前面已经介绍了几种经典的目标检测算法,光学习理论不实践的效果并不大,这里咱们使用谷歌的开源框架来实现目标检测。至于为何不去本身实现呢?主要是由于本身实现比较麻烦,并且调参比较麻烦,咱们直接利用别人的库去学习,能够节约不少时间,并且逐渐吃透别人代码,使得咱们能够慢慢的接受。html
Object Detection API是谷歌开放的一个内部使用的物体识别系统。2016年 10月,该系统在COCO识别挑战中名列第一。它支持当前最佳的实物检测模型,可以在单个图像中定位和识别多个对象。该系统不只用于谷歌于自身的产品和服务,还被推广至整个研究社区。python
Object Detection 模块的位置与slim的位置相近,同在github.com 中TensorFlow 的models\research目录下。相似slim, Object Detection也囊括了各类关于物体检测的各类先进模型:git
上述每个模型的冻结权重 (在COCO数据集上训练)可被直接加载使用。github
SSD模型使用了轻量化的MobileNet,这意味着它们能够垂手可得地在移动设备中实时使用。谷歌使用了 Faster R-CNN模型须要更多计算资源,但结果更为准确。算法
在在实物检测领域,训练模型的最权威数据集就是COCO数据集。
COCO数据集是微软发布的一个能够用来进行图像识别训练的数据集,官方网址为http://mscoco.org 其图像主要从复杂的平常场景中截取,图像中的目标经过精确的segmentation进行位置的标定。
COCO数据集包括91类目标,分两部分发布,前部分于2014年发布,后部分于2015年发布。express
Objet Detection API使用protobufs来配置模型和训练参数,这些文件以".proto"的扩展名放models\research\object_detection\protos下。在使用框架以前,必须使用protobuf库将其编译成py文件才能够正常运行。protobuf库的下载地址为https://github.com/google/protobuf/releases/tag/v2.6.1apache
下载并解压protoc-2.6.1-win32.zip到models\research路径下。ubuntu
打开cmd命令行,进入models\research目录下,执行以下命令windows
protoc.exe object_detection/protos/*.proto --python_out=.
若是不显示任何信息,则代表运行成功了,为了检验成功效果,来到models\research\object_detection\protos下,能够看到生成不少.py文件。数组
若是前面两步都完成了,下面能够测试一下object detection API是否能够正常使用,还须要两步操做:
代表object detection API一切正常,可使用、
为了避免用每次都将文件复制到Object Detection文件夹外,能够将Object Detection加到python引入库的默认搜索路径中,将Object Detection文件整个复制到anaconda3安装文件目录下lib\site-packages下:
这样不管文件在哪里,只要搜索import Objec Detection xxx,系统到会找到Objec Detection。
以前已经说过Objec Detection API默认提供了5个预训练模型。他们都是使用COCO数据集训练完成的,如何使用这些预训练模型呢?官方已经给了一个用jupyter notebook编写好的例子。首先在research文件下下,运行命令:jupyter-notebook,会直接打开http://localhost:8888/tree。
接着打开object_detection文件夹,并单击object_detection_tutorial.jpynb运行示例文件。
该代码使用Object Detection API基于COCO上训练的ssd_mobilenet_v1模型,对任意图片进行分类识别。
以前介绍的已有模型,在下面网站能够下载:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
每个压缩文件里包含如下文件:
咱们在models\research文件夹下建立一个文件夹my_download_pretrained,用于保存预训练的模型。
咱们对该代码进行一些修改,并给出该代码的中文注释:
在models\research下建立my_object_detection.py文件。程序只能在GPU下运行,CPU会报错。
# -*- coding: utf-8 -*- """ Created on Tue Jun 5 20:34:06 2018 @author: zy """ ''' 调用Object Detection API进行实物检测 须要GPU运行环境,CPU下会报错 模型下载网址:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md TensorFlow 生成的 .ckpt 和 .pb 都有什么用? https://www.cnblogs.com/nowornever-L/p/6991295.html 如何用Tensorflow训练模型成pb文件(一)——基于原始图片的读取 https://blog.csdn.net/u011463646/article/details/77918980?fps=1&locationNum=7 ''' import matplotlib.pyplot as plt import numpy as np import os import tensorflow as tf from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util from PIL import Image def test(): #重置图 tf.reset_default_graph() ''' 载入模型以及数据集样本标签,加载待测试的图片文件 ''' #指定要使用的模型的路径 包含图结构,以及参数 PATH_TO_CKPT = './my_download_pretrained/ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb' #测试图片所在的路径 PATH_TO_TEST_IMAGES_DIR = './object_detection/test_images' TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR,'image{}.jpg'.format(i)) for i in range(1,3) ] #数据集对应的label mscoco_label_map.pbtxt文件保存了index到类别名的映射 PATH_TO_LABELS = os.path.join('./object_detection/data','mscoco_label_map.pbtxt') NUM_CLASSES = 90 #从新定义一个图 output_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT,'rb') as fid: #将*.pb文件读入serialized_graph serialized_graph = fid.read() #将serialized_graph的内容恢复到图中 output_graph_def.ParseFromString(serialized_graph) #print(output_graph_def) #将output_graph_def导入当前默认图中(加载模型) tf.import_graph_def(output_graph_def,name='') print('模型加载完成') #载入coco数据集标签文件 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes = NUM_CLASSES,use_display_name = True) category_index = label_map_util.create_category_index(categories) ''' 定义session ''' def load_image_into_numpy_array(image): ''' 将图片转换为ndarray数组的形式 ''' im_width,im_height = image.size return np.array(image.getdata()).reshape((im_height,im_width,3)).astype(np.uint0) #设置输出图片的大小 IMAGE_SIZE = (12,8) #使用默认图,此时已经加载了模型 detection_graph = tf.get_default_graph() with tf.Session(graph=detection_graph) as sess: for image_path in TEST_IMAGE_PATHS: image = Image.open(image_path) #将图片转换为numpy格式 image_np = load_image_into_numpy_array(image) ''' 定义节点,运行并可视化 ''' #将图片扩展一维,最后进入神经网络的图片格式应该是[1,?,?,3] image_np_expanded = np.expand_dims(image_np,axis = 0) ''' 获取模型中的tensor ''' image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') #boxes用来显示识别结果 boxes = detection_graph.get_tensor_by_name('detection_boxes:0') #Echo score表明识别出的物体与标签匹配的类似程度,在类型标签后面 scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') #开始检查 boxes,scores,classes,num_detections = sess.run([boxes,scores,classes,num_detections], feed_dict={image_tensor:image_np_expanded}) #可视化结果 vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) print(type(image_np)) print(image_np.shape) image_np = np.array(image_np,dtype=np.uint8) plt.imshow(image_np) if __name__ == '__main__': test()
以VOC 2012数据集为例,介绍如何使用Object Detection API训练新的模型。VOC 2012是VOC2007数据集的升级版,一共有11530张图片,每张图片都有标准,标注的物体包括人、动物(如猫、狗、鸟等)、交通工具(如车、船飞机等)、家具(如椅子、桌子、沙发等)在内的20个类别。
首先下载数据集,并将其转换为tfrecord格式。下载地址为:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar。
首先下载谷歌models库,而后删除一些没必要要的文件,获得文件结构以下:
在research文件夹下,建立一个voc文件夹,把VOC2012解压到这个文件夹下,解压后,获得一个VOCdevkit文件夹:
JPEGImages文件中文件夹里存放了所有的训练图片和验证图片。
对于每一张图像,都在Annotations文件夹中存放有对应的xml文件。保存着物体框的标注,包括图片文件名,图片大小,图片边界框等信息。
以2007_000027.xml为例:
<annotation> #数据所在的文件夹名 <folder>VOC2012</folder> #图片名称 <filename>2007_000027.jpg</filename> <source> <database>The VOC2007 Database</database> <annotation>PASCAL VOC2007</annotation> <image>flickr</image> </source> #图片的宽和高 <size> <width>486</width> <height>500</height> <depth>3</depth> </size> <segmented>0</segmented> <object> #类别名 <name>person</name> #物体的姿式 <pose>Unspecified</pose> #物体是否被部分遮挡 <truncated>0</truncated> ##是否为难以辨识的物体, 主要指要结合背景才能判断出类别的物体。虽有标注, 但通常忽略这类物体 跳过难以识别的? <difficult>0</difficult> #边界框 <bndbox> <xmin>174</xmin> <ymin>101</ymin> <xmax>349</xmax> <ymax>351</ymax> </bndbox> #下面的数据是人体各个部位边界框 <part> <name>head</name> <bndbox> <xmin>169</xmin> <ymin>104</ymin> <xmax>209</xmax> <ymax>146</ymax> </bndbox> </part> <part> <name>hand</name> <bndbox> <xmin>278</xmin> <ymin>210</ymin> <xmax>297</xmax> <ymax>233</ymax> </bndbox> </part> <part> <name>foot</name> <bndbox> <xmin>273</xmin> <ymin>333</ymin> <xmax>297</xmax> <ymax>354</ymax> </bndbox> </part> <part> <name>foot</name> <bndbox> <xmin>319</xmin> <ymin>307</ymin> <xmax>340</xmax> <ymax>326</ymax> </bndbox> </part> </object> </annotation>
ImageSets文件夹包括Action Layout Main Segmentation四部分,Action存放的是人的动做,Layout存放人体部位数据,Main存放的是图像物体识别数据(里面的test.txt,train.txt,val.txt,trainval.txt当本身制做数据集时须要生成)。
ImageSets\Main文件夹以下。
SegmentationClass(标注出每个像素的类别)和SegmentationObject(标注出每一个像素属于哪个物体)是分割相关的。
把pascal_label_map.pbtxt文件复制到voc文件夹下,这个文件存放在voc2012数据集物体的索引和对应的名字。
从object_detection\dataset_tools下把create_pascal_tf_record.py文件复制到research文件夹下,这个代码是为VOC2012数据集提早编写好的。代码以下:
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Convert raw PASCAL dataset to TFRecord for object_detection. Example usage: ./create_pascal_tf_record --data_dir=/home/user/VOCdevkit \ --year=VOC2012 \ --output_path=/home/user/pascal.record """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import hashlib import io import logging import os from lxml import etree import PIL.Image import tensorflow as tf from object_detection.utils import dataset_util from object_detection.utils import label_map_util import sys #配置logging logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO, stream=sys.stdout) #命令行参数 主要包括数据集根目录,数据类型,输出tf文件路径等 flags = tf.app.flags flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.') flags.DEFINE_string('set', 'train', 'Convert training set, validation set or ' 'merged set.') flags.DEFINE_string('annotations_dir', 'Annotations', '(Relative) path to annotations directory.') flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.') flags.DEFINE_string('output_path', '', 'Path to output TFRecord') flags.DEFINE_string('label_map_path', 'voc/pascal_label_map.pbtxt', 'Path to label map proto') flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore ' 'difficult instances') FLAGS = flags.FLAGS SETS = ['train', 'val', 'trainval', 'test'] YEARS = ['VOC2007', 'VOC2012', 'merged'] def dict_to_tf_example(data, dataset_directory, label_map_dict, ignore_difficult_instances=False, image_subdirectory='JPEGImages'): """Convert XML derived dict to tf.Example proto. Notice that this function normalizes the bounding box coordinates provided by the raw data. Args: data: dict holding PASCAL XML fields for a single image (obtained by running dataset_util.recursive_parse_xml_to_dict) dataset_directory: Path to root directory holding PASCAL dataset label_map_dict: A map from string label names to integers ids. ignore_difficult_instances: Whether to skip difficult instances in the dataset (default: False). image_subdirectory: String specifying subdirectory within the PASCAL dataset directory holding the actual image data. Returns: example: The converted tf.Example. Raises: ValueError: if the image pointed to by data['filename'] is not a valid JPEG """ #获取图片相对数据集的相对路径 img_path = os.path.join(data['folder'], image_subdirectory, data['filename']) #获取图片绝对路径 full_path = os.path.join(dataset_directory, img_path) #读取图片 with tf.gfile.GFile(full_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = PIL.Image.open(encoded_jpg_io) if image.format != 'JPEG': raise ValueError('Image format not JPEG') key = hashlib.sha256(encoded_jpg).hexdigest() #获取图片的宽和高 width = int(data['size']['width']) height = int(data['size']['height']) xmin = [] ymin = [] xmax = [] ymax = [] classes = [] classes_text = [] truncated = [] poses = [] difficult_obj = [] for obj in data['object']: #是否为难以辨识的物体, 主要指要结合背景才能判断出类别的物体。虽有标注, 但通常忽略这类物体 跳过难以识别的? difficult = bool(int(obj['difficult'])) if ignore_difficult_instances and difficult: continue difficult_obj.append(int(difficult)) #bounding box 计算目标边界 归一化到[0,1]之间 左上角坐标,右下角坐标 xmin.append(float(obj['bndbox']['xmin']) / width) ymin.append(float(obj['bndbox']['ymin']) / height) xmax.append(float(obj['bndbox']['xmax']) / width) ymax.append(float(obj['bndbox']['ymax']) / height) #类别名 classes_text.append(obj['name'].encode('utf8')) #获取该类别对应的标签 classes.append(label_map_dict[obj['name']]) #物体是否被部分遮挡 truncated.append(int(obj['truncated'])) #物体的姿式 poses.append(obj['pose'].encode('utf8')) #tf文件一条记录格式 example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/source_id': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmin), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmax), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymin), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymax), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), 'image/object/difficult': dataset_util.int64_list_feature(difficult_obj), 'image/object/truncated': dataset_util.int64_list_feature(truncated), 'image/object/view': dataset_util.bytes_list_feature(poses), })) return example def main(_): ''' 主要是经过读取VOCdevkit\VOC2012\Annotations下的xml文件 而后获取对应的图片文件的路径,图片大小,文件名,边界框、以及图片数据等信息写入rfrecord文件 ''' if FLAGS.set not in SETS: raise ValueError('set must be in : {}'.format(SETS)) if FLAGS.year not in YEARS: raise ValueError('year must be in : {}'.format(YEARS)) data_dir = FLAGS.data_dir years = ['VOC2007', 'VOC2012'] if FLAGS.year != 'merged': years = [FLAGS.year] #建立对象,用于向记录文件写入记录 writer = tf.python_io.TFRecordWriter(FLAGS.output_path) #获取类别名->index的映射 dict类型 label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path) for year in years: logging.info('Reading from PASCAL %s dataset.', year) #获取aeroplane_train.txt文件的全路径 改文件保存部分文件名(一共5717/5823个文件,各种图片都有) examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt') #获取全部图片标注xml文件的路径 annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir) #list 存放文件名 examples_list = dataset_util.read_examples_list(examples_path) #遍历annotations_dir目录下,examples_list中指定的xml文件 for idx, example in enumerate(examples_list): if idx % 100 == 0: logging.info('On image %d of %d', idx, len(examples_list)) path = os.path.join(annotations_dir, example + '.xml') with tf.gfile.GFile(path, 'r') as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) #获取annotation节点的内容 data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation'] #把数据整理成tfrecord须要的数据结构 tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances) #向tf文件写入一条记录 writer.write(tf_example.SerializeToString()) writer.close() if __name__ == '__main__': tf.app.run()
若是读者但愿使用本身的数据集,有两种方法:
在research文件夹中,执行如下命令能够把VOC 2012数据集转换为tfrecord格式,转换好的tfrecord保存在voc文件夹下,分别为pasal_train.record和pascal_val.record:
python create_pascal_tf_record.py --data_dir voc/VOCdevkit --year=VOC2012 --set=train --output_path=voc/pascal_train.record
python create_pascal_tf_record.py --data_dir voc/VOCdevkit --year=VOC2012 --set=val --output_path=voc/pascal_val.record
以上执行完成后,咱们把voc文件夹和create_pascal_tf_record.py文件剪切到object_detection文件下。(其实在以前咱们就能够直接把文件建立在object_detection文件夹下,主要是由于create_pascal_tf_record.py在执行的时候会调用到object_detection库,我是懒得把object_detection库加入环境变量了,因此才这样作。)
若是想配置临时环境变量,在research目录下:
windows下命令:
set PYTHONPATH=%PYTHONPATH%;%CD%;%CD%/slim
ubuntu系统下:
export PYTHONPATH=$PYTHONPATH:${PWD}:${PWD}/slim
下载完VOC 2012数据集后,须要选择合适的训练模型。这里以Faster R-CNN + Inception-ResNet_v2模型为例进行介绍。首先下载在COCO数据集上预训练的Faster R-CNN + Inception-ResNet_v2模型。解压到voc文件夹下,如图:
Object Detection API是依赖一种特殊的设置文件进行训练的。在object_detection/samples/configs文件夹下,有一些设置文件的示例。能够参考faster_rcnn_inception_resnet_v2_atrous_coco.config文件建立的设置文件。先将faster_rcnn_inception_resnet_v2_atrous_coco.config复制一份到voc文件夹下,命名为faster_rcnn_inception_resnet_v2_atrous_voc.config。
faster_rcnn_inception_resnet_v2_atrous_voc.config文件有7处须要修改:
gradient_clipping_by_norm: 10.0 fine_tune_checkpoint: "voc/faster_rcnn_inception_resnet_v2_atrous_coco_2018_01_28/model.ckpt" from_detection_checkpoint: true # Note: The below line limits the training process to 200K steps, which we # empirically found to be sufficient enough to train the pets dataset. This # effectively bypasses the learning rate schedule (the learning rate will # never decay). Remove the below line to train indefinitely. num_steps: 200000 data_augmentation_options { random_horizontal_flip { } } } train_input_reader: { tf_record_input_reader { input_path: "voc/pascal_train.record" } label_map_path: "voc/pascal_label_map.pbtxt" } eval_config: { num_examples: 5823 # Note: The below line limits the evaluation process to 10 evaluations. # Remove the below line to evaluate indefinitely. max_evals: 10 } eval_input_reader: { tf_record_input_reader { input_path: "voc/pascal_val.record" } label_map_path: "voc/pascal_label_map.pbtxt" shuffle: false num_readers: 1 }
最后,在voc文件夹中新建一个train_dir做为保存模型和日志的目录,在使用object_detection目录下的train.py文件训练的时候会使用到slim下库,所以咱们须要先配置临时环境变量,在research目录下:
windows下命令:
set PYTHONPATH=%PYTHONPATH%;%CD%;%CD%/slim
ubuntu系统下:
export PYTHONPATH=$PYTHONPATH:${PWD}:${PWD}/slim
在object_detection目录下,使用下面的命令就能够开始训练了:(要在GPU下运行,在CPU运行会抛出module 'tensorflow' has no attribute 'data'的错误)
python train.py --train_dir voc/train_dir/ --pipeline_config_path voc/faster_rcnn_inception_resnet_v2_atrous_voc.config
解决:
出错缘由:知乎的大佬说是python3的兼容问题
解决办法:把research/object_detection/utils/learning_schedules.py
文件的 第167-169行由
解决:
出错缘由:知乎的大佬说是python3的兼容问题
解决办法:把research/object_detection/utils/learning_schedules.py
文件的 第167-169行由
程序运行结果以下:
....
因为咱们在设置文件中设置的训练步数为200k,所以整个训练可能会消耗大量时间,这里我训练到4万屡次就强行终止训练了.
num_steps: 200000
训练的日志和最终的模型(默认保存了5个不一样步数时的模型)都会保存在train_dir中,所以,一样可使用TensorBoard来监控训练状况。
使用cmd来到日志文件的上级路径下,输入以下命令:
tensorboard --logdir ./train_dir
接着打开浏览器,输入http://127.0.0.1:6006,若是训练时保存了一下变量,则能够在这里看到(我这里没有保存变量):
须要注意的是,若是发生内存和显存不足报错的状况,除了使用较小模型进行训练外,还能够修改配置文件中的如下内容:
image_resizer { keep_aspect_ratio_resizer { min_dimension: 600 max_dimension: 1024 } }
这个部分表示将输入图像进行等比例缩放再进行训练,缩放后的最大边长为1024,最小边长为600.能够将整两个数值改小(我训练的时候就分别改为512和300),使用的显存就会变小。不过这样作也可能致使模型的精度降低,所以咱们须要根据本身的状况选择适合的处理方法。
如何将train_dir中的checkpoint文件导出并用于单张图片的目标检测?TensorFlow Object Detection API提供了一个export_inference_graph.py脚本用于导出训练好的模型。具体方法是在research目录下执行:
python export_inference_graph.py --input_type image_tensor --pipeline_config_path voc/faster_rcnn_inception_resnet_v2_atrous_voc.config --trained_checkpoint_prefix voc/train_dir/model.ckpt-47837 --output_directory voc/export
其中model.ckpt-47837表示使用第47837步保存的模型。咱们须要根据voc/train_dir时间保存的checkpoint,将47837改成合适的数值。导出的模型是voc/export/frozen_inference_graph.pb文件。
而后能够参考上面咱们介绍的jupyter notebook代码,自行编写利用导出模型对单张图片作目标检测的脚本。而后将PATH_TO_CKPT的值赋值为voc/export/frozen_inference_graph.pb,即导出模型文件。将PATH_TO_LABELS修改成voc/pascal_label_map.pbtxt,即各个类别的名称。把NUM_CLASSES设置为20。其它代码均可以不改变,而后测试咱们的图片(注意:须要添加上文中提到的临时环境变量),因为VOC2012数据集中的类别也有狗和人,所以咱们能够直接使用object_detection/test_images中的测试图片。
# -*- coding: utf-8 -*- """ Created on Tue Jun 5 20:34:06 2018 @author: zy """ ''' 调用Object Detection API进行实物检测 须要GPU运行环境,CPU下会报错 TensorFlow 生成的 .ckpt 和 .pb 都有什么用? https://www.cnblogs.com/nowornever-L/p/6991295.html 如何用Tensorflow训练模型成pb文件(一)——基于原始图片的读取 https://blog.csdn.net/u011463646/article/details/77918980?fps=1&locationNum=7 ''' #运行前须要把object_detection添加到环境变量 #ubuntu 在research目录下,打开终端,执行export PYTHONPATH=$PYTHONPATH:${PWD}:${PWD}/slim 而后执行spyder,运行程序 #windows 在research目录下,打开cmd,执行set PYTHONPATH=%PYTHONPATH%;%CD%;%CD%/slim 而后执行spyder,运行程序 import matplotlib.pyplot as plt import numpy as np import os import tensorflow as tf from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util from PIL import Image def test(): #重置图 tf.reset_default_graph() ''' 载入模型以及数据集样本标签,加载待测试的图片文件 ''' #指定要使用的模型的路径 包含图结构,以及参数 PATH_TO_CKPT = './voc/export/frozen_inference_graph.pb' #测试图片所在的路径 PATH_TO_TEST_IMAGES_DIR = './test_images' TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR,'image{}.jpg'.format(i)) for i in range(1,3) ] #数据集对应的label pascal_label_map.pbtxt文件保存了index和类别名之间的映射 PATH_TO_LABELS = './voc/pascal_label_map.pbtxt' NUM_CLASSES = 20 #从新定义一个图 output_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT,'rb') as fid: #将*.pb文件读入serialized_graph serialized_graph = fid.read() #将serialized_graph的内容恢复到图中 output_graph_def.ParseFromString(serialized_graph) #print(output_graph_def) #将output_graph_def导入当前默认图中(加载模型) tf.import_graph_def(output_graph_def,name='') print('模型加载完成') #载入coco数据集标签文件 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes = NUM_CLASSES,use_display_name = True) category_index = label_map_util.create_category_index(categories) ''' 定义session ''' def load_image_into_numpy_array(image): ''' 将图片转换为ndarray数组的形式 ''' im_width,im_height = image.size return np.array(image.getdata()).reshape((im_height,im_width,3)).astype(np.uint0) #设置输出图片的大小 IMAGE_SIZE = (12,8) #使用默认图,此时已经加载了模型 detection_graph = tf.get_default_graph() with tf.Session(graph=detection_graph) as sess: for image_path in TEST_IMAGE_PATHS: image = Image.open(image_path) #将图片转换为numpy格式 image_np = load_image_into_numpy_array(image) ''' 定义节点,运行并可视化 ''' #将图片扩展一维,最后进入神经网络的图片格式应该是[1,?,?,3] image_np_expanded = np.expand_dims(image_np,axis = 0) ''' 获取模型中的tensor ''' image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') #boxes用来显示识别结果 boxes = detection_graph.get_tensor_by_name('detection_boxes:0') #Echo score表明识别出的物体与标签匹配的类似程度,在类型标签后面 scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') #开始检查 boxes,scores,classes,num_detections = sess.run([boxes,scores,classes,num_detections], feed_dict={image_tensor:image_np_expanded}) #可视化结果 vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) print(type(image_np)) print(image_np.shape) image_np = np.array(image_np,dtype=np.uint8) plt.imshow(image_np) if __name__ == '__main__': test()
咱们再来看一下若是直接使用官方在COCO数据集上训练的Faster R-CNN + Inception-ResNet_v2模型,进行目标检测:
咱们能够看到咱们使用本身数据集训练的模型进行目标检测效果没有官方提供的模型那个好,可能有如下几个缘由:
参考文章:
[1]将数据集作成VOC2007格式用于Faster-RCNN训练
[2]VOC数据集制做2——ImageSets\Main里的四个txt文件
[3]21个项目玩转深度学习-何之源
[4]深度学习之TensorFlow-李金洪