目录

头文件

一、读取数据集(图片名)

二、将数据集图片、标签写入TFRecord

三、从TFRecord中读取数据集

四、构建模型

五、训练模型

实验结果


头文件

import tensorflow as tf
import os

一、读取数据集(图片名)

data_dir = "D:/dataset/cats_and_dogs_filtered"
train_cat_dir = data_dir + "/train/cats/"
train_dog_dir = data_dir + "/train/dogs/"
train_tfrecord_file = data_dir + "/train/train.tfrecords"test_cat_dir = data_dir + "/validation/cats/"
test_dog_dir = data_dir + "/validation/dogs/"
test_tfrecord_file = data_dir + "/validation/test.tfrecords"train_cat_filenames = [train_cat_dir + filename for filename in os.listdir(train_cat_dir)]
train_dog_filenames = [train_dog_dir + filename for filename in os.listdir(train_dog_dir)]
train_filenames = train_cat_filenames + train_dog_filenames
train_labels = [0]*len(train_cat_filenames) + [1]*len(train_dog_filenames)test_cat_filenames = [test_cat_dir + filename for filename in os.listdir(test_cat_dir)]
test_dog_filenames = [test_dog_dir + filename for filename in os.listdir(test_dog_dir)]
test_filenames = test_cat_filenames + test_dog_filenames
test_labels = [0]*len(test_cat_filenames) + [1]*len(test_dog_filenames)

二、将数据集图片、标签写入TFRecord

