在处理多分类问题的时候会用到一个叫做softmax的分类器,是用来将输出结果划归到[0,1]的,本讲将主要从softmax分类器入手来实现多分类问题。在前一章我们对糖尿病模型进行了二分类,二分类问题中只需要输出一个概率,另外的一个概率通过用1来减即可获得。但多分类需要输出多个概率。

本次我们采用MNIST手写数字数据集,首先我们来看一下如果有十个分类那他们的输出该是什么样的。若有十个分类,那这10个概率的输出应该是总和=1且均>0的。但某些情况下,可能会出现P(y=1)=0.8,P(y=2)=0.9 这样的情况,所以当我们求出P(y=1)=0.8后需要对后面的概率情况进行抑制。

神经网络计算出来的结果可能是小于0的,可能总和不为1,所以下面我们就要请softmax来了。softmax的公式如下:

其中,z_l是线性层最后一层的输出,e^zi作用为强制使其>0,分母的作用为保证概率求Σ之后为1,这样就实现了功能需求。softmax作用的示意图如图1所示。

图1 softmax处理原理

接下来我们将进行多分类问题中损失函数的求解。在多分类问题中,loss函数非常简单,就是把预测标签为1的那个概率值取log再加负号即可。

图2 loss函数

下面是求loss值的代码部分:

import numpy as npy = np.array([1, 0, 0])
z = np.array([0.2, 0.1, -0.1])
y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()print(loss)

pytorch提供了现成的交叉熵损失函数框架,该框架包含了从softmax开始一直到输出的全过程,所以输入的时候只需要将神经网络计算的原始结果输入到框架中就行了,不需要做激活。下面是计算loss值的完整代码:

import torchcriterion = torch.nn.CrossEntropyLoss ()
Y = torch. LongTensor ([2, 0, 1])
Y_predl = torch.Tensor([[0.1,0.2, 0.9],                  [1.1, 0.1, 0.2],                  [0.2, 2.1, 0.1]])
Y_pred2 = torch.Tensor([[0.8, 0.2, 0.3],                  [0.2, 0.3, 0.5],                  [0.2, 0.2, 0.5]])   l1 = criterion(Y_predl, Y)
l2 = criterion(Y_pred2, Y)
print("Batch Lossl =", l1.data,"\nBatch Loss2 =",l2.data)

在这个例子中,[2, 0, 1]代表三个标签。2代表预测结果第三个概率最大,0代表第一个最大。在y1中我们可以看出,[0.1, 0.2, 0.9]显然第三个概率最大,即对应于[2, 0, 1]中的2。而在y2中显然是乱猜一气,完全对应不上,所以y2的loss值必然很大。

两个loss值的运行结果如下:

Batch Lossl = tensor(0.4966)
Batch Loss2 = tensor(1.2389)

下面我们回过头来看MNIST手写数字的多分类问题。之前我们讲的例子中输入都是一个向量,这里都是输入的图片。也不难,只需要将其映射为图像张量即可,如图3所示。

图3 图像转图像张量

下面是数据集的准备,其中Normalize是对数据进行归一化处理(就是映射到0~1之间),其中的平均值和标准差是根据MNIST数据集大量计算后得出的数字。ToTensor代表着映射通道,直接写上就行。

图4 数据集的准备

接下来我们看一下模型的构建,由于输入的是一个28×28=784的,我们需要先把矩阵一行一行的平铺成一行,也就是一个1行784列的向量。1行向量不能满足输入需求,这里我们使用view()方法将张量的形状改为一个二阶的,第一个参数设为-1是指让其自动计算。随后,我们需要将784降至512、256、128、64、10,到10的原因是输出是10种可能。不能一口气降到10的原因是这样会损失太多信息而无法有效训练。这期间再穿插着ReLU激活函数,即可完成模型的构建,流程如图5所示。

图5 模型构建的流程

下面是模型构建的代码实现:

class Net (torch.nn.Module) :def __init__(self) :super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x) :x = x.view(-1, 784)x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)    # 最后一层不做激活model = Net()

这里选用的损失函数和优化器要做一些变化,交叉熵损失作为计算loss函数的方法,梯度下降种我们采用带冲量的,引入冲量的目的是加快梯度下降的速率、突破局部极小值问题。

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr = 0.01,momentum = 0.5)

