文章目录

  • 引入
    • 1.1 隐藏层
    • 1.2 激活函数
      • 1.2.1 ReLU函数:
      • 1.2.2 sigmoid函数
  • 2 完整代码
  • 支持代码

引入

  深度学习主要关注多层模型,接下来将以多层感知机 (multi-layer perceptron, MLP)为例,介绍多层神经网络的概念。

1.1 隐藏层

  多层感知机在单层神经网络的基础上引入了一到多个隐藏层 (hidden layer)。隐藏层位于输入层和输出层之间,以下图为例,它含有一个隐藏层,该层包含5个隐藏单元 (hidden unit):

  图片来源:李沐、Aston Zhang等老师的这本《动手学深度学习》一书。
  由于输入层不涉及计算,所以上图所示的感知机的层数为2。
  相关的符号如下:

符号 含义
X∈Rn×d\boldsymbol{X} \in \boldsymbol{R}^{n×d}X∈Rn×d 小批量样本
nnn 批量大小
ddd 输入个数
hhh 隐藏单元个数 (假设只有一个隐藏层)
H∈Rn×h\boldsymbol{H} \in \boldsymbol{R}^{n × h}H∈Rn×h 隐藏层的输出
Wh∈Rd×h\boldsymbol{W}_h \in \boldsymbol{R}^{d × h}Wh​∈Rd×h 隐藏层权重参数
bh∈R1×h\boldsymbol{b}_h \in \boldsymbol{R}^{1 × h}bh​∈R1×h 隐藏层偏差参数
Wo∈Rh×q\boldsymbol{W}_o \in \boldsymbol{R}^{h × q}Wo​∈Rh×q 输出层权重参数
b0∈R1×q\boldsymbol{b}_0 \in \boldsymbol{R}^{1 × q}b0​∈R1×q 输出层偏差参数
qqq 输出个数

  首先介绍一种含单隐藏层的多层感知机的设计,其输出O∈Rn×q\boldsymbol{O} \in \boldsymbol{R}^{n × q}O∈Rn×q的计算为:
H=XWh+bhO=XWo+bo(1)\begin{matrix} \boldsymbol{H} = \boldsymbol{XW}_h + \boldsymbol{b}_h\\ \boldsymbol{O} = \boldsymbol{XW}_o + \boldsymbol{b}_o \tag{1} \end{matrix} H=XWh​+bh​O=XWo​+bo​​(1)也就是将隐藏层的输出直接作为输出层的输入。联立上式:
O=XWhWo+bhWo+bo(2)\boldsymbol{O} = \boldsymbol{XW}_h \boldsymbol{Wo} + \boldsymbol{b}_h \boldsymbol{W}_o + \boldsymbol{b}_o \tag{2} O=XWh​Wo+bh​Wo​+bo​(2)从联立的式子可以看出,虽然神经网络引入了隐藏层,却依然等价于一个单层神经网络。显然,即便引入再多的隐藏层,以上设计依然只能与仅含输出层的单层神经网络等价。

1.2 激活函数

  以下给出相关激活函数的相关示例。

1.2.1 ReLU函数:

  示例代码如下:

'''
@(#)MultilayerPerceptron.py
The class of multilayer perceptron.
Author: inki
Email: inki.yinji@qq.com
Created on May 15, 2020
Last Modified on May 15, 2020
'''import torch
import CommonDPif __name__ == '__main__':x = torch.arange(-8., 8., 0.1, requires_grad=True)y = x.relu()CommonDP.plot(x, y, 'x', 'relu (x)')

  运行结果:

  当输入为负时,ReLU函数的导数为0;输入为正时,其导数为1。尽管输入为0时其函数不可导,但仍可取此处的导数为0。导数绘制代码如下:

import torch
import CommonDPif __name__ == '__main__':x = torch.arange(-8., 8., 0.1, requires_grad=True)y = x.relu()y.sum().backward()CommonDP.plot(x, x.grad, 'x', 'grad of relu (x)')

  运行结果:

1.2.2 sigmoid函数

  懂的都懂。

2 完整代码

"""
@author: Inki
@contact: inki.yinji@gmail.com
@version: Created in 2020 1122, last modified in 2020 1124.
"""import torch
from torch import nn
from torch.nn import init
from util.SimpleTool import load_data_fashion_mnistdef train(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):"""Model train.@param:net:        The constructed net.train_iter: The iterator of training set with labels.test_iter:  The iterator of test set with labels.loss:       The loss function.num_epochs: The number of epochs.batch_size: The batch_size.params:     The parameters for net.lr:         The learning rate.optimizer:  The optimizer."""# Main loop.for idx_epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:sgd(params, lr, batch_size)else:optimizer.step()  # “softmax回归的简洁实现”一节将用到train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (idx_epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))def sgd(params, lr, batch_size):# 为了和原书保持一致,这里除以了batch_size,但是应该是不用除的,因为一般用PyTorch计算loss时就默认已经# 沿batch维求了平均了。for param in params:param.data -= lr * param.grad / batch_size  # 注意这里更改param时用的param.datadef evaluate_accuracy(data_iter, net, device=None):if device is None and isinstance(net, torch.nn.Module):# 如果没指定device就使用net的devicedevice = list(net.parameters())[0].deviceacc_sum, n = 0.0, 0with torch.no_grad():for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval()  # 评估模式, 这会关闭dropoutacc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train()  # 改回训练模式else:  # 自定义的模型, 3.13节之后不会用到, 不考虑GPUif ('is_training' in net.__code__.co_varnames):  # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item()else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / nclass LinearNet(nn.Module):"""THe linear net."""def __init__(self):"""The constructor."""super(LinearNet, self).__init__()def forward(self, x):"""The forward function.@param:x:  The given samples."""return x.view(x.shape[0], -1)if __name__ == '__main__':temp_num_inputs = 784temp_num_outputs = 10temp_num_hidden = 256temp_num_epochs = 5temp_batch_size = 10temp_train_iter, temp_test_iter = load_data_fashion_mnist(temp_batch_size)temp_net = nn.Sequential(LinearNet(),nn.Linear(temp_num_inputs, temp_num_hidden),nn.ReLU(),nn.Linear(temp_num_hidden, temp_num_outputs),)for temp_params in temp_net.parameters():init.normal_(temp_params, mean=0, std=0.01)temp_loss = nn.CrossEntropyLoss()temp_optimizer = torch.optim.SGD(temp_net.parameters(), lr=0.1)train(temp_net, temp_train_iter, temp_test_iter, temp_loss,temp_num_epochs, temp_batch_size, None, None, temp_optimizer)

