利用深度学习的方法进行图像分类及目标检测时,由于自己数据集可能相对较小,直接利用自己的数据集进行网络的训练,非常容易过拟合。在迁移学习中,我们首先在一个基础数据集和基础任务上训练一个基础网络,然后我们再微调一下学到的特征,或者说将它们迁移到第二个目标网络中,用目标数据集和目标任务训练网络。

对于计算机视觉领域的图像特征提取的基础卷积网络backbone往往采用在ImageNet数据集上训练得到的预训练模型,在torchvision中存储了基础常用的基础网络的网络结构及预训练参数,例如VGG,inception,resnet,densenet,shufflenet,squeezenet等,可以直接调用,但是在目标检测中,我们往往仅仅使用这些网络的一些底部的层,上边的一些卷积层需要根据自己依据实际问题设计;或者在图像分类中,结果输出的类别往往不会正好与给定的相同,因而都需要对原始网络及参数进行修改

torchvision的预训练模型地址​github.com

1、pytorch中的预训练模型

在上边的连接地址中有各个基础网络模型的程序源码,还有预训练模型参数文件的下载地址,使用时主要采用下边代码块中的文件直接引用

import torchvision.models as models#resnet
model = models.resnet50(pretrained=True)#vgg
model = models.vgg16(pretrained=True)

2、预训练模型的修改

在实际的应用中,往往不能直接应用导出的网络模型,均需要对网络的模型进行一定的修改才能使用,例如在图像分类任务中,我们任务的分类数目不一定刚好与imagenet数据集模型类别一致(1000类),在使用时就需要自己对模型参数进行修改(或者直接修改最后一层,然后自己建立合适的层进行自己算法的分类操作)。

在目标检测中往往自己需要建立特征提取的基础网络,需要有选择的建立适合自己特定需求的网络,然后利用与预训练模型结构相同的部分进行初始化,加快模型的收敛速度。

1)参数的修改(分类网络最后一层参数)

以自己的分类数据集10类为例,直接传入入类别数目这一参数

# -*- coding:utf-8 -*-
import torchvision.models as models#调用模型
model = models.resnet50(pretrained=True,num_classes = 10)
print(model)

将最后一层fc的参数直接进行修改,对类的属性进行直接修改,替换掉fc这一层:

# -*- coding:utf-8 -*-
import torchvision.models as models#调用模型
model = models.resnet50(pretrained=True)
#提取fc层中固定的参数(这一层的输入节点数目
fc_features = model.fc.in_features
#修改类别为9,(直接对类的属性进行修改)
model.fc = nn.Linear(fc_features, 10)

2)增减卷积层

对于前一种方法,仅适用于在分类问题时对简单层进行修改,但是对于目标检测想要利用预训练模型进行fine-tune该怎么办呢?这就要学会利用增减卷积层进行预训练模型的使用了。这两天刚好需要 修改一个基础特征提取的网络结构,但是自己的数据集又太小,直接修改网络从头训练容易出现过拟合,必须使用预训练的网络模型进行训练,就研究了一下使用方法,总结如下:

基本思想就是:先建立好自己的网络(与预训练的模型类似,要不谈何fine-tune),然后将

削减卷积层

基本思想就是:

1、先建立好自己的网络(与预训练的模型类似,要不谈何fine-tune)

2、然后将预训练模型参数与自己搭建的网络不一致的部分参数去掉

3、将保留的合适的参数读入网络初始化,实现fine-tune的效果

基础网络准备使用resnet152的前143层(为什么是143,这里直接去掉了最后一个残差块)

# -*- coding:utf-8 -*-
#####################
#建立自己的网络模型net
########################然后读出预训练模型参数以resnet152为例,我不是利用程序下载的,我是习惯了下载好存储在文件夹中
pretrained_dict = torch.load(save_path)
model_dict = net.state_dict()   (读出搭建的网络的参数,以便后边更新之后初始化)####去除不属于model_dict的键值
pretrained_dict={ k : v for k, v in pretrained_dict.items() if k in model_dict}###更新现有的model_dict的值
model_dict.update(pretrained_dict)##加载模型需要的参数
net.load_state_dict(model_dict)

介绍到这儿,相信应该对模型读入参数的方法load_state_dict()有了清晰的了解,就是利用键值的对应关系读入参数进行初始化网络,键值不对应会报错

知道了如何读入参数,那么增加卷积层的操作应该已经能够理解了。

