在平常看一些卷积神经网络的时候,大多数都是直接通过写一个Model类来定义的,这样写的代码其实是比较好懂的,特别是在魔改网络的时候也很方便。然后也有一些会通过cfg配置文件进行模型的定义。在yolov5中可以看到是通过yaml文件进行网络的定义【个人感觉通过配置文件魔改网络有些不方便,当然每个人习惯不同】,可能很多人也用过,如果自己去写一个yaml文件,自己能不能定义出来呢?很多人不知道是如何具体通过yaml文件将里面的参数传入自己定义的网络中,这也就给自己修改网络带来了不便。这篇文章将仿照yolov5的方式,利用yaml定义一个自己的网络


定义卷积块

我们可以先定义一个卷积块CBL,C指卷积Conv,B指BN层,L为激活函数,这里我用ReLu.

class BaseConv(nn.Module):def __init__(self, in_channels, out_channels, k=1, s=1, p=None):super().__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.conv = nn.Conv2d(in_channels, out_channels, k, s, autopad(k, p))self.bn = nn.BatchNorm2d(out_channels)self.act_fn = nn.ReLU(inplace=True)def forward(self, x):return self.act_fn(self.bn(self.conv(x)))

卷积中的autopad是自动补充pad,代码如下:

def autopad(k, p=None):if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]return p

定义一个Bottleneck 

可以仿照yolov5定义一个Bottleneck,参考了残差块的思想。

class Bottleneck(nn.Module):def __init__(self, in_channels, out_channels, shortcut=True):super(Bottleneck, self).__init__()self.conv1 = BaseConv(in_channels, out_channels, k=1, s=1)self.conv2 = BaseConv(out_channels, out_channels, k=3, s=1)self.add = shortcut and in_channels == out_channelsdef forward(self, x):"""x-->conv1-->conv2-->add|_________________|"""return x + self.conv2(self.conv1(x)) if self.add else self.conv2(self.conv1(x))

攥写yaml配置文件

然后我们来写一下yaml配置文件,网络不要很复杂,就由两个卷积和两个Bottleneck组成就行。同理,仿v5的方法,我们的网络中的backone也是个列表,每行为一个卷积层,每列有4个参数,分别代表from(指该层的输入通道数为上一层的输出通道数,所以是-1),number【yaml中的1,1,2指该层的深度,或者说是重复几次】,Module_nams【该层的名字】,args【网络参数,包含输出通道数,k,s,p等设置】

