Basic CNN

传送门:https://www.bilibili.com/video/BV1Y7411d7Ys?p=10

模型框架:



代码

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt#1.prepare dataset
#2.design model using class
#3.construct loss and optimizer
#4.training cycle+test#1.准备数据集batch_size = 64
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307, ), (0.3081, ))#均值,标准化
])
train_dataset = datasets.MNIST(root='./dataset/mnist',train=True,transform=transform,download=True)
print(train_dataset[0])
test_dataset = datasets.MNIST(root='./dataset/mnist',train=False,transform=transform,download=True)train_loader = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=32,shuffle=False)# ---------------------------卷积模型---------------------------
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)self.pooling = torch.nn.MaxPool2d(kernel_size=2, stride=2)self.fc = torch.nn.Linear(40, 20)def forward(self, x):batch_size = x.size(0)x = F.relu(self.pooling(self.conv1(x)))x = F.relu(self.pooling(self.conv2(x)))x = x.view(batch_size, -1)x = self.fc(x)return xmodel = Net()
# 开启显卡
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)#3.构建loss和optimzer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)#4.循环
def train(epoch):running_loss = 0.0for batch_idx, (inputs, target) in enumerate(train_loader):inputs, target = inputs.to(device), target.to(device) #显卡加速optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' % (100*correct / total))return correct / totalif __name__ == '__main__':epoch_list = []acc_list = []for epoch in range(10):train(epoch)acc = test()epoch_list.append(epoch)acc_list.append(acc)# if epoch % 10 == 9:#     test()import osos.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'plt.plot(epoch_list, acc_list)plt.xlabel('epoch')plt.ylabel('accuracy')plt.show()

对于图像的处理,线性模型显然会丢失掉图像之间的关系和联系,卷积神经网络有较好的提升,但网络过于简单会使得图像学习的效果较少,改进的网络/课后习题:

https://blog.csdn.net/frighting_ing/article/details/120773888?spm=1001.2014.3001.5501

《PyTorch深度学习实战》第十讲相关推荐

  1. 深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割

    深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割 1. 项目简介 2. 3D医学图像分割的需求 3. 医学图像和MRI 4. 三维医学图像表示 5. 3D-Unet模型 5.1损失函 ...

  2. 实战例子_Pytorch官方力荐新书《Pytorch深度学习实战指南》pdf及代码分享

    PyTorch是目前非常流行的机器学习.深度学习算法运算框架.它可以充分利用GPU进行加速,可以快速的处理复杂的深度学习模型,并且具有很好的扩展性,可以轻松扩展到分布式系统.PyTorch与Pytho ...

  3. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  4. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  5. PyTorch 深度学习实践 第13讲

    PyTorch 深度学习实践 第13讲 引言 代码 结果 引言 近期学习了B站 刘二大人的PyTorch深度学习实践,传送门PyTorch 深度学习实践--循环神经网络(高级篇),感觉受益匪浅,发现网 ...

  6. pytorch深度学习实战——预训练网络

    来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...

  7. PyTorch深度学习实战(5)——计算机视觉基础

    PyTorch深度学习实战(5)--计算机视觉基础 0. 前言 1. 图像表示 2. 将图像转换为结构化数组 2.1 灰度图像表示 2.2 彩色图像表示 3 利用神经网络进行图像分析的优势 小结 系列 ...

  8. PyTorch深度学习实战:从新手小白到数据科学家电子书

    作者:张敏 著 出版社:电子工业出版社 ISBN:9787121388293 出版时间:2020-08-01 PyTorch深度学习实战:从新手小白到数据科学家

  9. Pytorch深度学习实战教程:UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 本文的开发环境如下: 开发环境:Windows 开发语言:Python3. ...

  10. PyTorch 深度学习实践 第4讲

    第4讲  反向传播back propagation 源代码 B站 刘二大人 ,传送门PyTroch 深度学习实践--反向传播 如果需安装PyTorch,传送门 PyTorch深度学习快速入门教程 传送 ...

最新文章

  1. 算法学习之路|统计同成绩学生
  2. error: dereferencing pointer to incomplete type
  3. JAVA接口返回面积_java – 将接口的返回值限制为实现类的范围
  4. mysql中文乱码的一点理解
  5. 请解释为什么集合类没有实现Cloneable和Serializable接口?
  6. LocalStorage与SessionStorage
  7. python函数使用两个小括号
  8. k8s与监控--从telegraf改造谈golang多协程精确控制
  9. 新手安装Ubuntu操作系统
  10. 技术大佬:我去,你写的 switch 语句也太老土了吧!
  11. 一致 先验分布 后验分布_「分布式技术」分布式事务最终一致性解决方案,下篇...
  12. 拷贝data/data/包名/files文件记下所有文件及文件夹到本地sdcard根目录teddyData_files文件夹下...
  13. SpringBoot学习之文件结构和配置文件
  14. verilog学习 (二)
  15. 对比起来学习前端三大框架(持续更新)
  16. 苏宁易购关键词搜索商品方法
  17. 电子计算机为什么123安不出来,方正软件常见问题及其解决办法-精.doc
  18. Object.keys()的用法
  19. MATLAB 四点定球及三点定圆(完整代码)
  20. java数组的实例化

热门文章

  1. 征服者 游骑兵系列T117一体机最新款09年5月上市
  2. 解决RabbitMQ的The channelMax limit is reached. Try later.
  3. 难忘今宵,四六级和考研成绩公布!说说那些与英语考试有关的往事
  4. ERP 和 MES 之间的联系是什么?
  5. 升级版NanoDet-Plus来了 | 简单辅助模块加速训练收敛,精度大幅提升
  6. 斗罗大陆妇女节壁纸高清
  7. BZOJ_1014_[JSOI2008]_火星人prefix_(Splay+LCP_Hash+二分)
  8. CSV多标签数据预处理+加上csv多标签转文件夹+ ISIC2018数据集分类预处理
  9. 楼市刺激第三波政策潮来袭:税费补贴主打
  10. 三步换系统 win10到Ubuntu20.04