增加卷积层

上边介绍了削减卷积层,那么如果我想在后边累加一些卷积层怎么办呢,这个更简单了,直接对应将前边的卷积层参数利用预训练模型初始化,后边添加的卷积层利用nn.init.kaiming_normal_进行初始化

#torch.nn.init.uniform_(tensor, a=0, b=1)
#从均匀分布U(a, b)中生成值,填充输入的张量或变量
#参数:
#tensor - n维的torch.Tensor
#a - 均匀分布的下界
#b - 均匀分布的上界
#例子
w = torch.Tensor(3, 5)
nn.init.uniform_(w)#torch.nn.init.normal_(tensor, mean=0, std=1)
#从给定均值和标准差的正态分布N(mean, std)中生成值,填充输入的张量或变量
#参数:
#tensor – n维的torch.Tensor
#mean – 正态分布的均值
#std – 正态分布的标准差
#例子
w = torch.Tensor(3, 5)
nn.init.normal_(w)#torch.nn.init.constant(tensor, val)
#用val的值填充输入的张量或变量#torch.nn.init.eye_(tensor)
#用单位矩阵来填充2维输入张量或变量。在线性层尽可能多的保存输入特性#####################################这个用的比较多
#torch.nn.init.xavier_uniform_(tensor, gain=1)
#根据Glorot, X.和Bengio, Y.在“Understanding the difficulty of training deep feedforward neural networks”中描述的方法,用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自U(-a, a),其中a= gain * sqrt( 2/(fan_in + fan_out))* sqrt(3). 该方法也被称为Glorot initialisation
#参数:
#tensor – n维的torch.Tensor
#gain - 可选的缩放因子
#例子:w = torch.Tensor(3, 5)nn.init.xavier_uniform_(w, gain=math.sqrt(2.0))#torch.nn.init.xavier_normal_(tensor, gain=1)
#根据Glorot, X.和Bengio, Y. 于2010年在“Understanding the difficulty of training deep feedforward neural networks”中描述的方法,用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自均值为0,标准差为gain * sqrt(2/(fan_in + fan_out))的正态分布。也被称为Glorot initialisation#torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in')
#根据He, K等人于2015年在“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification”中描述的方法,用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自U(-bound, bound),其中bound = sqrt(2/((1 + a^2) * fan_in)) * sqrt(3)。也被称为He initialisation.
#tensor – n维的torch.Tensor或autograd.Variable
#a -这层之后使用的rectifier的斜率系数(ReLU的默认值为0)
#mode -可以为“fan_in”(默认)或“fan_out”。“fan_in”保留前向传播时权值方差的量级,“fan_out”保留反向传播时的量级。.
################################这个用的较多
#torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in')
#根据He, K等人在“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification”中描述的方法,用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自均值为0,标准差为sqrt(2/((1 + a^2) * fan_in))的正态分布。#tensor – n维的torch.Tensor或 autograd.Variable
#a -这层之后使用的rectifier的斜率系数(ReLU的默认值为0)
#mode -可以为“fan_in”(默认)或“fan_out”。“fan_in”保留前向传播时权值方差的量级,“fan_out”保留反向传播时的量级。

在中间插入卷积层呢,又该怎么办呢?

利用state-dict的键值对应读入原始的参数,新添加的层利用合适的方式直接初始化就可以。

总之,增加卷积层的操作,就是将原有的卷积层参数按照对应的键值读入,将新添加的层利用nn.init.kaiming_normal_进行初始化就可以

看到这儿,应该已经可以利用torchvision的预训练模型轻松的进行fine-tune了。

