参考博客

目录

  • CSV转tfRecord代码
    • tf.Example和TFRecord
    • python中的namedtuple
    • DataFrame中gruopby()的应用

CSV转tfRecord代码

from __future__ import division
from __future__ import print_function
from __future__ import absolute_importimport os
import io
import pandas as pd
import tensorflow as tffrom PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDictflags = tf.app.flags
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
flags.DEFINE_string('csv_input', 'train.csv', 'Path to the CSV input')
flags.DEFINE_string('image_dir', './data/train/', 'Path to the image directory')
flags.DEFINE_string('output_path', 'train.record', 'Path to output TFRecord')
FLAGS = flags.FLAGS# TO-DO replace this with label map
# 为什么从1开始,而不是从0开始?????
def class_text_to_int(row_label):if row_label == 'dog':return 1if row_label == 'pig':return 2if row_label=='cat':return 3else:return Nonedef split(df, group):"""对csv数据进行处理:param df: :param group: 聚合关键字:return: [('image_filename_1',DataFrame_1),('image_filename_2',DataFrame_2),...] (图片名,该图片的所有boxes信息)"""data = namedtuple('data', ['filename', 'object']) #创建一个namedtuple的数据类型,有两个属性filename,objectgb = df.groupby(group) #对关键列group进行聚合,有同一张图片多个标记框的聚合在一起return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):"""创建tf.Example消息:param group: tuple,每一张图片的信息(filename,DataFrame):param path: 数据集的路径:return:"""#tf.gfile.GFile(filename, mode)#获取文本操作句柄,类似于python提供的文本操作open()函数,filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。with  tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = Image.open(encoded_jpg_io)width, height = image.sizefilename = group.filename.encode('utf8')image_format = b'jpg'xmins = []xmaxs = []ymins = []ymaxs = []classes_text = []classes = []for index, row in group.object.iterrows():xmins.append(row['xmin'] / width) #相对值xmaxs.append(row['xmax'] / width)ymins.append(row['ymin'] / height)ymaxs.append(row['ymax'] / height)classes_text.append(row['class'].encode('utf8'))classes.append(class_text_to_int(row['class']))tf_example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(filename),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return tf_exampledef main(_):writer = tf.io.TFRecordWriter(FLAGS.output_path)path = FLAGS.image_direxamples = pd.read_csv(FLAGS.csv_input)grouped = split(examples, 'filename')for group in grouped:tf_example = create_tf_example(group, path)writer.write(tf_example.SerializeToString())writer.close()output_path =FLAGS.output_pathprint('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':tf.compat.v1.app.run()

tf.Example和TFRecord

参考
参考

  • TFRecord 格式是一种用于存储二进制记录序列的简单格式。
  • 协议缓冲区是一个跨平台、跨语言的库,用于高效地序列化结构化数据。协议消息由 .proto 文件定义,这通常是了解消息类型最简单的方法
  • tf.Example 消息(或 protobuf)是一种灵活的消息类型,就是一种将数据表示为{‘string’: value}形式的 message类型,TensorFlow经常使用 tf.Example 来写入,读取 TFRecord数据

通常情况下,tf.Example中可以使用以下几种格式:

  • tf.train.BytesList: 可以使用的类型包括 string和byte
  • tf.train.FloatList: 可以使用的类型包括 float和double
  • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
    如将string数据变为tf.train.BytesList后,再写入tf.train.Feature
def _bytes_feature(value):"""Returns a bytes_list from a string/byte."""if isinstance(value, type(tf.constant(0))):value = value.numpy() # BytesList won't unpack a string from an EagerTensor.return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _float_feature(value):"""Return a float_list form a float/double."""return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))def _int64_feature(value):"""Return a int64_list from a bool/enum/int/uint."""return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

通过上述操作,我们以dict的形式把要写入的数据汇总,并构建 tf.train.Features,然后构建 tf.train.Example

def get_tfrecords_example(feature, label):tfrecords_features = {}feat_shape = feature.shapetfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可以通过 tf.python_io.TFRecordWriter 写入 tfrecord文件中

#创建tfrecord的writer,文件名为xxx
tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord')
#把数据写入Example
exmp = get_tfrecords_example(feats[inx], labels[inx])
#Example序列化
exmp_serial = exmp.SerializeToString()
#写入tfrecord文件
tfrecord_wrt.write(exmp_serial)
#写完后关闭tfrecord的writer
tfrecord_wrt.close()

