一、安装

首先安装mxnet(cpu版):
命令行输入pip install -i https://pypi.doubanio.com/simple/ mxnet(豆瓣源、速度快)
接着安装gluoncv(cpu版):
命令行输入pip install --upgrade mxnet gluoncv(如果直接安装gluoncv,会从国外源下载mxnet安装包,速度慢到报错……建议还是先装mxnet再装gluoncv吧)

二、数据集制作

先自制PASCAL VOC格式的数据集,再将PASCAL VOC格式的数据集转换为适用于MXnet框架的数据集。

from gluoncv.data import VOCDetection
class VOCLike(VOCDetection):CLASSES = ['xxx1', 'xxx2']  #自己数据集中的类别def __init__(self, root, splits, transform=None, index_map=None, preload_label=True):super(VOCLike, self).__init__(root, splits, transform, index_map, preload_label)dataset = VOCLike(root='./VOCtemplate', splits=((2007, 'traineval'),))      #自己数据集的路径,包含VOC2007的上一层目录
print('length of dataset:', len(dataset))
print('label example:')
print(dataset[0][1])

报错:
d:\Anaconda3\lib\site-packages\gluoncv\data\pascal_voc\detection.py in _load_label(self, idx)
109 for obj in root.iter(‘object’):
110 try:
–> 111 difficult = int(obj.find(‘difficult’).text)
112 except ValueError:
113 difficult = 0

AttributeError: ‘NoneType’ object has no attribute ‘text’

解决:
我数据集的标签是用labelimg做的,标签中的difficult都写作“Difficult”,因此把gluoncv\data\pascal_voc\detection.py中的‘difficult’改为‘Difficult’即可。

三、迁移学习—以ssd_512_mobilenet1.0为例

#加载预训练过的模型
net = gcv.model_zoo.get_model('ssd_512_mobilenet1.0_voc', pretrained=True)
#重设类别
classes=['xxx1', 'xxx2']  #自己数据集中的类别
net.reset_class(classes)
#batch设置,规范输入网络进行训练的数据格式
def get_dataloader(net, train_dataset, data_shape, batch_size, num_workers):from gluoncv.data.batchify import Tuple, Stack, Padfrom gluoncv.data.transforms.presets.ssd import SSDDefaultTrainTransformwidth, height = data_shape, data_shape# use fake data to generate fixed anchors for target generationwith autograd.train_mode():_, _, anchors = net(mx.nd.zeros((1, 3, height, width)))batchify_fn = Tuple(Stack(), Stack(), Stack())  # stack image, cls_targets, box_targetstrain_loader = gluon.data.DataLoader(train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)),batch_size, True, batchify_fn=batchify_fn, last_batch='keep', num_workers=num_workers)return train_loadertrain_data = get_dataloader(net, dataset, 512, 4, 0)

设置用于训练的处理器类型

#尝试用gpu训练
try:a = mx.nd.zeros((1,), ctx=mx.gpu(0))ctx = [mx.gpu(0)]
except:ctx = [mx.cpu()]

开始训练:

