深度学习:NiN(Network In Network)详细讲解与代码实现

  • 网络核心思想
    • 1*1卷积
    • NiN块的作用
    • 全局池化(Global Average Pooling)
  • 基于NiN的服装分类(Pytorch)
    • 服装分类数据集
    • 定义模型
    • 测试数据
    • 训练模型

网络核心思想

LeNet、AlexNet和VGG都有一个共同的设计模式:通过一系列的卷积层与汇聚层来提取空间结构特征;然后通过全连接层对特征的表征进行处理。
AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。
网络中的网络(NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机 [Lin et al., 2013]

1*1卷积


1 * 1卷积实际上就是对特征图所有channel对应的像素点做全连接网络,由于它只考虑了1个像素点,它不像3 * 3卷积那样可以考虑周围像素点,但是可以让特征图在不需要padding的情况下保证的H、W不变,也就是融合了买个像素点不同通道的特征所以它也有跨通道交融的作用。卷积核的数量决定了输出的维度,所以用1 * 1卷积只会改变特征图的channel数,这也就是1 * 1卷积有升维 、降维的作用,在维度降低的同时,计算量也就减少了,模型速度会变快,与此同时,它在保留了空间信息的同时,还增加了非线性激活函数,非线性激活函数可以增加模型的复杂程度,让模型逼近更复杂的曲线。

NiN块的作用


NiN是由AlexNet改进而来的,他的主要贡献是提出了NiN块这个概念,随着网络层数的加大,参数的数量也水涨船高,我们举一个例子:

假设现在特征图是 28 * 28 * 256,我们想把它变成28 * 28 *32
我们的原始方法是采用32个 5 * 5的卷积核。如上图,那么它的参数量就是:
5∗5∗32∗256+32=201825*5*32*256+32=201825∗5∗32∗256+32=20182
而我们的NiN块的核心思想是把原始特征图先降维在升维,也就是我把32个 5 * 5卷积核替换成先经过16个 1 * 1的卷积核,特征图就变成了, 28 * 28 *16 ,然后在经过一个32个 5 * 5 的卷积核,它的参数数量为:
1∗1∗16∗256+16=41121 * 1 * 16 * 256+16 =41121∗1∗16∗256+16=4112
5∗5∗32∗16+32=128325*5*32*16 +32=128325∗5∗32∗16+32=12832
求和也就是16944,参数量对比之前少了15%左右。

全局池化(Global Average Pooling)

在NiN网络中,去掉了卷积层后面的全连接层,加入与了全局池化层,全局池化层是把最后的特征图数量变成了分类的数量,这样的可解释性更强,之后,我们只需要对每一个channel求一个全局平均值,然后经过,Softmax分类,这样也大大减少了参数量。

基于NiN的服装分类(Pytorch)

服装分类数据集

我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0到1之间
def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)#通过compose组合多个操作mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=4),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=4))#num workers 为线程数

定义模型


def nin_block(in_channels, out_channels, kernel_size,strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())net = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),# 将四维的输出转成二维的输出,其形状为(批量大小,10)nn.Flatten())

测试数据

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)#累加器with torch.no_grad():#禁止计算梯度for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

训练模型

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()#更新参数with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')
train_iter, test_iter = load_data_fashion_mnist(256, resize=224)
train_ch6(net, train_iter, test_iter, 10, 0.01, d2l.try_gpu())

