最近开始利用Pytorch写一些深度学习的模型,没有系统的去学习pytorch的用法,也还没来得及去看别人的写法,先简单记录一些自己的想法。
利用Pytorch在写一些具有多个分支的模型时(比如具有特征融合、模型融合或者多个任务的网络结构),模型类该怎么写,loss会怎么传播,应该先将input融合再传入forward还是传入forward后再进行融合等问题。

特征融合

使用相同的模型对输入进行特征的提取,在输入FC层前进行特征融合,然后分类。

class _CCN(nn.Module):def __init__(self, num, p, cnt):super(_CCN, self).__init__()self.cnt = cntself.features = nn.Sequential(nn.Conv3d(1, num, 3, 1, 0, bias=False),nn.MaxPool3d(2, 1, 0),nn.BatchNorm3d(num),nn.LeakyReLU(),nn.Dropout(p),# ? * ? * ?)self.classifier = nn.Sequential(nn.Linear(self.cnt*num*?*?*?, 30),)# 传入forward的inpts为list,list里为tensor。def forward(self, inputs):input = inputs[0].cuda()out = self.features(input)for i in range(1, len(inputs)):input = inputs[i].cuda()x = self.features(input)out = torch.cat((out, x), dim=1)batch_size = out.shape[0]out = out.view(batch_size, -1)out = self.classifier(out)return out

如此写的话,每个input会共享卷积核权重,坏处在于如果每个input原本是想学习不同的特征,那么这个权重对每个input来讲就无法达到最优。如果每个input本身就是差不多的输入,那么这种共享就会减少参数。

模型融合

使用不同的模型对输入进行特征的提取,在输入FC层之前进行特征融合,然后分类。

目前的想法是在一个类里写两个模型,按上面那种模式写多个self.features,然后将input分别送入不同的self.features,得到结果以后进行融合在输入self.classifier。

class _CCN(nn.Module):def __init__(self, num, p, cnt):super(_CCN, self).__init__()self.cnt = cntself.features0 = nn.Sequential(nn.Conv3d(1, num, 3, 1, 0, bias=False),nn.MaxPool3d(2, 1, 0),nn.BatchNorm3d(num),nn.LeakyReLU(),nn.Dropout(p),# ? * ? * ?)self.features1 = nn.Sequential(nn.Conv3d(1, num, 3, 1, 0, bias=False),nn.MaxPool3d(2, 1, 0),nn.BatchNorm3d(num),nn.LeakyReLU(),nn.Dropout(p),# ? * ? * ?)self.classifier = nn.Sequential(nn.Linear(self.cnt*num*?*?*?, 30),)def forward(self, inputs):x0 = self.features0(inputs)x1 = self.features1(inputs)out = torch.cat((x0, x1), dim=1)batch_size = out.shape[0]out = out.view(batch_size, -1)out = self.classifier(out)return out

这样就解决掉了在特征融合中遇到的问题,如果多个input不想共享权重,就可以模型融合,或者写多个一样的特征提取的容器,可能还需要解决的问题是能否批量构造一堆一样的特征提取的容器,而不用写多个。

决策融合

训练不同的模型,得到输出以后进行加权求和等操作,将该结果作为最终分类结果。

class _CCN0(nn.Module):def __init__(self, num, p):super(_CCN0, self).__init__()self.features = nn.Sequential(nn.Conv3d(1, num, 3, 1, 0, bias=False),nn.MaxPool3d(2, 1, 0),nn.BatchNorm3d(num),nn.LeakyReLU(),nn.Dropout(p),# ? * ? * ?)self.classifier = nn.Sequential(nn.Linear(self.cnt*num*?*?*?, 30),)def forward(self, inputs):out = self.features(inputs)batch_size = out.shape[0]out = out.view(batch_size, -1)out = self.classifier(out)return outclass _CCN1(nn.Module):def __init__(self, num, p):super(_CCN1, self).__init__()self.features = nn.Sequential(nn.Conv3d(1, num, 3, 1, 0, bias=False),nn.MaxPool3d(2, 1, 0),nn.BatchNorm3d(num),nn.LeakyReLU(),nn.Dropout(p),# ? * ? * ?)self.classifier = nn.Sequential(nn.Linear(self.cnt*num*?*?*?, 30),)def forward(self, inputs):out = self.features(inputs)batch_size = out.shape[0]out = out.view(batch_size, -1)out = self.classifier(out)return out
class CCN_Wrapper():def __init__(self, fil_num,drop_rate,seed,batch_size,balanced,Data_dir,exp_idx,model_name):self.seed = seedself.exp_idx = exp_idxself.Data_dir = Data_dirself.model_name = model_nameself.eval_metric = get_accuself.batch_size = batch_sizeself.prepare_dataloader(batch_size, balanced, Data_dir)self.model0 = _CCN0(num=fil_num, p=drop_rate).cuda()self.model1 = _CCN1(num=fil_num, p=drop_rate).cuda()def train(self, lr, epochs):print("training ....")self.optimizer0 = optim.Adam(self.model0.parameters(), lr=lr, betas=(0.5, 0.999))self.optimizer1 = optim.Adam(self.model1.parameters(), lr=lr, betas=(0.5, 0.999))self.criterion = nn.CrossEntropyLoss().cuda()for self.epoch in range(epochs):self.train_model_epoch()valid_matrix = self.valid_model_epoch()def train_model_epoch(self):self.model.train(True)for inputs, labels in self.train_dataloader:inputs, labels = inputs.cuda(), labels.cuda()self.model.zero_grad()preds0 = self.model0(inputs)preds1 = self.model1(inputs)preds = weight0 * preds0 + weight1 * preds1loss0 = self.criterion(preds0, labels)loss1 = self.criterion(preds1, labels)loss0.backward()loss1.backward()self.optimizer0.step()self.optimizer1.step()def valid_model_epoch(self):with torch.no_grad():self.model.train(False)for inputs, labels in self.valid_dataloader:inputs, labels = inputs.cuda(), labels.cuda()preds0 = self.model0(inputs)preds1 = self.model1(inputs)preds = weight0 * preds0 + weight1 * preds1acc = get_acc(preds, labels)def test(self):print('testing ... ')self.model.load_state_dict(torch.load())self.model.train(False)with torch.no_grad():for stage in ['train', 'valid', 'test']:data = CCN_Data(self.Data_dir, self.exp_idx, stage=stage, seed=self.seed)dataloader = DataLoader(data, batch_size=10, shuffle=False)for idx, (inputs, labels) in enumerate(dataloader):inputs, labels = inputs.cuda(), labels.cuda()preds0 = self.model0(inputs)preds1 = self.model1(inputs)preds = weight0 * preds0 + weight1 * preds1acc = get_acc(preds, labels)def prepare_dataloader(self, batch_size, balanced, Data_dir):train_data = CCCN_Data(Data_dir, self.exp_idx, stage='train', seed=self.seed)valid_data = CCCN_Data(Data_dir, self.exp_idx, stage='valid', seed=self.seed)test_data = CCCN_Data(Data_dir, self.exp_idx, stage='test', seed=self.seed)self.train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)self.valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False)self.test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

上述代码其实就是两个模型分别输出结果,然后做一个加权平均,和深度学习、pytorch都没有关系。目前还在思考,如何在深度学习中将两个模型的loss进行加权平均再传播回去进行训练,如果按loss加权平均再反向传播的话应该是多任务学习了。

多任务学习

a = nn.CrossEntropyLoss()
b = nn.MSELoss()loss_a = a(output_x, x_labels)
loss_b = b(output_y, y_labels)loss = loss_a + loss_bloss.backward()

参考

https://discuss.pytorch.org/t/how-to-combine-multiple-criterions-to-a-loss-function/348/7

Timeline

2020.07.14 先记录这么多,后续学习一下再来补充和修改。

利用Pytorch写多分支网络的一些想法相关推荐

  1. 如何利用PyTorch写一个Transformer实现英德互译

    数据集中每一行是一对英语,德语句子对 Transformer模型出处:2017 <Attention is all you need> Transformer中的位置编码是什么意思? ht ...

  2. 基于PyTorch的生成对抗网络入门(3)——利用PyTorch搭建生成对抗网络(GAN)生成彩色图像超详解

    目录 一.案例描述 二.代码详解 2.1 获取数据 2.2 数据集类 2.3 构建判别器 2.3.1 构造函数 2.3.2 测试判别器 2.4 构建生成器 2.4.1 构造函数 2.4.2 测试生成器 ...

  3. 利用Pytorch中深度学习网络进行多分类预测(multi-class classification)

    从下面的例子可以看出,在 Pytorch 中应用深度学习结构非常容易 执行多类分类任务. 在 iris 数据集的训练表现几乎是完美的. import torch.nn as nn import tor ...

  4. 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络

    Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...

  5. 利用MatConvNet进行孪生多分支网络设计

    前面提及到了利用vl_nndist作为多分支网络的特征测度函数,将多个网络的局部输出融合到一起.参见博客:https://blog.csdn.net/shenziheng1/article/detai ...

  6. 利用Pytorch实现GoogLeNet网络

    目  录 1 GoogLeNet网络 1.1 网络结构及参数 1.2 Inception结构 1.3 带降维功能的Inception结构 1.4 辅助分类器 2 利用Pytorch实现GoogLeNe ...

  7. PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN

    PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN 目录 训练过程 代码设计 训练过程 代码设计 #PyTorch:利用PyTorch实现 ...

  8. 利用Pytorch实现ResNeXt网络

    目  录 1 ResNeXt网络介绍 1.1 组卷积 1.2 block 1.3 网络搭建 2 利用Pytorch实现ResNet网络 2.1 模型定义 2.2 训练结果 1 ResNeXt网络介绍 ...

  9. 利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结。

    利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结. 摘要 一.神经网络与卷积网络的对比 1.数据处理 2.对获取到的数据进行归一化和独热编码 二.开始我们的tensorflow神经 ...

最新文章

  1. 开源 免费 java CMS - FreeCMS-功能说明-操作日志
  2. mysql limit优化_MySQL:教你学会如何做性能分析与查询优化
  3. 微服务架构 为什么需要配置中心
  4. 【转】SAP Fiori Design Guidelines基础篇
  5. 001_汽车之家,新浪和360之间的交流
  6. js将docx转换为html,js 将word转换Html
  7. FedNLP: 首个联邦学习赋能NLP的开源框架,NLP迈向分布式新时代
  8. php 面向对象问题,PHP 面向对象开发的一些问题
  9. 学习 服务器部署 hello world
  10. linux telnet无法连接,奇怪的问题:telnet无法连接另一台server的正常的开放端口
  11. windows7 iis安装 Windows Modules Installer服务无法启动
  12. [蓝桥杯历届试题] 汉诺塔计数
  13. 计算机二级ms高级应用选择题,计算机二级考试MS-OFFICE高级应用选择题及答案
  14. Java线程的5种状态及切换(透彻讲解)-京东面试
  15. 台达plc使用c语言编程软件,台达PLC编程软件_台达PLC编程软件官方版下载[plc编程]-下载之家...
  16. Node.js文字与图片合成
  17. DTCC 2020 | 阿里云李飞飞:云原生分布式数据库与数据仓库系统点亮数据上云之路
  18. python中的token是什么
  19. 获取素材列表返回40004 invalid media type.获取公众号素材mediaId
  20. Meta-HAR: Federated Representation Learning for Human Activity Recognition

热门文章

  1. JavaPYB - 第二十天、第二十一天 - JavaWeb - Part 2: MySQL - 增、删、查、改、练习
  2. 数据库系统概论 第七章课后习题(部分)
  3. P1297 [国家集训队]单选错位
  4. 文章评论:“鞋服企业以销定产-零库存不难”
  5. 水位检测专用芯片VK36W1D电容式专用检水触控IC1-8点高灵敏度 4S自动校准功能
  6. 关于Git,你真的学会了吗?
  7. 包装类说明以及包装类的装箱和拆箱
  8. Lua语法 垃圾回收collectgarbage
  9. 从零开始学习Java
  10. 公司企业怎么设置域名邮箱?