使用tf.data 加载文件夹下的图片集合并分类
Tensorflow原始教程链接在官网:
https://tensorflow.google.cn/tutorials/load_data/images
简化版:
https://colab.research.google.com/drive/146IoL0nVN7HOA3sUJ08zAGbngmwTArDp?usp=sharing
但原始教程中比较繁琐,对于想要直接使用的情况的话,本文将如下要点提炼出来。
1、数据
假设你有如下形式的数据:
每一个类别的名称就是文件夹名称,每个文件夹下面放置该类的图片。
现在我们就想使用tf.data 把数据整合成一个数据集,然后直接用于模型的训练。
2、生成数据集
2.1 获取所有数据的路径和标签
使用pathlib包来处理文件夹,包括获取文件夹名称和上一级文件夹
import pathlib
data_root="../../data/kaggle_dog/train_valid_test_tiny"
train_data_root = pathlib.Path(data_root+"/train")
#获取标签名称
label_names = sorted(item.name for item in train_data_root.glob('*/') if item.is_dir())
#标签名称到id的映射
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#获取图片路径
train_all_image_paths = [str(path) for path in list(train_data_root.glob('*/*'))]
#获取对应的labels
train_all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in train_all_image_paths]
#展示
print("First 10 images indices: ", train_valid_all_image_labels[:10])
print("First 10 labels indices: ", train_valid_all_image_labels[:10])
如上在../../data/kaggle_dog/train_valid_test_tiny
文件夹下有/train
的训练文件夹,里面又放了很多类别的文件夹,里面包含每个类的图片数据。
通过上面的操作可以获取/train
下的训练图片和类别。
2.2 路径–>图片–>模型输入
获取了图片的所有路径之后就可以根据路径获取图片,然后根据图片转化为模型可以接受的输入,将这个转化过程整理为transform_train
函数。
def transform_train(imgpath,label):#从路径中读取图片feature=tf.io.read_file(imgpath)#解码图片feature = tf.image.decode_jpeg(feature,channels=3)#重新设置大小feature = tf.image.resize(feature, size=[400, 400])#随机裁剪seed=random.randint(8,100)/100feature = tf.image.random_crop(feature, size=[int(seed*feature.shape[0]), int(seed*feature.shape[1]), 3])#最终设置为224*224的大小feature = tf.image.resize(feature, size=[224, 224])feature = tf.image.random_flip_left_right(feature)feature = tf.image.random_flip_up_down(feature)# 标准化feature = tf.divide(feature, 255.)# 正则化mean = tf.convert_to_tensor([0.485, 0.456, 0.406])std = tf.convert_to_tensor([0.229, 0.224, 0.225])feature = tf.divide(tf.subtract(feature, mean), std)#feature = tf.image.per_image_standardization(feature)return tf.image.convert_image_dtype(feature, tf.float32),label
需要注意的是让整个Dataset来统一应用这个函数的话,传入的path读取成图片后不一定会直接由shape的信息,所以提前调用resize图片。
2.3 构建 tf.data.Dataset
train_ds = tf.data.Dataset.from_tensor_slices((train_all_image_paths, train_all_image_labels)).map(transform_train).shuffle(len(train_all_image_paths)).batch(batch_size)
这样一个训练使用的 tf.data.Dataset就构建好了,需要注意的是使用了.batch(batch_size)
才能生成(None,224,224,3)
的数据集形式,否则就只是(224,224,3)
的图片形式。
3、训练
模型定义
from tensorflow.keras.applications import ResNet50
net=ResNet50(input_shape=(224, 224, 3),weights='imagenet',include_top=False
)
model = tf.keras.Sequential([net,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(len(label_names), activation='softmax',dtype=tf.float32)
])
model.summary()
训练参数设置
lr = 0.1
lr_decay = 0.01def scheduler(epoch):if epoch < 10:return lrelse:return lr * tf.math.exp(lr_decay * (10 - epoch))callback = tf.keras.callbacks.LearningRateScheduler(scheduler)model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr, momentum=0.9),loss='sparse_categorical_crossentropy')
训练模型
model.fit(train_ds, epochs=1 , validation_data=valid_ds, callbacks=[callback])
120/120 [==============================] - 105s 879ms/step - loss: 5.2341 - val_loss: 5.6558 - lr: 0.1000 <tensorflow.python.keras.callbacks.History at 0x7fa4cb2daac8>
由于只是测试,所以结果如上所示。
使用tf.data 加载文件夹下的图片集合并分类相关推荐
- unity加载文件夹下的所有预制体
public static void GetFileGame(string name)//文件夹名称 { foreach (var item in Directory.GetFiles(Applica ...
- Unity实现加载文件夹内所有图片并可显示和放大的一种解决方案
实现思路 Unity加载外部图片可以通过文件流和WWW的形式加载出来,再根据加载出来的纹理图片创建为精灵图片,赋值给Image对象即可:放大的思路主要是创建一张固定大小的Image图,点击小图时,将小 ...
- go gin框架:StaticFS搭建文件服务器(可以加载文件夹及文件)
Static只能展示文件,比如展示图片等 StaticFS可以连目录也一起展示 package mainimport ("net/http""github.com/gin ...
- 未能加载文件或程序集System.Data,Version=2.0.0.0解决方法
sqlserver 2005打开出现无法正常访问数据,提示信息: 未能加载文件或程序集"System.Data,Version=2.0.0.0,Culture=neutral,PublicK ...
- 未能加载文件或程序集System.Data,Version=2.0.0.0和System.XML,Version=2.0.0.0解决方法
1.未能加载文件或程序集"System.Data,Version=2.0.0.0,Culture=neutral,PublicKeyToken=b77a5c561934e089"或 ...
- 未能加载 mysql.data_连接MySQL 提示错误”未能加载文件或程序集“MySql.Data, Version=5.1.4.0, Culture=neutral,……..” | 学步园...
CodeSmith4.1.3版本连接MySQL 提示错误"未能加载文件或程序集"MySql.Data, Version=5.1.4.0, Culture=neutral,..... ...
- c# 未能加载文件或程序集mysql.data,SQLite的C#,.NET应用自适应32位/64位系统(未能加载文件或程序集“System.Data.SQLite.dll)...
SQLite异常报错 其他信息: 未能加载文件或程序集"System.Data.SQLite, Version=1.0.103.0, Culture=neutral, PublicKeyTo ...
- 关于在IDEA的Resources目录下无法加载文件的问题
我的目录结构 使用绝对路径加载 Properties p = new Properties();InputStream in = null;try {//用绝对路径加载File file =new F ...
- Tensorflow:TF模型文件(checkpoint文件夹下ckpt文件之data、index、meta)保存、模型导入、恢复并fine-tuning之详细攻略
Tensorflow:TF模型文件(checkpoint文件夹下ckpt文件之data.index.meta)保存.模型导入.恢复并fine-tuning之详细攻略 目录 保存TF训练好的模型 1.T ...
最新文章
- DF-SLAM:一种深度特征提取方法
- winform 外部组件发生异常
- 题目:社区人员登记管理系统(有源码链接免费下载)
- UE4学习-请求的操作需要提升
- 蚂蚁式管理(Style of Ant Management)
- SpringBoot-拦截器、过滤器、监听器
- 项目Alpha冲刺(团队)-第九天冲刺
- c语言 10^30,^ 在C语言中是什么意思?
- How to use neural network to realize logic 'and' and 'or'?
- 使用yum快速部署Oracle安装环境 11g
- 基于单片机的自动追日系统设计_电机太阳论文,关于基于Atmega32的主动式太阳能追日系统相关参考文献资料-免费论文范文...
- PWM波的原理和应用
- php 公众号多图文消息,微信公众号怎样群发多图文消息?
- 打印纸张尺寸换算_纸张尺寸对照表
- 小麦苗的常用代码--常用命令(仅限自己使用)
- 新手学编程?选python吧!
- mysql查看分片键
- linux文件夹内JPG批量转PNG
- 【win11】win10 资源管理器
- centos8上实现私有CA和证书申请颁发