# define own model
backbone:[[-1, 1, BaseConv, [32, 3, 1]],  # out_channles=32, k=3, s=1[-1, 1, BaseConv, [64, 1, 1]],[-1, 2, Bottleneck, [64]]]

我们现在用yaml工具来打开我们的配置文件,看看都有什么内容

    import yaml# 获得yaml文件名字yaml_file = Path('Model.yaml').namewith open(yaml_file,errors='ignore') as f:yaml_ = yaml.safe_load(f)print(yaml_)

输出:

{'backbone': [[-1, 1, 'BaseConv', [32, 3, 1]], [-1, 1, 'BaseConv', [64, 1, 1]], [-1, 2, 'Bottleneck', [64]]]}

然后我们可以定义下自己Model类,也就是定义自己的网络。可以看到与前面读取yaml文件相比,多了一行    ch = self.yaml["ch"] = self.yaml["ch"] = 3   这个是在原yaml内容中加入一个key和valuse,3指的3通道,因为我们的图像是3通道。parse_model是下面要说的传参过程。

class Model(nn.Module):def __init__(self, cfg='./Model.yaml', ch=3, ):super().__init__()self.yaml = cfgimport yamlyaml_file = Path(cfg).namewith open(yaml_file, errors='ignore')as f:self.yaml = yaml.safe_load(f)ch = self.yaml["ch"] = self.yaml["ch"] = 3self.backbone = parse_model(deepcopy(self.yaml), ch=[ch])def forward(self, x):output = self.backbone(x)return output

传入参数

这一步也是最关键的一步,我们需要定义传参的函数,将yaml中的卷积参数传入我们定义的网络中,这里会用的一个非常非常重要的函数eval(),后面也会介绍到这个函数的用法。

这里先附上完整代码:

def parse_model(yaml_cfg, ch):""":param yaml_cfg: yaml file:param ch: init in_channels default is 3:return: model"""layer, out_channels = [], ch[-1]for i, (f, number, Module_name, args) in enumerate(yaml_cfg['backbone']):"""f:上一层输出通道number:该模块有几层,就是该模块要重复几次Mdule_name:卷积层名字args:参数,包含输出通道数,k,s,p等"""# 通过eval,将str类型转自己定义的BaseConvm = eval(Module_name) if isinstance(Module_name, str) else Module_namefor j, a in enumerate(args):# 通过eval,将str转int,获得输出通道数args[j] = eval(a) if isinstance(a, str) else a# 更新通道# args[0]是输出通道if m in [BaseConv, Bottleneck]:in_channels, out_channels = ch[f], args[0]args = [in_channels, out_channels, *args[1:]]  # args=[in_channels, out_channels, k, s, p]# 将参数传入模型model_ = nn.Sequential(*[m(*args) for _ in range(number)]) if number > 1 else m(*args)# 更新通道列表,每次获取输出通道ch.append(out_channels)layer.append(model_)return nn.Sequential(*layer)

下面开始分析代码 。

这行代码是通过列表用来存放每层内容以及输出通道数。

# 这行代码是通过列表用来存放每层内容以及输出通道数
layer, out_channels = [], ch[-1]

然后进入我们的for循环,在每一次循环中可以获得我们yaml文件中的每一层网络:f是上一层网络的输出通道【用来作为本层的输入通道】,number【网络深度,也就是该层重复几次而已】,Module_name是该层的名字,args是该层的一些参数。

for i, (f, number, Module_name, args) in enumerate(yaml_cfg['backbone']):

接下来会碰到一个很重要的函数eval()。下行的代码首先需要判断一下我们的Module_name类型是不是字符串类型,也就是判断一下yaml中“BaseConv”是不是字符串类型,如果是,则用eval进行对应类型的转化,转成我们的BaseConv类型

m = eval(Module_name) if isinstance(Module_name, str) else Module_name

这里我将对eval函数在深入点,如果知道这个函数用法的,就可以略去这部分。

我们先举个例子,比如我现在有个变量a="123",这个a的类型是什么呢?他是一个str类型,不是int类型。 现在我们用eval函数转一下,看看会变成什么样子。

>>> b = eval(a) if isinstance(a,str) else a
>>> b
123
>>> type(b)
<class 'int'>

我们可以看到,经过eval函数以后,会自动识别并转为int类型。那么我继续举例子,如果现在a="BaseConv",经过eval以后会变成什么?可以看到,这里报错了!这是为什么?这是因为我们没有导入BaseConv这个类,所以eval函数并不知道我们希望转为什么类型。所以我们需要用import导入BaseConv这个类才可以。

>>> a="BaseConv"
>>> b = eval(a) if isinstance(a,str) else a
Traceback (most recent call last):File "<stdin>", line 1, in <module>File "<string>", line 1, in <module>
NameError: name 'BaseConv' is not defined

当我们导入BaseConv以后,在经过eval就可以获得:

<class 'models.BaseConv'>


接下来是获得args中的网络参数,也是通过eval进行转化

        for j, a in enumerate(args):# 通过eval,将str转int,获得输出通道数args[j] = eval(a) if isinstance(a, str) else a

获取通道数,并在每次循环中对通道进行更新:可以仔细看一下ch[f]指的上一层输出通道,刚开始默认为[3],那么ch[-1]=3,我们yaml中第一层的BaseConv args[0]为32,表示输出32通道。因此在第一次循环中有in_channels = 3,out_channels=32。args也要更新,*args前面的"*"并不是指针的意思,也不是乘的意思,而是解压操作,因此我们第一次循环中得到的args=[3,32,3,1]。

# 更新通道
# args[0]是输出通道
if m in [BaseConv, Bottleneck]:in_channels, out_channels = ch[f], args[0]args = [in_channels, out_channels, *args[1:]]  # args=[in_channels, out_channels, k, s, p]

将参数传入模型

这里用for _ in range(number)来判断网络的深度【或者说该模块重复几次】,这里的m就是前面经过eval转化的 <class 'models.BaseConv'>。通过*args解压操作将args列表中的内容放入m中,再通过*解压操作放入nn.Sequential。

model_ = nn.Sequential(*[m(*args) for _ in range(number)]) if number > 1 else m(*args)

这样就可以获得我们第一次循环BaseConv了。后面的循环也是同样的反复操作而已。

BaseConv(
  (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_fn): ReLU(inplace=True)
)

然后是更新通道列表和layer列表,为的是获取每次循环的输出通道,没有这一步,再下一次循环的时候将不能正确得到通道数。

# 更新通道列表,每次获取输出通道
ch.append(out_channels)
layer.append(model_)

然后我们就可以对模型调用进行实例化了,可以打印下模型:

Model((backbone): Sequential((0): BaseConv((conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act_fn): ReLU(inplace=True))(1): BaseConv((conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act_fn): ReLU(inplace=True))(2): Sequential((0): Bottleneck((conv1): BaseConv((conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act_fn): ReLU(inplace=True))(conv2): BaseConv((conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act_fn): ReLU(inplace=True)))(1): Bottleneck((conv1): BaseConv((conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act_fn): ReLU(inplace=True))(conv2): BaseConv((conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(act_fn): ReLU(inplace=True)))))
)

同时我们也可以对模型每层可视化看一下。可以看到和我们定义的模型是一样的。


上述完整的代码:

from copy import deepcopyfrom models import BaseConv, Bottleneck
import torch.nn as nn
import ospath = os.getcwd()
from pathlib import Path
import torchdef parse_model(yaml_cfg, ch):""":param yaml_cfg: yaml file:param ch: init in_channels default is 3:return: model"""layer, out_channels = [], ch[-1]for i, (f, number, Module_name, args) in enumerate(yaml_cfg['backbone']):"""f:上一层输出通道number:该模块有几层,就是该模块要重复几次Mdule_name:卷积层名字args:参数,包含输出通道数,k,s,p等"""# 通过eval,将str类型转自己定义的BaseConvm = eval(Module_name) if isinstance(Module_name, str) else Module_namefor j, a in enumerate(args):# 通过eval,将str转int,获得输出通道数args[j] = eval(a) if isinstance(a, str) else a# 更新通道# args[0]是输出通道if m in [BaseConv, Bottleneck]:in_channels, out_channels = ch[f], args[0]args = [in_channels, out_channels, *args[1:]]  # args=[in_channels, out_channels, k, s, p]# 将参数传入模型model_ = nn.Sequential(*[m(*args) for _ in range(number)]) if number > 1 else m(*args)# 更新通道列表,每次获取输出通道ch.append(out_channels)layer.append(model_)return nn.Sequential(*layer)class Model(nn.Module):def __init__(self, cfg='./Model.yaml', ch=3, ):super().__init__()self.yaml = cfgimport yamlyaml_file = Path(cfg).namewith open(yaml_file, errors='ignore')as f:self.yaml = yaml.safe_load(f)ch = self.yaml["ch"] = self.yaml["ch"] = 3self.backbone = parse_model(deepcopy(self.yaml), ch=[ch])def forward(self, x):output = self.backbone(x)return outputif __name__ == "__main__":cfg = path + '/Model.yaml'model = Model()model.eval()print(model)x = torch.ones(1, 3, 512, 512)output = model(x)torch.save(model, "model.pth")# model = torch.load('model.pth')# model.eval()# x = torch.ones(1,3,512,512)# input_name = ['input']# output_name = ['output']# torch.onnx.export(model, x, 'myonnx.onnx', verbose=True)

利用yaml定义卷积网络【附代码】相关推荐

  1. 经典卷积网络——DenseNet代码实现

    题目:Densely Connected Convolutional Networks 论文地址:https://arxiv.org/pdf/1608.06993.pdf 常见的卷积网络结构对比:   ...

  2. Graph Convolution Network图卷积网络(二)数据加载与网络结构定义

    背景 : 弄懂Graph Convolution Network的pytorch代码如何加载数据并且如何定义网络结构的. 代码地址:https://github.com/tkipf/pygcn 论文地 ...

  3. 图卷积网络 GCN Graph Convolutional Network(谱域GCN)的理解和详细推导

    文章目录 1. 为什么会出现图卷积神经网络? 2. 图卷积网络的两种理解方式 2.1 vertex domain(spatial domain):顶点域(空间域) 2.2 spectral domai ...

  4. 深度卷积网络CNN与图像语义分割

    转载请注明出处:  http://xiahouzuoxin.github.io/notes/html/深度卷积网络CNN与图像语义分割.html 级别1:DL快速上手 级别2:从Caffe着手实践 级 ...

  5. 图卷积 节点分类_在节点分类任务上训练图卷积网络

    图卷积 节点分类 This article goes through the implementation of Graph Convolution Networks (GCN) using Spek ...

  6. java 递归_采用递归算法求解迷宫问题(Java版) | 附代码+视频

    递归算法能够解决很多计算机科学问题,迷宫问题就是其中一个典型案例.本篇教程我们将采用递归算法求解迷宫问题,输出从入口到出口的所有迷宫路径. 01 用递归算法解决迷宫问题 迷宫问题在<数据结构教程 ...

  7. 论文翻译 SGCN:Sparse Graph Convolution Network for Pedestrian Trajectory Prediction 用于行人轨迹预测的稀疏图卷积网络

    SGCN:Sparse Graph Convolution Network for Pedestrian Trajectory Prediction 用于行人轨迹预测的稀疏图卷积网络 行人轨迹预测是自 ...

  8. python抓取朋友圈动态_如何利用Python网络爬虫爬取微信朋友圈动态--附代码(下)...

    原标题:如何利用Python网络爬虫爬取微信朋友圈动态--附代码(下) 前天给大家分享了如何利用Python网络爬虫爬取微信朋友圈数据的上篇(理论篇),今天给大家分享一下代码实现(实战篇),接着上篇往 ...

  9. 关系抽取:图卷积网络的学习(二)(附代码)

    关系抽取:图卷积网络的学习(二)(附代码) 目录 关系抽取:图卷积网络的学习(二)(附代码) 论文一:基于关系图的实体关系联合抽取 摘要 1.Introduction 2.Motivation 3.G ...

最新文章

  1. 如何把手变成手控_在这个模拟手的VR游戏里,你能体验到很多手控福利
  2. Java 8 一行代码解决了空指针问题,太厉害了...
  3. Filter学习(一)
  4. 双塔模型没效果了?请加大加粗!
  5. leetcode1292. 元素和小于等于阈值的正方形的最大边长(二分法+前缀和)
  6. 好的软件人员必看的书
  7. pytorch Resnet
  8. C++技巧:用kdevelop进行交叉编译的方法
  9. Linux 探索之旅 | 第五部分第七课:Shell 实现图片展示网页
  10. javascript跨域、iframe跨域访问
  11. Layui 监听 复选框 提交表单
  12. oracle 安全备份与rman_Oracle RMAN备份与还原注意事项
  13. 论基于架构的软件设计方法及应用
  14. 服装erp系统的设计方案
  15. ArcEngine实现要素类排序的四种方法
  16. c语言:“有一个已排好序的数组,要求输入一个数后,按原来的规律将它插入数组中” 的程序分析及详细代码
  17. Exsel 设置固定表头
  18. SpringBoot Hanlp的集成
  19. windows10无法连接远程桌面的问题的解决方法
  20. 分享几个有趣实用的网站

热门文章

  1. 阿里云域名解析与绑定服务器IP地址—域名和端口访问自己的web网站
  2. 高职高专计算机类教师招聘计划,2021年这类教师招聘8.4万人,专科生的福利,服务期满就能入编?...
  3. 店大欺客的微信与腾信
  4. jquery判断字符串的长度,中英文都可
  5. vue项目 运行的时候报To install it, you can run: npm install --save ..\common\css\common.scss
  6. 软件项目怎么快速响应用户需求
  7. oracle用户常用权限,Oracle用户权限分配的具体方法【常用财务软件使用教程】
  8. css基础 层叠样式表 选择器
  9. 意法半导体MCU微控制器技术突破在哪?
  10. Ex20b示例程序:DLL测试客户程序