ubuntu使用pytorch训练出现killed_目标检测之pytorch预训练模型的使用(削减削减网络层,修改参数)fine-tune技巧...相关推荐

  1. PyTorch实现 | 车牌OCR识别,《PyTorch深度学习之目标检测》

    注:本文选自中国水利水电出版社出版<PyTorch深度学习之目标检测>一书,有改动 福利!免费寄送图书!! 公众号[机器学习与AI生成创作]后台回复:168.即可参与免费寄送图书活动,活动 ...

  2. 睿智的目标检测61——Pytorch搭建YoloV7目标检测平台

    睿智的目标检测61--Pytorch搭建YoloV7目标检测平台 学习前言 源码下载 YoloV7改进的部分(不完全) YoloV7实现思路 一.整体结构解析 二.网络结构解析 1.主干网络Backb ...

  3. 睿智的目标检测66——Pytorch搭建YoloV8目标检测平台

    睿智的目标检测66--Pytorch搭建YoloV8目标检测平台 学习前言 源码下载 YoloV8改进的部分(不完全) YoloV8实现思路 一.整体结构解析 二.网络结构解析 1.主干网络Backb ...

  4. 目标检测-基于Pytorch实现Yolov3(1)- 搭建模型

    原文地址:https://www.cnblogs.com/jacklu/p/9853599.html 本人前段时间在T厂做了目标检测的项目,对一些目标检测框架也有了一定理解.其中Yolov3速度非常快 ...

  5. 使用PyTorch从零开始实现YOLO-V3目标检测算法 (一)

    原文:https://blog.csdn.net/u011520516/article/details/80222743 点击查看博客原文 标检测是深度学习近期发展过程中受益最多的领域.随着技术的进步 ...

  6. 使用pytorch从零开始实现YOLO-V3目标检测算法 (二)

    原文:https://blog.csdn.net/u011520516/article/details/80212960 博客翻译 这是从零开始实现YOLO v3检测器的教程的第2部分.在上一节中,我 ...

  7. 使用yolov5训练自己的目标检测数据集

    使用yolov5训练自己的目标检测数据集 yolov4出来后不久,又出现了yolov5,没有论文.虽然作者没有放上和yolov4的直接测试对比,但在COCO数据集的测试效果还是很可观的.很多人考虑到Y ...

  8. 睿智的目标检测30——Pytorch搭建YoloV4目标检测平台

    睿智的目标检测30--Pytorch搭建YoloV4目标检测平台 学习前言 什么是YOLOV4 代码下载 YOLOV4改进的部分(不完全) YOLOV4结构解析 1.主干特征提取网络Backbone ...

  9. 睿智的目标检测35——Pytorch搭建YoloV4-Tiny目标检测平台

    睿智的目标检测35--Pytorch搭建YoloV4-Tiny目标检测平台 学习前言 什么是YOLOV4-Tiny 代码下载 YoloV4-Tiny结构解析 1.主干特征提取网络Backbone 2. ...

  10. 睿智的目标检测36——Pytorch搭建Efficientdet目标检测平台

    睿智的目标检测33--Pytorch搭建Efficientdet目标检测平台 学习前言 什么是Efficientdet目标检测算法 源码下载 Efficientdet实现思路 一.预测部分 1.主干网 ...

最新文章

  1. 妙用postman系列——postman建组、分享
  2. TensorFlow实战Google深度学习框架5-7章学习笔记
  3. c语言中的数组二分法排序程序,#C语言#二分法查找有序数组
  4. 普通软件项目开发过程规范(五)—— 总结
  5. [转载] Python中不可变集合的使用frozenset()方法
  6. 数值分析(2)-多项式插值: 拉格朗日插值法
  7. 假设法求最大值和数组的优点
  8. (2)机器学习_train_test_split
  9. 关于SQL SERVER 2005 开发版
  10. c# json转对象
  11. 管家婆普及版_昆明逸马软件 — 管家婆普及版新手入门指南
  12. WES7 SKU WES7E和WES7P的区别
  13. macOS 安装老旧版本的 adobe 应用
  14. 2017年第八届CSTQB®国际软件测试高峰论坛日程发布
  15. 集线器、交换机、路由器、网桥、网关之间的区别
  16. 什么是邮箱域名,企业邮箱域名有什么好处?
  17. Maven插件列表_Maven插件查询_Maven插件查看
  18. Python和R的GUI图形化编程与用户界面
  19. conversion failed: could not load input document
  20. Hadoop之HDFS常见面试题

热门文章

  1. 欧盛K7儿童手机,全面保护青少年儿童身心健康
  2. 我的站(艾网---城市生活新门户)重新上线了
  3. 人工智能与深度学习概念(2)——人工神经网络-ANN
  4. python函数的面向对象——面向对象设计
  5. mybatis使用oracle自动生成主键
  6. HightChar图表控件
  7. 终于 知道为什么datagrid有时候翻页要双击了...
  8. 【内推】AI独角兽-数美科技-NLP/CV/ASR等开放百余岗位,薪资诱人
  9. 【清华大学-腾讯】关系提取综述,Review and Outlook for Relation Extraction
  10. 3月面经汇总-字节跳动,美团,腾讯算法岗