这里写目录标题

  • 1.集成学习方法
  • 2.深度学习中的集成学习
    • Dropout
    • TTA
    • Snapshot

1.集成学习方法

在机器学习中的集成学习可以在一定程度上提高预测精度,常见的集成学习方法有Stacking、Bagging和Boosting,同时这些集成学习方法与具体验证集划分联系紧密。

由于深度学习模型一般需要较长的训练周期,如果硬件设备不允许建议选取留出法,如果需要追求精度可以使用交叉验证的方法。

例如构建了10折交叉验证,训练得到10个CNN模型。那么在10个CNN模型可以使用如下方式进行集成:
对预测的结果的概率值进行平均,然后解码为具体字符;对预测的字符进行投票,得到最终字符。

2.深度学习中的集成学习

Dropout

Dropout可以作为训练深度神经网络的一种技巧。在每个训练批次中,通过随机让一部分的节点停止工作。同时在预测的过程中让所有的节点都其作用。Dropout经常出现在在先有的CNN网络中,可以有效的缓解模型过拟合的情况,也可以在预测时增加模型的精度。

代码使用nn.Dropout

class SVHN_Model1(nn.Module):def __init__(self):super(SVHN_Model1, self).__init__()# CNN提取特征模块self.cnn = nn.Sequential(nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2)),nn.ReLU(),nn.Dropout(0.25),nn.MaxPool2d(2),nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2)),nn.ReLU(), nn.Dropout(0.25),nn.MaxPool2d(2),)# self.fc1 = nn.Linear(32*3*7, 11)self.fc2 = nn.Linear(32*3*7, 11)self.fc3 = nn.Linear(32*3*7, 11)self.fc4 = nn.Linear(32*3*7, 11)self.fc5 = nn.Linear(32*3*7, 11)self.fc6 = nn.Linear(32*3*7, 11)def forward(self, img):        feat = self.cnn(img)feat = feat.view(feat.shape[0], -1)c1 = self.fc1(feat)c2 = self.fc2(feat)c3 = self.fc3(feat)c4 = self.fc4(feat)c5 = self.fc5(feat)c6 = self.fc6(feat)return c1, c2, c3, c4, c5, c6

TTA

测试集数据扩增(Test Time Augmentation,简称TTA)也是常用的集成学习技巧,数据扩增不仅可以在训练时候用,而且可以同样在预测时候进行数据扩增,对同一个样本预测三次,然后对三次结果进行平均。

def predict(test_loader, model, tta=10):model.eval()test_pred_tta = None# TTA 次数for _ in range(tta):test_pred = []with torch.no_grad():for i, (input, target) in enumerate(test_loader):c0, c1, c2, c3, c4, c5 = model(data[0])output = np.concatenate([c0.data.numpy(), c1.data.numpy(),c2.data.numpy(), c3.data.numpy(),c4.data.numpy(), c5.data.numpy()], axis=1)test_pred.append(output)test_pred = np.vstack(test_pred)if test_pred_tta is None:test_pred_tta = test_predelse:test_pred_tta += test_predreturn test_pred_tta

Snapshot

本章的开头已经提到,假设我们训练了10个CNN则可以将多个模型的预测结果进行平均。但是加入只训练了一个CNN模型,如何做模型集成呢?

在论文Snapshot Ensembles中,作者提出使用cyclical learning rate进行训练模型,并保存精度比较好的一些checkopint,最后将多个checkpoint进行模型集成。

由于在cyclical learning rate中学习率的变化有周期性变大和减少的行为,因此CNN模型很有可能在跳出局部最优进入另一个局部最优。在Snapshot论文中作者通过使用表明,此种方法可以在一定程度上提高模型精度,但需要更长的训练时间。