def encoder(filenames, labels, tfrecord_file):with tf.io.TFRecordWriter(tfrecord_file) as writer:for filename, label in zip(filenames, labels):image = open(filename, 'rb').read()feature = {# 建立feature字典'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}# 通过字典创建exampleexample = tf.train.Example(features=tf.train.Features(feature=feature))# 将example序列化并写入字典writer.write(example.SerializeToString())encoder(train_filenames, train_labels, train_tfrecord_file)
encoder(test_filenames, test_labels, test_tfrecord_file)

三、从TFRecord中读取数据集

def decoder(tfrecord_file, is_train_dataset=None):dataset = tf.data.TFRecordDataset(tfrecord_file)feature_discription = {'image': tf.io.FixedLenFeature([], tf.string),'label': tf.io.FixedLenFeature([], tf.int64)}def _parse_example(example_string): # 解码每一个examplefeature_dic = tf.io.parse_single_example(example_string, feature_discription)feature_dic['image'] = tf.io.decode_jpeg(feature_dic['image'])feature_dic['image'] = tf.image.resize(feature_dic['image'], [256, 256])/255.0return feature_dic['image'], feature_dic['label']batch_size = 32if is_train_dataset is not None:dataset = dataset.map(_parse_example).shuffle(buffer_size=2000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)else:dataset = dataset.map(_parse_example)dataset = dataset.batch(batch_size)return datasettrain_data = decoder(train_tfrecord_file, 1)
test_data = decoder(test_tfrecord_file)

四、构建模型

class CNNModel(tf.keras.models.Model):def __init__(self):super(CNNModel, self).__init__()self.conv1 = tf.keras.layers.Conv2D(12, 3, activation='relu')self.maxpool1 = tf.keras.layers.MaxPooling2D()self.conv2 = tf.keras.layers.Conv2D(12, 5, activation='relu')self.maxpool2 = tf.keras.layers.MaxPooling2D()self.flatten = tf.keras.layers.Flatten()self.d1 = tf.keras.layers.Dense(64, activation='relu')self.d2 = tf.keras.layers.Dense(2, activation='softmax')def call(self, inputs):x = self.conv1(inputs)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.flatten(x)x = self.d1(x)x = self.d2(x)return x

五、训练模型

def train_CNNModel():model = CNNModel()loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()optimizer = tf.keras.optimizers.Adam(0.001)train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')@tf.functiondef train_step(images, labels):with tf.GradientTape() as tape:logits = model(images)loss = loss_obj(labels, logits)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_acc(labels, logits)@tf.functiondef test_step(images, labels):logits = model(images)test_acc(labels, logits)Epochs = 5for epoch in range(Epochs):train_acc.reset_states()test_acc.reset_states()for images, labels in train_data:train_step(images, labels)for images, labels in test_data:test_step(images, labels)tmp = 'Epoch {}, Acc {}, Test Acc {}'print(tmp.format(epoch + 1,train_acc.result() * 100,test_acc.result() * 100))train_CNNModel()

实验结果

Epoch 1, Acc 51.45000076293945, Test Acc 51.70000076293945
Epoch 2, Acc 60.650001525878906, Test Acc 58.099998474121094
Epoch 3, Acc 70.5, Test Acc 63.30000305175781
Epoch 4, Acc 78.05000305175781, Test Acc 69.30000305175781
Epoch 5, Acc 87.4000015258789, Test Acc 69.19999694824219

TensorFlow2.0 利用TFRecord存取数据集,分批次读取训练相关推荐

  1. tensorflow2.0莺尾花iris数据集分类|超详细

    tensorflow2.0莺尾花iris数据集分类 超详细 直接上代码 #导入模块 import tensorflow as tf #导入tensorflow模块from sklearn import ...

  2. 基于tensorflow2.0利用CNN与线性回归两种方法实现手写数字识别

    CNN实现手写数字识别 导入模块和数据集 import os import tensorflow as tf from tensorflow import keras from tensorflow. ...

  3. 笔记3:Tensorflow2.0实战之MNSIT数据集

    最近Tensorflow相继推出了alpha和beta两个版本,这两个都属于tensorflow2.0版本:早听说新版做了很大的革新,今天就来用一下看看 这里还是使用MNSIT数据集进行测试 导入必要 ...

  4. tensorflow2.0实现IMDB文本数据集学习词嵌入

    1. IMDB数据集示例如下所示 [{"rating": 5, "title": "The dark is rising!", " ...

  5. Tensorflow2.0 利用LSTM和爬虫做自动生成七言律诗

    从古诗网上获取七言律诗. 从网上随便找了一个古诗网 把该网站上的七言律诗爬取过来,该网站上也有五言律诗但没有把它们一起爬取下来做为数据来源,因为它们的文本长度不一样,如果把它们混在一起的话要对五言律诗 ...

  6. 【TensorFlow2.0】数据读取与使用方式

    大家好,这是专栏<TensorFlow2.0>的第三篇文章,讲述如何使用TensorFlow2.0读取和使用自己的数据集. 如果您正在学习计算机视觉,无论你通过书籍还是视频学习,大部分的教 ...

  7. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization

    <<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...

  8. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...

  9. 【小白学PyTorch】扩展之Tensorflow2.0 | 20 TF2的eager模式与求导

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 19 TF2模型的存储与载入 扩展之Tensorflow2.0 | 18 ...

最新文章

  1. python制作验证码_Python编写生成验证码的脚本的教程
  2. 使用证书保护网站--兼谈证书服务器吊销列表的使用
  3. BZOJ1001 狼抓兔子
  4. Java线程池理解及用法
  5. 关于神经网络训练的一些建议笔记
  6. 【前端笔试题】文本居中的几种小技巧
  7. 写一段代码提高内存占用_记录一次生产环境中Redis内存增长异常排查全流程!...
  8. “世界百位名人”诠释上海世博会城市主题
  9. Java程序员从笨鸟到菜鸟之(九十四)深入java虚拟机(三)——类的生命周期(下)类的初始化...
  10. Java 书籍 Top 10
  11. 约瑟夫环数据结构课程设计详解
  12. Chart控件,chart、Series、ChartArea曲线图绘制的重要属性介绍(Windows窗体)
  13. Python代码反向解析列线图nomogram自动计算各项得分及总得分
  14. wordpress seo设置全套SEO插件教程
  15. shell中单引号和双引号的区别-经典解释
  16. RestTemplate使用实战-exchange方法讲解
  17. 向量检索milvus之一:以图搜图
  18. 蚂蚁的愤怒之源(落日余晖)-终结篇
  19. 关于网络、交换机、路由器
  20. 1、孟子·梁惠王上 孟子·梁惠王下

热门文章

  1. docker集群搭建
  2. iOS 仿写项目之微信聊天界面、QQ聊天界面
  3. PXE自动装机脚本原创代码(适合脚本新人)
  4. env中的dev和prd
  5. php的构造函数和析构函数
  6. 搭建DSS环境(一)之CentOS7基础设置
  7. 软件售后拜访记录模板
  8. java web 火车票预定系统 完整源码 下载直接运行
  9. CocosCreator动画弹窗
  10. Excel一键清除数据区域内汉字的操作?