简介


哆啦A梦是小时候动画片中很喜欢的卡通人物,他那么的可爱,还有一个无所不能的口袋。在思考利用Tensorflow Object Detection API(以下简称TODA) 在自定义数据上做一个简单的练习时,脑海里突然冒出了哆啦A梦的形象,在这篇博客中记录了利用TODA对哆啦A梦的检测,主要包括以下流程(前提是安装好TODA,可以参考官方文档):
1. 搜集图片
2. 数据标注
3. 将图片转换为tfrecord格式
4. 下载预训练好的模型
5. 编辑labelmap.pbtxt 和模型配置文件
6. 模型训练
7. ckpt模型转换为pb模型
8. 模型测试
文件目录:在TODA目录的基础上添加了以下目录及文件:

1.搜集图片

这里利用了Bing Image Search api,使用前需要获取API KEY。用Python脚本下载所需的哆啦A梦图片,这里下载了500张哆啦A梦图片,对应的下载数量为脚本中的MAX_RESULTS=500, 脚本search_bing_api.py内容如下:

# USAGE
# python search_bing_api.py --query "doraemon" --output dataset/doraemon# import the necessary packages
from requests import exceptions
import argparse
import requests
import cv2
import os# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-q", "--query", required=True,help="search query to search Bing Image API for")
ap.add_argument("-o", "--output", required=True,help="path to output directory of images")
args = vars(ap.parse_args())# set your Microsoft Cognitive Services API key along with (1) the
# maximum number of results for a given search and (2) the group size
# for results (maximum of 50 per request)
API_KEY = "Your API-key"
MAX_RESULTS = 500
GROUP_SIZE = 50# set the endpoint API URL
URL = "https://api.cognitive.microsoft.com/bing/v7.0/images/search"# when attemping to download images from the web both the Python
# programming language and the requests library have a number of
# exceptions that can be thrown so let's build a list of them now
# so we can filter on them
EXCEPTIONS = set([IOError, FileNotFoundError,exceptions.RequestException, exceptions.HTTPError,exceptions.ConnectionError, exceptions.Timeout])# store the search term in a convenience variable then set the
# headers and search parameters
term = args["query"]
headers = {"Ocp-Apim-Subscription-Key" : API_KEY}
params = {"q": term, "offset": 0, "count": GROUP_SIZE}# make the search
print("[INFO] searching Bing API for '{}'".format(term))
search = requests.get(URL, headers=headers, params=params)
search.raise_for_status()# grab the results from the search, including the total number of
# estimated results returned by the Bing API
results = search.json()
estNumResults = min(results["totalEstimatedMatches"], MAX_RESULTS)
print("[INFO] {} total results for '{}'".format(estNumResults,term))# initialize the total number of images downloaded thus far
total = 0# loop over the estimated number of results in `GROUP_SIZE` groups
for offset in range(0, estNumResults, GROUP_SIZE):# update the search parameters using the current offset, then# make the request to fetch the resultsprint("[INFO] making request for group {}-{} of {}...".format(offset, offset + GROUP_SIZE, estNumResults))params["offset"] = offsetsearch = requests.get(URL, headers=headers, params=params)search.raise_for_status()results = search.json()print("[INFO] saving images for group {}-{} of {}...".format(offset, offset + GROUP_SIZE, estNumResults))# loop over the resultsfor v in results["value"]:# try to download the imagetry:# make a request to download the imageprint("[INFO] fetching: {}".format(v["contentUrl"]))r = requests.get(v["contentUrl"], timeout=30)# build the path to the output imageext = v["contentUrl"][v["contentUrl"].rfind("."):]p = os.path.sep.join([args["output"], "{}{}".format(str(total).zfill(8), ext)])# write the image to diskf = open(p, "wb")f.write(r.content)f.close()# catch any errors that would not unable us to download the# imageexcept Exception as e:# check to see if our exception is in our list of# exceptions to check forif type(e) in EXCEPTIONS:print("[INFO] skipping: {}".format(v["contentUrl"]))continue# try to load the image from diskimage = cv2.imread(p)# if the image is `None` then we could not properly load the# image from disk (so it should be ignored)if image is None:print("[INFO] deleting: {}".format(p))os.remove(p)continue# update the countertotal += 1

将下载得到的图片分为训练数据和测试数据,分别存放在 images/train 和 images/test 目录下。