深度学习:NiN(Network In Network)详细讲解与代码实现相关推荐

  1. 【深度学习】基础知识 | 超详细逐步图解 Transformer

    作者 | Chilia 整理 | NewBeeNLP 1. 引言 读完先修知识中的文章之后,你会发现:RNN由于其顺序结构训练速度常常受到限制,既然Attention模型本身可以看到全局的信息, 那么 ...

  2. 深度学习框架caffe及py-faster-rcnn详细配置安装过程

    深度学习框架caffe及py-faster-rcnn详细配置安装过程 配置环境: ubuntu14.04 CUDA7.5 1.准备工作 安装vim.python-pip.git sudo apt-ge ...

  3. 【深度学习】搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

    作者丨科技猛兽 编辑丨极市平台 导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Transformer的实现和代码以及Tr ...

  4. 2023年的深度学习入门指南(14) - 不能只关注模型代码

    2023年的深度学习入门指南(14) - 不能只关注模型代码 最近,有一张大模型的发展树非常流行: 这个图是相当不错的,对于加深对于Transformer模型编码器.解码器作用的理解,模型的开源和闭源 ...

  5. 【OS系列-2】- 进程详细讲解(代码示例)

    进程 进程详细讲解(代码示例) 进程 示例代码 创建进程的具体过程? 执行 fork()的时候系统做了什么? 进程间通信 管道 消息队列 共享内存 信号量 套接字 进程间同步 信号量 文件锁 无锁 C ...

  6. 【深度学习】梯度和方向导数概念解析(代码基于Pytorch实现)

    [深度学习]梯度和方向导数概念解析(代码基于Pytorch实现) 文章目录 1 方向导数 2 梯度 3 自动求导实现 4 梯度下降4.1 概述4.2 小批量梯度下降 5 总结 1 方向导数 方向导数的 ...

  7. 手机上的机器学习资源!Github标星过万的吴恩达机器学习、深度学习课程笔记,《统计学习方法》代码实现!...

    吴恩达机器学习.深度学习,李航老师<统计学习方法>.CS229数学基础等,可以说是机器学习入门的宝典.本文推荐一个网站"机器学习初学者",把以上资源的笔记.代码实现做成 ...

  8. 【深度学习】2021 最新视频防抖论文+开源代码汇总

    大家好,今天给大家分享,今年三篇关于视频防抖的文章,这三篇文章分布采用了不同的方法来解决视频抖动的问题. 1.基于深度的三维视频稳定学习方法Deep3D稳定器 2.融合运动传感器数据和光流,实现在线视 ...

  9. lamport面包店算法详细讲解及代码实现

    lamport面包店算法详细讲解及代码实现 1 算法详解 1.1 一个较为直观的解释 1.2 Lamport算法的时间戳原理 1.3 Lamport算法的5个原则 1.4 一个小栗子 2 算法实现 3 ...

最新文章

  1. 使用 CAS 在 Tomcat 中实现单点登录
  2. 只此一招,全屏操作从此易如反掌
  3. oracle查询表的索引
  4. 使用 yum 安装Docker(CentOS 7下)
  5. laravel实现读写分离
  6. 201571030322/201571030319《小学生四则运算软件需求说明结对项目报告》
  7. oracle client 11.2.0.3 32位,oracle client 32位/64位下载(Oracl数据库)
  8. 一个例子看懂神马是闭包
  9. centos mysql rpm re_CentOS 7 RPM 安装 MySQL5.7
  10. Oracle 19c 新特性:混合分区表Hybrid partitioned tables强体验
  11. building xxx gradle project info的解决办法
  12. linux用六维BT
  13. [转][Err] 1452 - Cannot add or update a child row: a foreign key constraint fail
  14. Ovum 最新市场报告称数据中心持续改变光网络市场
  15. 数据库设计实例-教务管理系统
  16. 年轻人的第一笔债,在双11的直播间里
  17. 蓝桥杯——等差素数列(c语言)
  18. D盘目录或文件被损坏且无法读取怎么办
  19. 跟我学-域名解析故障排查技巧
  20. python爬取b站弹幕_如何爬取B站弹幕

热门文章

  1. 替代if else 的方法---巧用枚举类和抽象方法
  2. 程序员必看的10本书,轻松提升自己
  3. 视频播放器,基于videojs,NVR
  4. “新冠肺炎”会让远程移动办公成为很酷的工作方式?
  5. 技术人员创业的第一步分析
  6. go中的goroutine
  7. 抽象类和接口的区别(浅显易懂)
  8. 科幻电影十大经典段落
  9. [嵌入式] 重温Mini2440(二)移植Linux-4.9.270
  10. css、HTML制作小米商城网页(一)