训练过程就不再过多解释了,整个的代码如下所示:

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim# prepare datasetbatch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化,均值和方差train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)# design model using classclass Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)  # -1其实就是自动获取mini_batchx = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)  # 最后一层不做激活,不进行非线性变换model = Net()# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# training cycle forward, backward, updatedef train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):# 获得一个批次的数据和标签inputs, target = dataoptimizer.zero_grad()# 获得模型预测结果(64, 10)outputs = model(inputs)# 交叉熵代价函数outputs(64,10),target(64)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度total += labels.size(0)correct += (predicted == labels).sum().item() # 张量之间的比较运算print('accuracy on test set: %d %% ' % (100*correct/total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()

[Pytorch] 学习记录(七)MNIST多分类问题相关推荐

  1. PyTorch学习记录——PyTorch进阶训练技巧

    PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...

  2. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  3. Pytorch学习记录(七):自定义模型 Auto-Encoders 使用numpy实现BP神经网络

    文章目录 1. 自定义模型 1.1 自定义数据集加载 1.2 自定义数据集数据预处理 1.3 图像数据存储结构 1.4 模型构建 1.5 训练模型 2. Auto-Encoders 2.1 无监督学习 ...

  4. PyTorch学习记录——PyTorch生态

    Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...

  5. 黄金时代 —— Pytorch学习记录(一)

    文章目录 Tensor Tensor操作 桥接 NumPy Cuda张量 Autograd:自动求导 张量 梯度 定义网络 关于nn和nn.Module模块 网络 BP过程 损失函数 反向传播 更新权 ...

  6. 【多线程】学习记录七种主线程等待子线程结束之后在执行的方法

    最近遇到一个问题需要主线程等待所有的子线程结束,才能开始执行,统计所有的子线程执行结果,返回,网上翻阅各种资料,最后记录一下,找到七种方案 第一种:while循环 对于"等待所有的子线程结束 ...

  7. excel分类_Excel数据处理学习(七)使用分类汇总

    端午节快乐 今天端午节呀,所以更新的有点晚-这个系列的日更只会迟到,但不会缺席! 来看看分类汇总和定位工具的应用以及创建组工具吧- 01 先来看看数据源,会发现使用的数据源就是我们第一节讲的数据透视表 ...

  8. PyTorch学习记录-1PyTorch安装

    学习建议里有PyTorch,所以我就开始了PyTorch的学习. 首先就是安装啦,去官网很清楚,可以选择自己的版本和平台,然后下面就会出现 Run this command:  后面跟着的命令复制运行 ...

  9. pytorch学习笔记七:nn网络层——池化层、线性层

    一.池化层 池化运算:对信号进行"收集" 并"总结",类似于水池收集水资源,因而得名池化层. 收集:由多变少,图像的尺寸由大变小 总结:最大值/平均值 下面是最 ...

最新文章

  1. leetcode9 Palindrome Number 回数
  2. 最长子段和 11061008 谢子鸣
  3. [翻译] WindowsPhone-GameBoy模拟器开发二--Rom文件分析
  4. Educational Codeforces Round 32 G. Xor-MST 01tire + 分治 + Boruvka
  5. Chapter7-6_Text Style Transfer
  6. 使用Docker部署SpringBoot
  7. AI 高等数学、概率论基础
  8. Web 前端知识体系精简
  9. SQL语言入坑—1.数据的检索、排序、过滤、分组
  10. java 设置系统参数_Java设置系统参数和运行参数
  11. 英语常用九种时态记忆要点
  12. 自学android刷机包,Android刷机包解包打包
  13. 常用编程语言开发工具
  14. hmailserver mysql密码_mysql+hmailserver+roundcube修改密码
  15. 离散数学_命题逻辑的演绎推理
  16. oracle 逗号,查询oracle中逗号分隔字符串中所有值
  17. 深以为然-为什么一些JAVA EE / J2EE 工程是效率低下或者至少是效率欠佳的(翻译)
  18. rono在oracle的作用_Oracle 11g各种服务作用以及哪些需要开启
  19. 2006年江苏专转本计算机试卷答案,2006年度江苏省普通高校专转本计算机试卷.doc...
  20. HTTPS之TLS证书

热门文章

  1. 导出excel 规则数据多个sheet
  2. 基于蓝牙适配器的PC与Android端通讯
  3. 中文分词下载IK Analyzer 2012FF_hf1
  4. GIS核心期刊资料等(精心收集资料)
  5. WEB漏洞扫描器 – 北极熊扫描器
  6. centos 7 重启mysql_centOS7 如何启动/停止/重启MySQL
  7. Error response from daemon: conflict: unable to delete d0957ffdf8a2 (must be forced) - image is refe
  8. 发财项目冲顶大会逆向分析
  9. Java Math.log10()方法
  10. 未实名的.com/.net域名即将被暂停解析,网站/邮箱等无法访问!