2.数据标注

利用开源图片标注工具 LabelImg 分别对训练数据和测试数据进行标注,每张图片会得到一个对应的XML标注结果文件,将其保存在和图片相同的目录下。

3.将图片转换为tfrecord格式

首先,由标注得到的XML文件生成CSV文件,从而可以利用已有的工具从CSV文件生成训练所需的 tfrecord 格式文件。利用主目录下的 xml_to_csv.py ,修改对应的数据目录即可得到训练和测试数据的CSV文件。xml_to_csv.py 内容如下:

import os
import glob
import pandas as pd
import xml.etree.ElementTree as ETdef xml_to_csv(path):xml_list = []for xml_file in glob.glob(path + '/*.xml'):tree = ET.parse(xml_file)root = tree.getroot()for member in root.findall('object'):value = (root.find('filename').text+root.find('path').text[-4:],int(root.find('size')[0].text),int(root.find('size')[1].text),member[0].text,int(member[4][0].text),int(member[4][1].text),int(member[4][2].text),int(member[4][3].text))xml_list.append(value)column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']xml_df = pd.DataFrame(xml_list, columns=column_name)return xml_dfdef main():for folder in ['train','test']:image_path = os.path.join(os.getcwd(), ('images/' + folder))xml_df = xml_to_csv(image_path)xml_df.to_csv(('images/' + folder + '_labels.csv'), index=None)print('Successfully converted xml to csv.')main()
python xml_to_csv.py

生成tfrecord

利用generate_tfrecord.py 内容:

"""
Usage:# From tensorflow/models/# Create train data:python generate_tfrecord.py --csv_input=images/train_labels.csv --image_dir=images/train --output_path=train.record# Create test data:python generate_tfrecord.py --csv_input=images/test_labels.csv  --image_dir=images/test --output_path=test.record
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_importimport os
import io
import pandas as pd
import tensorflow as tffrom PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDictflags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('image_dir', '', 'Path to the image directory')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS# TO-DO replace this with label map
def class_text_to_int(row_label):if row_label == 'Doraemon':return 1else:Nonedef split(df, group):data = namedtuple('data', ['filename', 'object'])gb = df.groupby(group)return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:encoded_jpg = fid.read()encoded_jpg_io = io.BytesIO(encoded_jpg)image = Image.open(encoded_jpg_io)width, height = image.sizefilename = group.filename.encode('utf8')image_format = b'jpg'xmins = []xmaxs = []ymins = []ymaxs = []classes_text = []classes = []for index, row in group.object.iterrows():xmins.append(row['xmin'] / width)xmaxs.append(row['xmax'] / width)ymins.append(row['ymin'] / height)ymaxs.append(row['ymax'] / height)classes_text.append(row['class'].encode('utf8'))classes.append(class_text_to_int(row['class']))#print("Classes: ", classes)tf_example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(filename),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))return tf_exampledef main(_):writer = tf.python_io.TFRecordWriter(FLAGS.output_path)path = os.path.join(os.getcwd(), FLAGS.image_dir)examples = pd.read_csv(FLAGS.csv_input)grouped = split(examples, 'filename')for group in grouped:tf_example = create_tf_example(group, path)writer.write(tf_example.SerializeToString())writer.close()output_path = os.path.join(os.getcwd(), FLAGS.output_path)print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':tf.app.run()

训练数据

python generate_tfrecord.py --csv_input=images/train_labels.csv --image_dir=images/train --output_path=train.record

测试数据

python generate_tfrecord.py --csv_input=images/test_labels.csv  --image_dir=images/test --output_path=test.record

4.下载预训练好的模型

使用预训练的模型做迁移学习能够更快的训练得到一个不错的模型,这里选择的是faster_rcnn_inception_v2_coco_2018_01_28 模型,可以在这里下载得到,或者在Tensorflow detection model zoo 中下载其他预训练好的模型。将下载文件解压存放在faster_rcnn_inception_v2_coco_2018_01_28目录下,该文件夹下的文件应为:

5.编辑labelmap.pbtxt 和模型配置文件

labelmap.pbtxt及模型配置文件faster_rcnn_inception_doraemon.config都存放在training目录下。

(1)labelmap.pbtxt

item {id: 1name: 'Doraemon'
}

(2)faster_rcnn_inception_doraemon.config

# Faster R-CNN with Inception v2, configured for Oxford-IIIT Pets Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.model {faster_rcnn {num_classes: 1image_resizer {keep_aspect_ratio_resizer {min_dimension: 600max_dimension: 1024}}feature_extractor {type: 'faster_rcnn_inception_v2'first_stage_features_stride: 16}first_stage_anchor_generator {grid_anchor_generator {scales: [0.125, 0.25, 0.5, 1.0, 2.0]aspect_ratios: [0.25, 0.5, 1.0, 2.0]height_stride: 16width_stride: 16}}first_stage_box_predictor_conv_hyperparams {op: CONVregularizer {l2_regularizer {weight: 0.0}}initializer {truncated_normal_initializer {stddev: 0.01}}}first_stage_nms_score_threshold: 0.0first_stage_nms_iou_threshold: 0.7first_stage_max_proposals: 300first_stage_localization_loss_weight: 2.0first_stage_objectness_loss_weight: 1.0initial_crop_size: 14maxpool_kernel_size: 2maxpool_stride: 2second_stage_box_predictor {mask_rcnn_box_predictor {use_dropout: falsedropout_keep_probability: 1.0fc_hyperparams {op: FCregularizer {l2_regularizer {weight: 0.0}}initializer {variance_scaling_initializer {factor: 1.0uniform: truemode: FAN_AVG}}}}}second_stage_post_processing {batch_non_max_suppression {score_threshold: 0.0iou_threshold: 0.6max_detections_per_class: 100max_total_detections: 300}score_converter: SOFTMAX}second_stage_localization_loss_weight: 2.0second_stage_classification_loss_weight: 1.0}
}train_config: {batch_size: 1optimizer {momentum_optimizer: {learning_rate: {manual_step_learning_rate {initial_learning_rate: 0.0003schedule {step: 10000learning_rate: .00003}schedule {step: 20000learning_rate: .000003}}}momentum_optimizer_value: 0.9}use_moving_average: false}gradient_clipping_by_norm: 10.0fine_tune_checkpoint: "C:/ObjectDetect/models/research/object_detection/faster_rcnn_inception_v2_coco_2018_01_28/model.ckpt"from_detection_checkpoint: true# Note: The below line limits the training process to 200K steps, which we# empirically found to be sufficient enough to train the pets dataset. This# effectively bypasses the learning rate schedule (the learning rate will# never decay). Remove the below line to train indefinitely.num_steps: 200000data_augmentation_options {random_horizontal_flip {}}
}train_input_reader: {tf_record_input_reader {input_path: "C:/ObjectDetect/models/research/object_detection/train.record"}label_map_path: "C:/ObjectDetect/models/research/object_detection/training/labelmap.pbtxt"
}eval_config: {num_examples: 27# Note: The below line limits the evaluation process to 10 evaluations.# Remove the below line to evaluate indefinitely.max_evals: 10
}eval_input_reader: {tf_record_input_reader {input_path: "C:/ObjectDetect/models/research/object_detection/test.record"}label_map_path: "C:/ObjectDetect/models/research/object_detection/training/labelmap.pbtxt"shuffle: falsenum_readers: 1
}

6. 模型训练

已有的模型训练脚本:

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================r"""Training executable for detection models.This executable is used to train DetectionModels. There are two ways of
configuring the training job:1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
can be specified by --pipeline_config_path.Example usage:./train \--logtostderr \--train_dir=path/to/train_dir \--pipeline_config_path=pipeline_config.pbtxt2) Three configuration files can be provided: a model_pb2.DetectionModel
configuration file to define what type of DetectionModel is being trained, an
input_reader_pb2.InputReader file to specify what training data will be used and
a train_pb2.TrainConfig file to configure training parameters.Example usage:./train \--logtostderr \--train_dir=path/to/train_dir \--model_config_path=model_config.pbtxt \--train_config_path=train_config.pbtxt \--input_config_path=train_input_config.pbtxt
"""import functools
import json
import os
import tensorflow as tffrom object_detection import trainer
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import dataset_utiltf.logging.set_verbosity(tf.logging.INFO)flags = tf.app.flags
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
flags.DEFINE_integer('task', 0, 'task id')
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
flags.DEFINE_boolean('clone_on_cpu', False,'Force clones to be deployed on CPU.  Note that even if ''set to False (allowing ops to run on gpu), some ops may ''still be run on the CPU if they have no GPU kernel.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer ''replicas.')
flags.DEFINE_integer('ps_tasks', 0,'Number of parameter server tasks. If None, does not use ''a parameter server.')
flags.DEFINE_string('train_dir', '','Directory to save the checkpoints and training summaries.')flags.DEFINE_string('pipeline_config_path', '','Path to a pipeline_pb2.TrainEvalPipelineConfig config ''file. If provided, other configs are ignored')flags.DEFINE_string('train_config_path', '','Path to a train_pb2.TrainConfig config file.')
flags.DEFINE_string('input_config_path', '','Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '','Path to a model_pb2.DetectionModel config file.')FLAGS = flags.FLAGSdef main(_):assert FLAGS.train_dir, '`train_dir` is missing.'if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)if FLAGS.pipeline_config_path:configs = config_util.get_configs_from_pipeline_file(FLAGS.pipeline_config_path)if FLAGS.task == 0:tf.gfile.Copy(FLAGS.pipeline_config_path,os.path.join(FLAGS.train_dir, 'pipeline.config'),overwrite=True)else:configs = config_util.get_configs_from_multiple_files(model_config_path=FLAGS.model_config_path,train_config_path=FLAGS.train_config_path,train_input_config_path=FLAGS.input_config_path)if FLAGS.task == 0:for name, config in [('model.config', FLAGS.model_config_path),('train.config', FLAGS.train_config_path),('input.config', FLAGS.input_config_path)]:tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),overwrite=True)model_config = configs['model']train_config = configs['train_config']input_config = configs['train_input_config']model_fn = functools.partial(model_builder.build,model_config=model_config,is_training=True)def get_next(config):return dataset_util.make_initializable_iterator(dataset_builder.build(config)).get_next()create_input_dict_fn = functools.partial(get_next, input_config)env = json.loads(os.environ.get('TF_CONFIG', '{}'))cluster_data = env.get('cluster', None)cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else Nonetask_data = env.get('task', None) or {'type': 'master', 'index': 0}task_info = type('TaskSpec', (object,), task_data)# Parameters for a single worker.ps_tasks = 0worker_replicas = 1worker_job_name = 'lonely_worker'task = 0is_chief = Truemaster = ''if cluster_data and 'worker' in cluster_data:# Number of total worker replicas include "worker"s and the "master".worker_replicas = len(cluster_data['worker']) + 1if cluster_data and 'ps' in cluster_data:ps_tasks = len(cluster_data['ps'])if worker_replicas > 1 and ps_tasks < 1:raise ValueError('At least 1 ps task is needed for distributed training.')if worker_replicas >= 1 and ps_tasks > 0:# Set up distributed training.server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',job_name=task_info.type,task_index=task_info.index)if task_info.type == 'ps':server.join()returnworker_job_name = '%s/task:%d' % (task_info.type, task_info.index)task = task_info.indexis_chief = (task_info.type == 'master')master = server.targetgraph_rewriter_fn = Noneif 'graph_rewriter_config' in configs:graph_rewriter_fn = graph_rewriter_builder.build(configs['graph_rewriter_config'], is_training=True)trainer.train(create_input_dict_fn,model_fn,train_config,master,task,FLAGS.num_clones,worker_replicas,FLAGS.clone_on_cpu,ps_tasks,worker_job_name,is_chief,FLAGS.train_dir,graph_hook_fn=graph_rewriter_fn)if __name__ == '__main__':tf.app.run()

确保以上步骤无误后可以开始训练

python train.py --logtostderr --train_dir=training/ --pipeline_config_path=training/faster_rcnn_inception_doraemon.config

在TODA目录下打开tensorboard

tensorboard --logdir=training

在经过8000 steps 后可以看到的训练情况为:

7. ckpt模型转换为pb模型

使用tensorflow训练模型,ckpt作为tensorflow训练生成的模型,可以在tensorflow内部使用。但是如果想要永久保存,最好将其导出成pb的形式。利用export_inference_graph.py 将modle.ckpt-8234转换为pb模型文件,代码如下:

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================r"""Tool to export an object detection model for inference.Prepares an object detection tensorflow graph for inference using model
configuration and an optional trained checkpoint. Outputs inference
graph, associated checkpoint files, a frozen inference graph and a
SavedModel (https://tensorflow.github.io/serving/serving_basic.html).The inference graph contains one of three input nodes depending on the user
specified option.* `image_tensor`: Accepts a uint8 4-D tensor of shape [None, None, None, 3]* `encoded_image_string_tensor`: Accepts a 1-D string tensor of shape [None]containing encoded PNG or JPEG images. Image resolutions are expected to bethe same if more than 1 image is provided.* `tf_example`: Accepts a 1-D string tensor of shape [None] containingserialized TFExample protos. Image resolutions are expected to be the sameif more than 1 image is provided.and the following output nodes returned by the model.postprocess(..):* `num_detections`: Outputs float32 tensors of the form [batch]that specifies the number of valid boxes per image in the batch.* `detection_boxes`: Outputs float32 tensors of the form[batch, num_boxes, 4] containing detected boxes.* `detection_scores`: Outputs float32 tensors of the form[batch, num_boxes] containing class scores for the detections.* `detection_classes`: Outputs float32 tensors of the form[batch, num_boxes] containing classes for the detections.* `detection_masks`: Outputs float32 tensors of the form[batch, num_boxes, mask_height, mask_width] containing predicted instancemasks for each box if its present in the dictionary of postprocessedtensors returned by the model.Notes:* This tool uses `use_moving_averages` from eval_config to decide whichweights to freeze.Example Usage:
python export_inference_graph.py --input_type image_tensor --pipeline_config_path training_D/faster_rcnn_inception_doraemon.config --trained_checkpoint_prefix training_D/model.ckpt-6993 --output_directory inference_graphD
--------------
python export_inference_graph \--input_type image_tensor \--pipeline_config_path path/to/ssd_inception_v2.config \--trained_checkpoint_prefix path/to/model.ckpt \--output_directory path/to/exported_model_directoryThe expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:- graph.pbtxt- model.ckpt.data-00000-of-00001- model.ckpt.info- model.ckpt.meta- frozen_inference_graph.pb+ saved_model (a directory)Config overrides (see the `config_override` flag) are text protobufs
(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
certain fields in the provided pipeline_config_path.  These are useful for
making small changes to the inference graph that differ from the training or
eval config.Example Usage (in which we change the second stage post-processing score
threshold to be 0.5):python export_inference_graph \--input_type image_tensor \--pipeline_config_path path/to/ssd_inception_v2.config \--trained_checkpoint_prefix path/to/model.ckpt \--output_directory path/to/exported_model_directory \--config_override " \model{ \faster_rcnn { \second_stage_post_processing { \batch_non_max_suppression { \score_threshold: 0.5 \} \} \} \}"
"""
import tensorflow as tf
from google.protobuf import text_format
from object_detection import exporter
from object_detection.protos import pipeline_pb2slim = tf.contrib.slim
flags = tf.app.flagsflags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be ''one of [`image_tensor`, `encoded_image_string_tensor`, ''`tf_example`]')
flags.DEFINE_string('input_shape', None,'If input_type is `image_tensor`, this can explicitly set ''the shape of this input tensor to a fixed size. The ''dimensions are to be provided as a comma-separated list ''of integers. A value of -1 can be used for unknown ''dimensions. If not specified, for an `image_tensor, the ''default shape will be partially specified as ''`[None, None, None, 3]`.')
flags.DEFINE_string('pipeline_config_path', None,'Path to a pipeline_pb2.TrainEvalPipelineConfig config ''file.')
flags.DEFINE_string('trained_checkpoint_prefix', None,'Path to trained checkpoint, typically of the form ''path/to/model.ckpt')
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('config_override', '','pipeline_pb2.TrainEvalPipelineConfig ''text proto to override pipeline_config_path.')
tf.app.flags.mark_flag_as_required('pipeline_config_path')
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
tf.app.flags.mark_flag_as_required('output_directory')
FLAGS = flags.FLAGSdef main(_):pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:text_format.Merge(f.read(), pipeline_config)text_format.Merge(FLAGS.config_override, pipeline_config)if FLAGS.input_shape:input_shape = [int(dim) if dim != '-1' else Nonefor dim in FLAGS.input_shape.split(',')]else:input_shape = Noneexporter.export_inference_graph(FLAGS.input_type, pipeline_config,FLAGS.trained_checkpoint_prefix,FLAGS.output_directory, input_shape)if __name__ == '__main__':tf.app.run()
python export_inference_graph.py --input_type image_tensor --pipeline_config_path training/faster_rcnn_inception_doraemon.config --trained_checkpoint_prefix training_D/model.ckpt-8243 --output_directory inference_graph

此时inference_graph目录下的文件为:

8. 模型测试

python 脚本Object_detection_image.py :

######## Image Object Detection Using Tensorflow-trained Classifier #########
#
# Author: Evan Juras
# Date: 1/15/18
# Description:
# This program uses a TensorFlow-trained classifier to perform object detection.
# It loads the classifier uses it to perform object detection on an image.
# It draws boxes and scores around the objects of interest in the image.## Some of the code is copied from Google's example at
## https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb## and some is copied from Dat Tran's example at
## https://github.com/datitran/object_detector_app/blob/master/object_detection_app.py## but I changed it to make it more understandable to me.# Import packages
import os
import cv2
import numpy as np
import tensorflow as tf
import sys# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")# Import utilites
from utils import label_map_util
from utils import visualization_utils as vis_util# Name of the directory containing the object detection module we're using
MODEL_NAME = 'inference_graph'
IMAGE_NAME = './images/test/00000391.jpg'# Grab path to current working directory
CWD_PATH = os.getcwd()# Path to frozen detection graph .pb file, which contains the model that is used
# for object detection.
PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,'frozen_inference_graph.pb')# Path to label map file
PATH_TO_LABELS = os.path.join(CWD_PATH,'training','labelmap.pbtxt')# Path to image
PATH_TO_IMAGE = os.path.join(CWD_PATH,IMAGE_NAME)# Number of classes the object detector can identify
NUM_CLASSES = 1# Load the label map.
# Label maps map indices to category names, so that when our convolution
# network predicts `5`, we know that this corresponds to `king`.
# Here we use internal utility functions, but anything that returns a
# dictionary mapping integers to appropriate string labels would be fine
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)# Load the Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():od_graph_def = tf.GraphDef()with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')sess = tf.Session(graph=detection_graph)# Define input and output tensors (i.e. data) for the object detection classifier# Input tensor is the image
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')# Output tensors are the detection boxes, scores, and classes
# Each box represents a part of the image where a particular object was detected
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')# Each score represents level of confidence for each of the objects.
# The score is shown on the result image, together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')# Number of objects detected
num_detections = detection_graph.get_tensor_by_name('num_detections:0')# Load image using OpenCV and
# expand image dimensions to have shape: [1, None, None, 3]
# i.e. a single-column array, where each item in the column has the pixel RGB value
image = cv2.imread(PATH_TO_IMAGE)
image_expanded = np.expand_dims(image, axis=0)# Perform the actual detection by running the model with the image as input
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_expanded})# Draw the results of the detection (aka 'visulaize the results')vis_util.visualize_boxes_and_labels_on_image_array(image,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),category_index,use_normalized_coordinates=True,line_thickness=1,min_score_thresh=0.50)# All the results have been drawn on image. Now display the image.
cv2.imshow('Object detector', image)# Press any key to close the image
if cv2.waitKey(0) & 0xFF == ord('q'):breakcv2.imwrite("doraemon_result.png", image)
# Clean upcv2.destroyAllWindows()

测试

python Object_detection_image.py

得到doraemon_result.png 结果图片为:

总结

  • 在 Tensorflow Object Detction API 基础之上
  • 搜集图片、标注图片、转换为tfrecord格式
  • 编辑labelmap.pbtxt和模型配置文件
流程跑通之后的调参工作可以参考相关算法及改进的论文,实践出真知,这个需要时间和项目的磨练与总结和思考。最后,一首歪诗表达我在学习AI过程中的实践观:

初入深度学习,
时常一脸懵逼。
不妨多多实践,
万变自在心田。

基于Tensorflow Object Detection API 的哆啦A梦检测相关推荐

  1. 基于TensorFlow Object Detection API训练自己的目标识别模型

    基于TensorFlow Object Detection API训练自己的目标识别模型 环境 Windows10 CUDA_9 Cudnn_9.0 Anaconda3-5.2.0 Tensorflo ...

  2. Tensorflow object detection API 搭建自己的目标检测模型并迁移到Android上

    参考链接:https://blog.csdn.net/dy_guox/article/details/79111949 之前参考上述一系列博客在Windows10下面成功运行了TensorFlow A ...

  3. 使用tensorflow object detection API 训练自己的目标检测模型 (三)

    在上一篇博客"使用tensorflow object detection API 训练自己的目标检测模型 (二)"中介绍了如何使用LabelImg标记数据集,生成.xml文件,经过 ...

  4. Tensorflow object detection API训练自己的目标检测模型 详细配置教程 (一)

    Tensorflow object detection API 简单介绍Tensorflow object detection API: 这个API是基于tensorflow构造的开源框架,易于构建. ...

  5. 使用tensorflow object detection API 训练自己的目标检测模型 (二)labelImg的安装配置过程

    上一篇博客介绍了goggle的tensorflow object detection API 的配置和使用, 这次介绍一下如何用这个API训练一个私人定制的目标检测模型. 第一步:准备自己的数据集.比 ...

  6. 关于使用tensorflow object detection API训练自己的模型-补充部分(代码,数据标注工具,训练数据,测试数据)

    之前分享过关于tensorflow object detection API训练自己的模型的几篇博客,后面有人陆续碰到一些问题,问到了我解决方法.所以在这里补充点大家可能用到的东西.声明一下,本人专业 ...

  7. 谷歌开放的TensorFlow Object Detection API 效果如何?对业界有什么影响

    ? 谷歌开放了一个 Object Detection API: Supercharge your C 写个简单的科普帖吧. 熟悉TensorFlow的人都知道,tf在Github上的主页是:tenso ...

  8. 使用Tensorflow Object Detection API进行集装箱识别并对集装箱号进行OCR识别

    使用Tensorflow Object Detection API进行集装箱识别并对集装箱号进行OCR识别 两年多之前我在"ex公司"的时候,有一个明确的项目需求是集装箱识别并计数 ...

  9. 使用Tensorflow Object Detection API对集装箱号进行OCR识别

    玄念 两年多之前我在"ex公司"的时候,有一个明确的项目需求是集装箱识别并计数,然后通过OCR识别出之前计数的每一个集装箱号,与其余业务系统的数据进行交换,以实现特定的整体需求.当 ...

  10. TensorFlow学习——Tensorflow Object Detection API(2.目标检测篇)

    2017 年 6 月, Google 公司开放了 TensorFlow Object Detection API . 这 个项目使用 TensorFlow 实现了大多数深度学习目标检测框架,真中就包括 ...

最新文章

  1. stl vector 函数_vector :: push_back()函数,以及C ++ STL中的示例
  2. jQuery 实现Ajax
  3. 4-输出基本数据类型
  4. c52单片机c语言编程,c52单片机c语言编程怎样实现阴历查询
  5. uboot 使用fdt命令查看设备树
  6. django使用ajax传输数据
  7. mysql可视化连接的错误及解决方案
  8. Typora的使用方法
  9. android移植大作游戏,这款steam移植的1GB大作,或许是今年最有氛围的悬疑游戏
  10. HTML+css中鼠标经过触发等问题
  11. 互联网3D数字化时代,3D产品展示开启新商机
  12. docker:虚拟化和docker容器概念
  13. pg_freespacemap
  14. ajax学员信息php,PHP开源AJAX框架
  15. ASP是什么?ASP初识
  16. 匿名内部类以及Lambda表达式(Java和C#)
  17. (四)双击放大与缩小图片
  18. 脑芯编 | 窥脑究竟,结网造芯(三)
  19. 苹果手机二手最新价格
  20. Spring Cloud升级之路 - Hoxton - 8. 修改实例级别的熔断为实例+方法级别

热门文章

  1. 计算机技术与高中英语教学整合,计算机技术与高中英语课程整合的探讨(多媒体英语教学系列论文十篇).doc...
  2. echarts2的一个地图demo
  3. 手把手教你规划IP地址
  4. 5G笔记| 概述:F-OFDM、新型NOMA多址
  5. 采样频率和带宽的关系_等效时间采样示波器和实时示波器的差别?
  6. MyQQ:可以在终端里面上的QQ
  7. 计算机用户密码怎么查看,怎么查看电脑开机密码【具体阐明】
  8. 谷歌退出中国市场传言推动百度股价大涨4.8%【转载】
  9. Ogre:render to texture
  10. 计算机主板检测卡0d,主板检测卡的0d码是什么意思?