python中的namedtuple

namedtuple是继承自tuple的子类。namedtuple创建一个和tuple类似的对象,而且对象拥有可访问的属性

from collections import namedtuple# 定义一个namedtuple类型User,并包含name,sex和age属性。
User = namedtuple('User', ['name', 'sex', 'age'])# 创建一个User对象
user1 = User(name='zhangsan', sex='woman', age=21)
print('user1:')
print(type(user1),len(user1)) #<class '__main__.User'> 3
print(user1) #User(name='zhangsan', sex='woman', age=21)#属性
print(user1.name) #zhangsan
print(user1.sex) #woman# 也可以通过一个list来创建一个User对象,这里注意需要使用"_make"方法
user2 = User._make(['lisi', 'man', 22])
print(user2) #User(name='lisi', sex='man', age=22)# 修改对象属性,注意要使用"_replace"方法
user1 = user1._replace(age=45)
print(user1) #User(name='zhangsan', sex='woman', age=45)# 将User对象转换成字典,注意要使用"_asdict"
print(user1._asdict()) #OrderedDict([('name', 'zhangsan'), ('sex', 'woman'), ('age', 45)])

DataFrame中gruopby()的应用

def split(df, group):#namedtuple是tuple的子类#定义一个namedtuple的数据类型,有两个属性filename,objectdata = namedtuple('data', ['filename', 'object']) #对关键列group进行聚合,有同一张图片多个标记框的聚合在一起gb = df.groupby(group) return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
import pandas as pd
import numpy as  npseed=np.random.seed(2)
data=pd.DataFrame({'image_id':[0,0,1,2,3,3,3],'width':[23,23,34,45,67,67,67],'height':[32,32,43,45,76,76,76],'xmin':np.random.randn(7),'ymin':np.random.randn(7),'xmax':np.random.randn(7),'ymax':np.random.randn(7)})
print('原始数据:')
print(data)
'''
原始数据:image_id  width  height      xmin      ymin      xmax      ymax
0         0     23      32 -0.416758 -1.245288  0.539058 -0.156434
1         0     23      32 -0.056267 -1.057952 -0.596160  0.256570
2         1     34      43 -2.136196 -0.909008 -0.019130 -0.988779
3         2     45      45  1.640271  0.551454  1.175001 -0.338822
4         3     67      76 -1.793436  2.292208 -0.747871 -0.236184
5         3     67      76 -0.841747  0.041539  0.009025 -0.637655
6         3     67      76  0.502881 -1.117925 -0.878108 -1.187612
'''
gb = data.groupby('image_id')
# 对于DataFrame数据,根据关键列‘image_id’进行聚类,filename相同的放一起,成为一个DataFrame
# 聚类后的gb是一个包含所有聚类结果的对象print(gb) #<pandas.core.groupby.generic.DataFrameGroupBy object at 0x000002402C56A2B0>
print(type(gb)) #pandas.core.groupby.generic.DataFrameGroupBy
print(gb.count())
'''width  height  xmin  ymin  xmax  ymax
image_id
0             2       2     2     2     2     2
1             1       1     1     1     1     1
2             1       1     1     1     1     1
3             3       3     3     3     3     3
聚类后只有0,1,2,3三个image_id,width表示images_id=0时,width列有两条数据
'''
print(gb.groups.keys()) #dict_keys([0, 1, 2, 3])#遍历gb对象,
for name,group in gb:#gb对象中print('name:',name)print('type:',type(group))print('shape:',group.shape)print(group)
'''
name: 0
type: <class 'pandas.core.frame.DataFrame'>
shape: (2, 7)image_id  width  height      xmin      ymin      xmax      ymax
0         0     23      32 -0.416758 -1.245288  0.539058 -0.156434
1         0     23      32 -0.056267 -1.057952 -0.596160  0.256570
name: 1
type: <class 'pandas.core.frame.DataFrame'>
shape: (1, 7)image_id  width  height      xmin      ymin     xmax      ymax
2         1     34      43 -2.136196 -0.909008 -0.01913 -0.988779
name: 2
type: <class 'pandas.core.frame.DataFrame'>
shape: (1, 7)image_id  width  height      xmin      ymin      xmax      ymax
3         2     45      45  1.640271  0.551454  1.175001 -0.338822
name: 3
type: <class 'pandas.core.frame.DataFrame'>
shape: (3, 7)image_id  width  height      xmin      ymin      xmax      ymax
4         3     67      76 -1.793436  2.292208 -0.747871 -0.236184
5         3     67      76 -0.841747  0.041539  0.009025 -0.637655
6         3     67      76  0.502881 -1.117925 -0.878108 -1.187612
'''

