Tensorflow原始教程链接在官网:
https://tensorflow.google.cn/tutorials/load_data/images
简化版:
https://colab.research.google.com/drive/146IoL0nVN7HOA3sUJ08zAGbngmwTArDp?usp=sharing

但原始教程中比较繁琐,对于想要直接使用的情况的话,本文将如下要点提炼出来。

1、数据

假设你有如下形式的数据:

每一个类别的名称就是文件夹名称,每个文件夹下面放置该类的图片。
现在我们就想使用tf.data 把数据整合成一个数据集,然后直接用于模型的训练。

2、生成数据集

2.1 获取所有数据的路径和标签

使用pathlib包来处理文件夹,包括获取文件夹名称和上一级文件夹

import pathlib
data_root="../../data/kaggle_dog/train_valid_test_tiny"
train_data_root = pathlib.Path(data_root+"/train")
#获取标签名称
label_names = sorted(item.name for item in train_data_root.glob('*/') if item.is_dir())
#标签名称到id的映射
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#获取图片路径
train_all_image_paths = [str(path) for path in list(train_data_root.glob('*/*'))]
#获取对应的labels
train_all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in train_all_image_paths]
#展示
print("First 10 images indices: ", train_valid_all_image_labels[:10])
print("First 10 labels indices: ", train_valid_all_image_labels[:10])

如上在../../data/kaggle_dog/train_valid_test_tiny文件夹下有/train的训练文件夹,里面又放了很多类别的文件夹,里面包含每个类的图片数据。
通过上面的操作可以获取/train下的训练图片和类别。

2.2 路径–>图片–>模型输入

获取了图片的所有路径之后就可以根据路径获取图片,然后根据图片转化为模型可以接受的输入,将这个转化过程整理为transform_train函数。

def transform_train(imgpath,label):#从路径中读取图片feature=tf.io.read_file(imgpath)#解码图片feature = tf.image.decode_jpeg(feature,channels=3)#重新设置大小feature = tf.image.resize(feature, size=[400, 400])#随机裁剪seed=random.randint(8,100)/100feature = tf.image.random_crop(feature, size=[int(seed*feature.shape[0]), int(seed*feature.shape[1]), 3])#最终设置为224*224的大小feature = tf.image.resize(feature, size=[224, 224])feature = tf.image.random_flip_left_right(feature)feature = tf.image.random_flip_up_down(feature)# 标准化feature = tf.divide(feature, 255.)# 正则化mean = tf.convert_to_tensor([0.485, 0.456, 0.406])std = tf.convert_to_tensor([0.229, 0.224, 0.225])feature = tf.divide(tf.subtract(feature, mean), std)#feature = tf.image.per_image_standardization(feature)return tf.image.convert_image_dtype(feature, tf.float32),label

需要注意的是让整个Dataset来统一应用这个函数的话,传入的path读取成图片后不一定会直接由shape的信息,所以提前调用resize图片。

2.3 构建 tf.data.Dataset

train_ds = tf.data.Dataset.from_tensor_slices((train_all_image_paths, train_all_image_labels)).map(transform_train).shuffle(len(train_all_image_paths)).batch(batch_size)

这样一个训练使用的 tf.data.Dataset就构建好了,需要注意的是使用了.batch(batch_size)才能生成(None,224,224,3)的数据集形式,否则就只是(224,224,3)的图片形式。

3、训练

模型定义

from tensorflow.keras.applications import ResNet50
net=ResNet50(input_shape=(224, 224, 3),weights='imagenet',include_top=False
)
model = tf.keras.Sequential([net,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(len(label_names), activation='softmax',dtype=tf.float32)
])
model.summary()


训练参数设置

lr = 0.1
lr_decay = 0.01def scheduler(epoch):if epoch < 10:return lrelse:return lr * tf.math.exp(lr_decay * (10 - epoch))callback = tf.keras.callbacks.LearningRateScheduler(scheduler)model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr, momentum=0.9),loss='sparse_categorical_crossentropy')

训练模型

model.fit(train_ds, epochs=1 , validation_data=valid_ds,  callbacks=[callback])

120/120 [==============================] - 105s 879ms/step - loss: 5.2341 - val_loss: 5.6558 - lr: 0.1000 <tensorflow.python.keras.callbacks.History at 0x7fa4cb2daac8>
由于只是测试,所以结果如上所示。

