1.环境
Ubuntu20.04
Vscode
Cuda 11.2
Pytorch 1.8
2.代码

import time
import torch
import torchvision
from torch import nn,optimdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.conv = nn.Sequential(nn.Conv2d(1,6,5),nn.Sigmoid(),nn.MaxPool2d(2,2),nn.Conv2d(6,16,5),nn.Sigmoid(),nn.MaxPool2d(2,2))self.fc = nn.Sequential(nn.Linear(16*4*4,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10))def forward(self,img):feature = self.conv(img)output = self.fc(feature.view(img.shape[0],-1))return outputnet = LeNet()def load_data_fashion_mnist(batch_size,resize=None,root='~/Datasets/FashionMNIST'):trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root,train=True,download=True,transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root,train=False,download=True,transform=transform)train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=4)test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=4)return train_iter,test_iterdef evaluate_accuracy(data_iter,net,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):acc_sum,n = 0.0,0with torch.no_grad():for X,y in data_iter:if isinstance(net,torch.nn.Module):net.eval()acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train()else:if('is_training' in net.__code__.co_varnames):acc_sum += (net(X,is_training=False).argmax(dim=1) == y).float().sum().item()else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum/ndef train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs):net = net.to(device)print("training on ",device)loss = torch.nn.CrossEntropyLoss()batch_count = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0,0.0,0,time.time()for X,y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat,y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter,net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' %(epoch + 1,train_l_sum/batch_count,train_acc_sum/n,test_acc,time.time()-start))batch_size = 256
train_iter,test_iter = load_data_fashion_mnist(batch_size=batch_size)lr, num_epochs = 0.001, 10
optimizer = torch.optim.Adam(net.parameters(),lr=lr)
train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)

3.结果

Pytorch学习笔记——LeNet模型相关推荐

  1. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  2. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  3. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  4. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  5. PyTorch学习笔记(三):PyTorch主要组成模块

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  6. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

  7. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  8. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  9. 深度学习入门之PyTorch学习笔记

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 5 循环神经网络 6 生成对抗网络 7 深度学习实战 参考资料 绪论 深度学习如今 ...

最新文章

  1. 自己拿项目,软件设计开发,释放你的力量
  2. Socket阻塞,非阻塞,同步,异步
  3. volume 生命周期管理 - 每天5分钟玩转 Docker 容器技术(44)
  4. (常用API)正则表达式语法规则
  5. linux accept过程,Linux协议栈accept和syn队列问题
  6. 北方人思想为什么落后_广西人为什么很少到北方打工?
  7. Vs2010工具栏显示“开始执行“按钮
  8. python爬取豆瓣电影名称与评分进行分析
  9. 信号与噪声经过匹配滤波器后能量
  10. 程序员在国外:我用20天在加拿大找到首份工作
  11. 如何使用c语言开发ebpf程序
  12. 《自拍教程48》Python_adb随机地图移图2小时
  13. mac部署rabbitmq流程与异常总结
  14. 运动酒店,如何“奇袭”文旅产业精准蓝海赛道——缤跃酒店
  15. 阿里云-钉钉-企业邮箱
  16. 如何使用python的openpyxl进行强大的图表处理
  17. 未婚同居能白头偕老吗
  18. 棋盘覆盖问题 (分治)
  19. Oracle导出DMP文件的两种方法
  20. 读源码学算法之Monte Carlo Tree Search

热门文章

  1. RF ADC指标:NSD、IM3和ACLR
  2. Android 天气APP(十)继续优化、下拉刷新页面天气数据
  3. 揭秘三大运营商在5G专网的布局!
  4. DICOM:DICOM开源库多线程分析之“ThreadPoolQueue in fo-dicom”
  5. Java:2022年全球使用的15种最流行的Java应用
  6. mac下的socket调试工具---sokit
  7. 抖音短视频源码中视频排序模块热门列表解决方案
  8. SQL: 视图和表的区别
  9. 【SqlServer-函数】
  10. Python脚本翻译英文到汉语