CSV转tfRecord相关推荐

  1. 【Python1】双系统安装,深度学习环境搭建,目标检测(Tensorflow_API_SSD)

    文章目录 1.安装双系统 2.ubuntu安装常用软件 2.1 anaconda3 2.2 flameshot(截图) 2.3 SimpleScreenRecorder(录屏) 2.4 teamvie ...

  2. tensorflow系列之1:加载数据

    本文介绍了如何加载各种数据源,以生成可以用于tensorflow使用的数据集,一般指Dataset.主要包括以下几类数据源: 预定义的公共数据源 内存中的数据 csv文件 TFRecord 任意格式的 ...

  3. tensorflow环境下的识别食物_在win10环境下进行tensorflow物体识别(ObjectDetection)训练...

    安装ObjectDetection,CPU和GPU都需要 解压module.rar放到C:\TFWS\models目录 地址:https://github.com/tensorflow/models ...

  4. 手把手带你玩转Tensorflow 物体检测 API (2)——数据准备

    致谢声明 本文在学习<Tensorflow object detection API 搭建属于自己的物体识别模型(2)--训练并使用自己的模型>的基础上优化并总结,此博客链接:https: ...

  5. object detection训练自己数据

    1.用labelImg标自己数据集. 并将图片存放在JPEGImages中,xml存放在Annotations中 2.分离训练和测试数据 import os import randomtrainval ...

  6. 基于深度学习的动物识别方法研究与实现

    基于深度学习的动物识别方法研究与实现 目  录 摘  要 I ABSTRACT II     第一章  绪论 1 1.1 研究的目的和意义 1 1.2国内外研究现状 1 1.2.1 目标检测国内外研究 ...

  7. 常用数据集预处理(dota)

    从数据集中选出自己需要的类别 import os import cv2 import shutilcatogary = ['bridge'] #列表def customname(fullname):& ...

  8. 深度学习实战(七)——目标检测API训练自己的数据集(R-FCN数据集制作+训练+测试)

    TensorFlow提供的网络结构的预训练权重:https://cloud.tencent.com/developer/article/1006123 将voc数据集转换成.tfrecord格式供te ...

  9. 建立自己的数据集 并用Tensorflow object detection API进行训练

    ps: 欢迎大家光临我的博客 建立数据集 标注工具: ubuntu 图像标注工具labelImg sudo apt-get install pyqt5-dev-tools sudo pip3 inst ...

最新文章

  1. 爱奇艺蒙版AI:弹幕穿人过,爱豆心中坐
  2. IoC容器Autofac(3) - 理解Autofac原理,我实现的部分Autofac功能(附源码)
  3. Netty系列之Netty 服务端创建
  4. VII Python(9)socket编程
  5. BZOJ4432 : [Cerc2015]Greenhouse Growth
  6. maven pom java版本_Maven更新POM中的JDK版本(比如更新为JDK1.8)
  7. caffe检测图片是否包含人脸_caffe入门-人脸检测1
  8. parted如何将磁盘所有空间格式化_linux下大于2T的硬盘格式化问题
  9. linux 启动流详解
  10. 使用番石榴的5个理由
  11. IE css hack整理
  12. 分享一个查看JSON的程序
  13. SQL variable type
  14. ORA-01045 :user 用户名 lacks create session privilege; logon denied
  15. 企业号 网页授权 php,微信企业号开发之网页授权接口调用示例
  16. YUI可真是个不错的东东
  17. 四年级计算机教学目的,四年级计算机教学计划
  18. DataGrip 太好使了
  19. android 刘海机型适配,Android全面屏刘海适配
  20. 交互式多模型算法IMM——机动目标跟踪中的应用

热门文章

  1. macOS 修改mysql账号密码
  2. 【msm8953】带clk的gpio口模拟pwm
  3. 利用结构体设计游戏背包属性的思路
  4. 便携式明渠流量计有哪几种呢?
  5. 电子工程师,你在深圳值多少钱
  6. 坚持理想与目标、并从小事慢慢做起
  7. C语言学习笔记 —— 内存管理
  8. 程序员如何在情人节脱单?
  9. 宏观框架-海通梁中华-01)
  10. 深入理解Java Lambda表达式,匿名函数,闭包