CSV转tfRecord
参考博客
目录
- CSV转tfRecord代码
- tf.Example和TFRecord
- python中的namedtuple
- DataFrame中gruopby()的应用
CSV转tfRecord代码
from __future__ import division
from __future__ import print_function
from __future__ import absolute_importimport os
import io
import pandas as pd
import tensorflow as tffrom PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDictflags = tf.app.flags
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
flags.DEFINE_string('csv_input', 'train.csv', 'Path to the CSV input')
flags.DEFINE_string('image_dir', './data/train/', 'Path to the image directory')
flags.DEFINE_string('output_path', 'train.record', 'Path to output TFRecord')
FLAGS = flags.FLAGS# TO-DO replace this with label map
# 为什么从1开始,而不是从0开始?????
def class_text_to_int(row_label):if row_label == 'dog':return 1if row_label == 'pig':return 2if row_label=='cat':return 3else:return Nonedef split(df, group):"""对csv数据进行处理:param df: :param group: 聚合关键字:return: [('image_filename_1',DataFrame_1),('image_filename_2',DataFrame_2),...] (图片名,该图片的所有boxes信息)"""data = namedtuple('data', ['filename', 'object']) #创建一个namedtuple的数据类型,有两个属性filename,objectgb = df.groupby(group) #对关键列group进行聚合,有同一张图片多个标记框的聚合在一起return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):"""创建tf.Example消息:param group: tuple,每一张图片的信息(filename,DataFrame):param path: 数据集的路径:return:"""#tf.gfile.GFile(filename, mode)#获取文本操作句柄,类似于python提供的文本操作open()函数,filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = Image.open(encoded_jpg_io)width, height = image.sizefilename = group.filename.encode('utf8')image_format = b'jpg'xmins = []xmaxs = []ymins = []ymaxs = []classes_text = []classes = []for index, row in group.object.iterrows():xmins.append(row['xmin'] / width) #相对值xmaxs.append(row['xmax'] / width)ymins.append(row['ymin'] / height)ymaxs.append(row['ymax'] / height)classes_text.append(row['class'].encode('utf8'))classes.append(class_text_to_int(row['class']))tf_example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(filename),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return tf_exampledef main(_):writer = tf.io.TFRecordWriter(FLAGS.output_path)path = FLAGS.image_direxamples = pd.read_csv(FLAGS.csv_input)grouped = split(examples, 'filename')for group in grouped:tf_example = create_tf_example(group, path)writer.write(tf_example.SerializeToString())writer.close()output_path =FLAGS.output_pathprint('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':tf.compat.v1.app.run()
tf.Example和TFRecord
参考
参考
- TFRecord 格式是一种用于存储二进制记录序列的简单格式。
- 协议缓冲区是一个跨平台、跨语言的库,用于高效地序列化结构化数据。协议消息由 .proto 文件定义,这通常是了解消息类型最简单的方法
- tf.Example 消息(或 protobuf)是一种灵活的消息类型,就是一种将数据表示为{‘string’: value}形式的 message类型,TensorFlow经常使用 tf.Example 来写入,读取 TFRecord数据
通常情况下,tf.Example中可以使用以下几种格式:
- tf.train.BytesList: 可以使用的类型包括 string和byte
- tf.train.FloatList: 可以使用的类型包括 float和double
- tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
如将string数据变为tf.train.BytesList后,再写入tf.train.Feature
def _bytes_feature(value):"""Returns a bytes_list from a string/byte."""if isinstance(value, type(tf.constant(0))):value = value.numpy() # BytesList won't unpack a string from an EagerTensor.return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _float_feature(value):"""Return a float_list form a float/double."""return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))def _int64_feature(value):"""Return a int64_list from a bool/enum/int/uint."""return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
通过上述操作,我们以dict的形式把要写入的数据汇总,并构建 tf.train.Features,然后构建 tf.train.Example
def get_tfrecords_example(feature, label):tfrecords_features = {}feat_shape = feature.shapetfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
把创建的tf.train.Example序列化下,便可以通过 tf.python_io.TFRecordWriter 写入 tfrecord文件中
#创建tfrecord的writer,文件名为xxx
tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord')
#把数据写入Example
exmp = get_tfrecords_example(feats[inx], labels[inx])
#Example序列化
exmp_serial = exmp.SerializeToString()
#写入tfrecord文件
tfrecord_wrt.write(exmp_serial)
#写完后关闭tfrecord的writer
tfrecord_wrt.close()
python中的namedtuple
namedtuple是继承自tuple的子类。namedtuple创建一个和tuple类似的对象,而且对象拥有可访问的属性
from collections import namedtuple# 定义一个namedtuple类型User,并包含name,sex和age属性。
User = namedtuple('User', ['name', 'sex', 'age'])# 创建一个User对象
user1 = User(name='zhangsan', sex='woman', age=21)
print('user1:')
print(type(user1),len(user1)) #<class '__main__.User'> 3
print(user1) #User(name='zhangsan', sex='woman', age=21)#属性
print(user1.name) #zhangsan
print(user1.sex) #woman# 也可以通过一个list来创建一个User对象,这里注意需要使用"_make"方法
user2 = User._make(['lisi', 'man', 22])
print(user2) #User(name='lisi', sex='man', age=22)# 修改对象属性,注意要使用"_replace"方法
user1 = user1._replace(age=45)
print(user1) #User(name='zhangsan', sex='woman', age=45)# 将User对象转换成字典,注意要使用"_asdict"
print(user1._asdict()) #OrderedDict([('name', 'zhangsan'), ('sex', 'woman'), ('age', 45)])
DataFrame中gruopby()的应用
def split(df, group):#namedtuple是tuple的子类#定义一个namedtuple的数据类型,有两个属性filename,objectdata = namedtuple('data', ['filename', 'object']) #对关键列group进行聚合,有同一张图片多个标记框的聚合在一起gb = df.groupby(group) return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
import pandas as pd
import numpy as npseed=np.random.seed(2)
data=pd.DataFrame({'image_id':[0,0,1,2,3,3,3],'width':[23,23,34,45,67,67,67],'height':[32,32,43,45,76,76,76],'xmin':np.random.randn(7),'ymin':np.random.randn(7),'xmax':np.random.randn(7),'ymax':np.random.randn(7)})
print('原始数据:')
print(data)
'''
原始数据:image_id width height xmin ymin xmax ymax
0 0 23 32 -0.416758 -1.245288 0.539058 -0.156434
1 0 23 32 -0.056267 -1.057952 -0.596160 0.256570
2 1 34 43 -2.136196 -0.909008 -0.019130 -0.988779
3 2 45 45 1.640271 0.551454 1.175001 -0.338822
4 3 67 76 -1.793436 2.292208 -0.747871 -0.236184
5 3 67 76 -0.841747 0.041539 0.009025 -0.637655
6 3 67 76 0.502881 -1.117925 -0.878108 -1.187612
'''
gb = data.groupby('image_id')
# 对于DataFrame数据,根据关键列‘image_id’进行聚类,filename相同的放一起,成为一个DataFrame
# 聚类后的gb是一个包含所有聚类结果的对象print(gb) #<pandas.core.groupby.generic.DataFrameGroupBy object at 0x000002402C56A2B0>
print(type(gb)) #pandas.core.groupby.generic.DataFrameGroupBy
print(gb.count())
'''width height xmin ymin xmax ymax
image_id
0 2 2 2 2 2 2
1 1 1 1 1 1 1
2 1 1 1 1 1 1
3 3 3 3 3 3 3
聚类后只有0,1,2,3三个image_id,width表示images_id=0时,width列有两条数据
'''
print(gb.groups.keys()) #dict_keys([0, 1, 2, 3])#遍历gb对象,
for name,group in gb:#gb对象中print('name:',name)print('type:',type(group))print('shape:',group.shape)print(group)
'''
name: 0
type: <class 'pandas.core.frame.DataFrame'>
shape: (2, 7)image_id width height xmin ymin xmax ymax
0 0 23 32 -0.416758 -1.245288 0.539058 -0.156434
1 0 23 32 -0.056267 -1.057952 -0.596160 0.256570
name: 1
type: <class 'pandas.core.frame.DataFrame'>
shape: (1, 7)image_id width height xmin ymin xmax ymax
2 1 34 43 -2.136196 -0.909008 -0.01913 -0.988779
name: 2
type: <class 'pandas.core.frame.DataFrame'>
shape: (1, 7)image_id width height xmin ymin xmax ymax
3 2 45 45 1.640271 0.551454 1.175001 -0.338822
name: 3
type: <class 'pandas.core.frame.DataFrame'>
shape: (3, 7)image_id width height xmin ymin xmax ymax
4 3 67 76 -1.793436 2.292208 -0.747871 -0.236184
5 3 67 76 -0.841747 0.041539 0.009025 -0.637655
6 3 67 76 0.502881 -1.117925 -0.878108 -1.187612
'''
CSV转tfRecord相关推荐
- 【Python1】双系统安装,深度学习环境搭建,目标检测(Tensorflow_API_SSD)
文章目录 1.安装双系统 2.ubuntu安装常用软件 2.1 anaconda3 2.2 flameshot(截图) 2.3 SimpleScreenRecorder(录屏) 2.4 teamvie ...
- tensorflow系列之1:加载数据
本文介绍了如何加载各种数据源,以生成可以用于tensorflow使用的数据集,一般指Dataset.主要包括以下几类数据源: 预定义的公共数据源 内存中的数据 csv文件 TFRecord 任意格式的 ...
- tensorflow环境下的识别食物_在win10环境下进行tensorflow物体识别(ObjectDetection)训练...
安装ObjectDetection,CPU和GPU都需要 解压module.rar放到C:\TFWS\models目录 地址:https://github.com/tensorflow/models ...
- 手把手带你玩转Tensorflow 物体检测 API (2)——数据准备
致谢声明 本文在学习<Tensorflow object detection API 搭建属于自己的物体识别模型(2)--训练并使用自己的模型>的基础上优化并总结,此博客链接:https: ...
- object detection训练自己数据
1.用labelImg标自己数据集. 并将图片存放在JPEGImages中,xml存放在Annotations中 2.分离训练和测试数据 import os import randomtrainval ...
- 基于深度学习的动物识别方法研究与实现
基于深度学习的动物识别方法研究与实现 目 录 摘 要 I ABSTRACT II 第一章 绪论 1 1.1 研究的目的和意义 1 1.2国内外研究现状 1 1.2.1 目标检测国内外研究 ...
- 常用数据集预处理(dota)
从数据集中选出自己需要的类别 import os import cv2 import shutilcatogary = ['bridge'] #列表def customname(fullname):& ...
- 深度学习实战(七)——目标检测API训练自己的数据集(R-FCN数据集制作+训练+测试)
TensorFlow提供的网络结构的预训练权重:https://cloud.tencent.com/developer/article/1006123 将voc数据集转换成.tfrecord格式供te ...
- 建立自己的数据集 并用Tensorflow object detection API进行训练
ps: 欢迎大家光临我的博客 建立数据集 标注工具: ubuntu 图像标注工具labelImg sudo apt-get install pyqt5-dev-tools sudo pip3 inst ...
最新文章
- 爱奇艺蒙版AI:弹幕穿人过,爱豆心中坐
- IoC容器Autofac(3) - 理解Autofac原理,我实现的部分Autofac功能(附源码)
- Netty系列之Netty 服务端创建
- VII Python(9)socket编程
- BZOJ4432 : [Cerc2015]Greenhouse Growth
- maven pom java版本_Maven更新POM中的JDK版本(比如更新为JDK1.8)
- caffe检测图片是否包含人脸_caffe入门-人脸检测1
- parted如何将磁盘所有空间格式化_linux下大于2T的硬盘格式化问题
- linux 启动流详解
- 使用番石榴的5个理由
- IE css hack整理
- 分享一个查看JSON的程序
- SQL variable type
- ORA-01045 :user 用户名 lacks create session privilege; logon denied
- 企业号 网页授权 php,微信企业号开发之网页授权接口调用示例
- YUI可真是个不错的东东
- 四年级计算机教学目的,四年级计算机教学计划
- DataGrip 太好使了
- android 刘海机型适配,Android全面屏刘海适配
- 交互式多模型算法IMM——机动目标跟踪中的应用