使用tf.data 加载文件夹下的图片集合并分类相关推荐

  1. unity加载文件夹下的所有预制体

    public static void GetFileGame(string name)//文件夹名称 { foreach (var item in Directory.GetFiles(Applica ...

  2. Unity实现加载文件夹内所有图片并可显示和放大的一种解决方案

    实现思路 Unity加载外部图片可以通过文件流和WWW的形式加载出来,再根据加载出来的纹理图片创建为精灵图片,赋值给Image对象即可:放大的思路主要是创建一张固定大小的Image图,点击小图时,将小 ...

  3. go gin框架:StaticFS搭建文件服务器(可以加载文件夹及文件)

    Static只能展示文件,比如展示图片等 StaticFS可以连目录也一起展示 package mainimport ("net/http""github.com/gin ...

  4. 未能加载文件或程序集System.Data,Version=2.0.0.0解决方法

    sqlserver 2005打开出现无法正常访问数据,提示信息: 未能加载文件或程序集"System.Data,Version=2.0.0.0,Culture=neutral,PublicK ...

  5. 未能加载文件或程序集System.Data,Version=2.0.0.0和System.XML,Version=2.0.0.0解决方法

    1.未能加载文件或程序集"System.Data,Version=2.0.0.0,Culture=neutral,PublicKeyToken=b77a5c561934e089"或 ...

  6. 未能加载 mysql.data_连接MySQL 提示错误”未能加载文件或程序集“MySql.Data, Version=5.1.4.0, Culture=neutral,……..” | 学步园...

    CodeSmith4.1.3版本连接MySQL 提示错误"未能加载文件或程序集"MySql.Data, Version=5.1.4.0, Culture=neutral,..... ...

  7. c# 未能加载文件或程序集mysql.data,SQLite的C#,.NET应用自适应32位/64位系统(未能加载文件或程序集“System.Data.SQLite.dll)...

    SQLite异常报错 其他信息: 未能加载文件或程序集"System.Data.SQLite, Version=1.0.103.0, Culture=neutral, PublicKeyTo ...

  8. 关于在IDEA的Resources目录下无法加载文件的问题

    我的目录结构 使用绝对路径加载 Properties p = new Properties();InputStream in = null;try {//用绝对路径加载File file =new F ...

  9. Tensorflow:TF模型文件(checkpoint文件夹下ckpt文件之data、index、meta)保存、模型导入、恢复并fine-tuning之详细攻略

    Tensorflow:TF模型文件(checkpoint文件夹下ckpt文件之data.index.meta)保存.模型导入.恢复并fine-tuning之详细攻略 目录 保存TF训练好的模型 1.T ...

最新文章

  1. DF-SLAM:一种深度特征提取方法
  2. winform 外部组件发生异常
  3. 题目:社区人员登记管理系统(有源码链接免费下载)
  4. UE4学习-请求的操作需要提升
  5. 蚂蚁式管理(Style of Ant Management)
  6. SpringBoot-拦截器、过滤器、监听器
  7. 项目Alpha冲刺(团队)-第九天冲刺
  8. c语言 10^30,^ 在C语言中是什么意思?
  9. How to use neural network to realize logic 'and' and 'or'?
  10. 使用yum快速部署Oracle安装环境 11g
  11. 基于单片机的自动追日系统设计_电机太阳论文,关于基于Atmega32的主动式太阳能追日系统相关参考文献资料-免费论文范文...
  12. PWM波的原理和应用
  13. php 公众号多图文消息,微信公众号怎样群发多图文消息?
  14. 打印纸张尺寸换算_纸张尺寸对照表
  15. 小麦苗的常用代码--常用命令(仅限自己使用)
  16. 新手学编程?选python吧!
  17. mysql查看分片键
  18. linux文件夹内JPG批量转PNG
  19. 【win11】win10 资源管理器
  20. centos8上实现私有CA和证书申请颁发

热门文章

  1. win10系统vscode、sublime等无法全局搜索
  2. Fail2ban在ip网络电话调度中的作用和实现方法
  3. 高中数学三角函数图像平移变换问题解题技巧(附免费视频教程)
  4. 谁在制造“完美男性”?
  5. openwrt 软路由 docker安装青龙面板(基础版)
  6. 从URL启动程序:也谈谈旺旺的页面启动--转载
  7. 京东是如何管人的(刘强东)
  8. 计算机word表格加法公式,Word中的表格使用公式计算的方法(推荐)
  9. 下面这些人,你肯定想不到他们也做过程序员
  10. 技术型企业如何做好人才通道?