支持代码

def load_data_fashion_mnist(batch_size=10, root='D:/Data/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""transform = transforms.ToTensor()mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0else:num_workers = cpu_count()train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter

torch学习(六):多层感知机相关推荐

  1. 学习笔记 | 多层感知机(MLP)、Transformer

    目录 多层感知机(MLP) Transformer 1. inputs 输入 2. Transformer的Encoder 2.1 Multi-Head Attention 2.2 Add&N ...

  2. 【动手学深度学习】多层感知机(MLP)

    1 多层感知机的从零开始实现 torch.nn 继续使用Fashion-MNIST图像分类数据集 导入需要的包 import torch from torch import nn from d2l i ...

  3. 动手学深度学习之多层感知机

    多层感知机 多层感知机的基本知识 深度学习主要关注多层模型.本节将以多层感知机(multilayer perceptron,MLP)为例,介绍多层神经网络的概念. 隐藏层 下图展示了一个多层感知机的神 ...

  4. 深度学习基础——多层感知机

    多层感知机(Multilayer Perceptron, MLP)是最简单的深度网络.本文回顾多层感知机的相关内容及一些基本概念术语. 多层感知机 为什么需要多层感知机 多层感知机是对线性回归的拓展和 ...

  5. [深度学习] (sklearn)多层感知机对葡萄酒的分类

    时间:2021年12月2日 from sklearn.datasets import load_wine from sklearn.model_selection import train_test_ ...

  6. 机器学习理论之(13):感知机 Perceptron;多层感知机(神经网络)

    文章目录 表示学习 (representation Learning) 生物神经元 V.S. 人造神经元 感知机 (Perceptron) 训练感知机(Training Perceptron) 激活函 ...

  7. MLP多层感知机 学习笔记

    cvpr2022的 mobileformer中用到了mlp多层感知机,就来学习一下 其实就是3个全连接层,前面两个加了bn,最后一层没有加bn. import timeimport torch fro ...

  8. 动手学深度学习(PyTorch实现)(五)--多层感知机

    多层感知机 1. 基本知识 2. 激活函数 2.1 ReLU函数 2.2 Sigmoid函数 2.3 tanh函数 2.4 关于激活函数的选择 3. PyTorch实现 3.1 导入相应的包 3.2 ...

  9. 动手学习深度学习 04:多层感知机

    文章目录 01 多层感知机 1.感知机 总结 2.多层感知机 2.1.隐藏层 2.1.1 线性模型可能会出错 2.1.2 在网络中加入隐藏层 2.1.3 从线性到非线性 2.1.4 通用近似定理 3. ...

最新文章

  1. Matlab R2016a 如何设置自己称心的工作区域
  2. 《Android应用开发》——1.3节配置Eclipse
  3. 14.索引数组初始化
  4. 使用post向webservice发送请求,并且返回值
  5. 工业交换机的价格为什么有高低之分?
  6. Element type quot;Resourcequot; must be followed by either attribute specifications, quot;gt;qu...
  7. c 连接oracle 通用类,c#操作oracle,有没有相仿sqlhelp之类的通用操作类(6)
  8. Highcharts:小案例,自定义图片下载路径,中文乱码的解决办法(不足之处,求指点)。...
  9. 【Oracle】ORA-55610: Invalid DDL statement on history-tracked table
  10. 这四款录屏工具,也许是电脑录屏软件中免费、无广告且最实用的
  11. android studio实现ar,在Android Studio上运行EasyAR
  12. mysql的性能瓶颈_Mysql性能优化(一) - 性能检测与瓶颈分析
  13. 寻求路径问题————动态规划的思想
  14. android framework 引入jia包
  15. 间隙锁-记一次死锁原因分析
  16. 大数据方向学习系列——hadoop——hdfs学习
  17. H3C 大规模网络路由技术 笔记
  18. Android挂断电话以及Java Class Loader
  19. MTK4G安卓核心板_XY6739CW(MTK6739平台)详细参数性能
  20. Docker也被禁了,Oracle还远吗?

热门文章

  1. 【1024】程序员节丨致敬所有技术布道师
  2. HTML5网页设计笔记
  3. cfree运行程序错误_Java 错误和异常汇总
  4. 智能物联:联想刘军的新赛道
  5. 【java毕业设计】基于javaEE+SSH+mysql的码头船只出行及配套货柜码放管理系统设计与实现(毕业论文+程序源码)——码头船只出行及配套货柜码放管理系统
  6. android 在线解析pdf文件格式,Android PDF预览阅读:用Mozilla PDF.js浏览本地在线PDF文件 | KaelLi的博客...
  7. C语言一行太长的换行处理
  8. 【翻译】Deep Learning-Based Video Coding: A Review and A Case Study
  9. Win8无法打开hlp文件
  10. C Primer Plus 第三章 数据和C 阅读笔记