Datawhale 零基础入门CV赛事-Task5 模型集成相关推荐

  1. Datawhale 零基础入门CV赛事-Task4 模型训练与验证

    文章目录 1.构造验证集 2.模型训练与验证 1.构造验证集 在机器学习模型(特别是深度学习模型)的训练过程中,模型是非常容易过拟合的.深度学习模型在不断的训练过程中训练误差会逐渐降低,但测试误差的走 ...

  2. Datawhale 零基础入门CV赛事-Task3 字符识别模型

    文章目录 1.CNN实现 2.Pytorch实现CNN 3.使用ImangeNet预训练模型 1.CNN实现 CNN基础 2.Pytorch实现CNN 构建一个简单的CNN模型和训练过程 import ...

  3. Datawhale零基础入门NLP赛事 - Task5 基于深度学习的文本分类2

    在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习基于深度学习的文本分类. 学习目标 学 ...

  4. Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增

    文章目录 数据读取 图像读取 1.pillow 2.opencv 数据读取 数据扩增 数据读取 导入需要的包以及文件路径 import json, glob import numpy as np fr ...

  5. 阿里云天池竞赛-零基础入门CV赛事-Task4 模型训练与验证

    在上一章节我们构建了一个简单的CNN进行训练,并可视化了训练过程中的误差损失和第一个字符预测准确率,但这些还远远不够.一个成熟合格的深度学习训练流程至少具备以下功能: 在训练集上进行训练,并在验证集上 ...

  6. 零基础入门CV赛事,理论结合实践

    Datawhale干货 作者:阿水,Datawhale成员 本次分享的背景是,Datawhle联合天池发布的学习赛:零基础入门CV赛事之街景字符识别.本文以该比赛为例,对计算机视觉赛事中,赛事理解和B ...

  7. 零基础入门CV赛事- 街景字符编码识别

    零基础入门CV赛事- 街景字符编码识别 Task01 学习目标 数据介绍 Task01任务内容 数据读取 解题思路 学习目标 熟悉计算机视觉赛事 完成典型的字符识别问题 掌握CV领域赛事的编程和解题思 ...

  8. Datawhale零基础入门NLP day5/Task5基于深度学习的文本分类2

    基于深度学习的文本分类 本章将继续学习基于深度学习的文本分类. 学习目标 学习Word2Vec的使用和基础原理 学习使用TextCNN.TextRNN进行文本表示 学习使用HAN网络结构完成文本分类 ...

  9. 零基础入门CV赛事—街景字符编码识别—task2数据读取与扩增

    数据读取与扩增 上节学习了街景字符编码识别的解题思路,让我们对本赛题有了基本的idea,这节在定长字符编码的思路基础上学习读取数据和数据扩增. 图像数据读取 由于赛题数据是图像数据,赛题的任务是识别图 ...

最新文章

  1. 如何使用React提前三天计划
  2. Factory - 工厂模式
  3. Spark程序运行常见错误解决方法以及优化
  4. c#样条曲线命令_如何定制CAD功能区界面中的命令?
  5. Windows上安装Mysql解压缩版教程
  6. 小鹤双拼记忆口诀_选择双拼之自然码
  7. xenserver 虚拟机扩容lvm磁盘分区的方法_Linux磁盘扩容
  8. LeetCode meituan-003. 小美的跑腿代购(排序)
  9. ruby在类中访问@,类外访问调用方法
  10. python office库使用_看完这篇Python操作PPT总结,从此使用Python玩转Office全家桶就没有压力了!...
  11. 核函数与径向基函数 (Radial Basis Function 简称 RBF)详解
  12. 关闭Apple Watch 上的激活锁的方法
  13. 这根网线真奇怪——笔记本可用,台式机不可用(另一端重压水晶头后可以)
  14. 复习高数下册8-10章主要内容(简略版)
  15. JDK 商用正式免费、Log4j2 爆核弹级漏洞、LayUI 下线...2021 发生的 10 件大事。。。
  16. 灵猫二维码 - 二维码中间加图片的方法
  17. APP第一次请求HTTPS慢
  18. 硬盘变成RAW 修复
  19. Window10系统电脑开机超级慢的解决方法(从180秒提升至12秒)
  20. 调节阀卡塞的处理方法

热门文章

  1. margin与padding的bug
  2. Python环境搭建之OpenCV(转载)
  3. 这样讲闭包,你终生难忘
  4. android开发(44) 使用了 SoundPool 播放提示音
  5. SVN安装,SVN服务搭建与eclipse里插件安装
  6. iframe design=on 时,oncontextmeun不能触发之问题!
  7. mysql注释符号_MySQL基础知识(2021最新版教程)
  8. maven java web项目_Maven创建JavaWeb项目
  9. 文献阅读软件_推荐一款阅读英文文献的神器,效率高不少,理解深不少!
  10. 学会这一招,轻松玩转 app 中混合应用自动化测试