net.collect_params().reset_ctx(ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd',{'learning_rate': 0.001, 'wd': 0.0005, 'momentum': 0.9})mbox_loss = gcv.loss.SSDMultiBoxLoss()
ce_metric = mx.metric.Loss('CrossEntropy')
smoothl1_metric = mx.metric.Loss('SmoothL1')for epoch in range(0, 10):    #设置epoch数ce_metric.reset()smoothl1_metric.reset()tic = time.time()btic = time.time()net.hybridize(static_alloc=True, static_shape=True)for i, batch in enumerate(train_data):batch_size = batch[0].shape[0]data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)with autograd.record():cls_preds = []box_preds = []for x in data:cls_pred, box_pred, _ = net(x)cls_preds.append(cls_pred)box_preds.append(box_pred)sum_loss, cls_loss, box_loss = mbox_loss(cls_preds, box_preds, cls_targets, box_targets)autograd.backward(sum_loss)# since we have already normalized the loss, we don't want to normalize# by batch-size anymoretrainer.step(1)ce_metric.update(0, [l * batch_size for l in cls_loss])smoothl1_metric.update(0, [l * batch_size for l in box_loss])name1, loss1 = ce_metric.get()name2, loss2 = smoothl1_metric.get()if i % 20 == 0:print('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))btic = time.time()

保存训练好的模型参数:

net.save_parameters('./ssd_512_mobilenet1.0_mound.params') #设置模型参数的存放路径

四、测试


net = gcv.model_zoo.get_model('ssd_512_mobilenet1.0_custom', classes=classes, pretrained_base=False)
net.load_parameters('./ssd_512_mobilenet1.0_mound.params') #载入模型参数
x, image = gcv.data.transforms.presets.ssd.load_test('./xxxx.jpg', 512) #读取测试图像
cid, score, bbox = net(x)
ax = viz.plot_bbox(image, bbox[0], score[0], cid[0], class_names=classes)
plt.show()

代码参考gluoncv的 tutorial
https://gluon-cv.mxnet.io/build/examples_detection/finetune_detection.html

MXnet-gluoncv实现基于迁移学习的目标检测(自己的数据集)相关推荐

  1. Yolov5官方网络改进:增加search模块(基于迁移学习的目标检测+多模态零样本自定义标签分类网络)

    1.效果展示[label:目标检测:概率+针对检测框的细分类(颜色情绪都行,此部分标签可自定义)]: 1.1目标检测+颜色识别 1.2人物检测+情绪检测 1.3针对特定类别的自动裁切效果+情绪识别结果 ...

  2. 综述 | 基于深度学习的目标检测算法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:计算机视觉life 导读:目标检测(Object Det ...

  3. 基于深度学习的目标检测算法综述(一)

    基于深度学习的目标检测算法综述(一) 基于深度学习的目标检测算法综述(二) 基于深度学习的目标检测算法综述(三) 本文内容原创,作者:美图云视觉技术部 检测团队,转载请注明出处 目标检测(Object ...

  4. 基于深度学习的目标检测综述

    基于深度学习的目标检测算法归类和总结 整体框架 目标检测算法 主要包括:[两阶段]目标检测算法.[多阶段]目标检测算法.[单阶段]目标检测算法 什么是两阶段目标检测算法,与单阶段目标检测有什么区别? ...

  5. 病虫害模型算法_基于深度学习的目标检测算法综述

    sigai 基于深度学习的目标检测算法综述 导言 目标检测的任务是找出图像中所有感兴趣的目标(物体),确定它们的位置和大小,是机器视觉领域的核心问题之一.由于各类物体有不同的外观,形状,姿态,加上成像 ...

  6. 基于深度学习的目标检测的研究进展2

    普通的深度学习监督算法主要是用来做分类,如图1(1)所示,分类的目标是要识别出图中所示是一只猫.而在ILSVRC(ImageNet Large Scale Visual Recognition Cha ...

  7. 基于深度学习的目标检测研究进展

    前言 开始本文内容之前,我们先来看一下上边左侧的这张图,从图中你看到了什么物体?他们在什么位置?这还不简单,图中有一个猫和一个人,具体的位置就是上图右侧图像两个边框(bounding-box)所在的位 ...

  8. 基于深度学习的目标检测算法综述(从R-CNN到Mask R-CNN)

    深度学习目标检测模型全面综述:Faster R-CNN.R-FCN和SSD 从RCNN到SSD,这应该是最全的一份目标检测算法盘点 基于深度学习的目标检测算法综述(一) 基于深度学习的目标检测算法综述 ...

  9. 【深度学习】基于深度学习的目标检测研究进展

    原文出处:http://chuansong.me/n/353443351445 开始本文内容之前,我们先来看一下上边左侧的这张图,从图中你看到了什么物体?他们在什么位置?这还不简单,图中有一个猫和一个 ...

最新文章

  1. Product Helper
  2. 第十一届河南省赛--H : Attack City and Capture Territory
  3. 如何给SAP C4C的产品主数据division配置出新的下拉选项
  4. 前端学习(338):堆栈
  5. 南工院linux考试题库,操作系统复习题..doc
  6. python课程水平测试成绩查询_学业水平考试成绩查询系统入口
  7. android fragment传递参数_fragment之间传值的两种方法
  8. 字节跳动自研重度游戏;中国移动前董事长谈飞信失败;Linux 5.3-rc3 发布 | 极客头条...
  9. Mac可读可写remount硬盘
  10. excel不显示0_Excel2007:Excel表格中完整输入身份证号码的几种方法
  11. 第六章 第一个Linux驱动程序:统计单词个数
  12. Windows下U盘无法格式化原因及解决办法:Windows无法完成格式化
  13. 网络工程师面试题收集
  14. 什么软件测试情侣头像,情侣头像另一半查找器
  15. 基于php+mysql的大学生四六级英语考试报名成绩管理
  16. 计算机后置音频接口,电脑后面音频插孔没声音怎么办?电脑后置插孔没声音的解决方案...
  17. 已知四边形的四个点,求一个点是否在四边形之内的解决方法
  18. 查询pytorch文档的实用方法
  19. 微信公众h5页面如何在pc端调试
  20. 蓝桥杯-魔方旋转问题

热门文章

  1. elf文件中代码段有绝对地址但重定位表中无.text.rel
  2. 某ARM服务器与X86服务器简单性能对比
  3. 局部性原理——各类优化的基石
  4. internet协议
  5. Android 简单SlidingTabLayout的用法
  6. 目标检测——day45 Deep Affinity Network for Multiple Object Tracking
  7. 刘涵 美国 西北大学 计算机,西北大学关于表彰2010-2011学年度学生先进集体.doc...
  8. jis拉伸试棒图纸_拉伸试棒延伸率快速测量装置制造方法
  9. 2.1.3原语对进程的控制
  10. 车载基础软件——AUTOSAR CP