最近为了实现HR-net在学习pytorch,然后突然发现这个框架简直比tensorflow要方便太多太多啊,我本来其实不太喜欢python,但是这个框架使用的流畅性真的让我非常的喜欢,下面我就开始介绍从0开始编写一个Lenet并用它来训练cifar10。

1.首先需要先找到Lenet的结构图再考虑怎么去实现它,在网上找了一个供参考

2.需要下载好cifar-10的数据集,在pytorch下默认的是下载cifar-10-python版本的,由于官网速度较慢,我直接提供度娘网盘的链接:链接:https://pan.baidu.com/s/18LNEZmGVkzEwf3SgOrO2rw  密码:n1h7

3.下载好数据集后,需要定义网络的结构,根据图我们可以看出,整个lenet只有两个卷积层,两个池化层(其实应该叫降采样层,那个时候还没有池化),三个全连接层。

pytorch中有一个容器,叫做Sequential,你可以在这个容器里添加你需要使用的卷积,池化,全连接操作,但是,这个Sequential它只能包含类方法定义的层,而不能包含像torch.Functional里面的函数方法(可能我说的不专业,见谅),所以如果当你想自己定义某个层的话,例如在输入全连接层之前,需要将形如[batch_size,channel,higth,width]的tensor转化成[batch_size,channel*higth*width]这种形式,那我如果想在Sequential这个容器里加入这一个操作该怎么办呢,这时候就需要我们继承nn.Module这个类来实现,具体的方法如下

importtorchimporttorch.nn as nnclassFlatten(nn.Module):def __init__(self):

super(Flatten, self).__init__()defforward(self,input):

out=input.view(input.size(0),-1)return out

好了,介绍完Sequential我们就开始实现这个网络的结构吧

#文件名是Lenet5.py

importtorchimporttorch.nn as nnfrom pytorch__lesson.pytorch_mnist.main importFlattenclassLenet(nn.Module):def __init__(self):

super(Lenet, self).__init__()

self.net=nn.Sequential(

nn.Conv2d(3,6,5,stride=1,padding=0),

nn.MaxPool2d(kernel_size=2,stride=2,padding=0),

nn.Conv2d(6,16,5,stride=1,padding=0),

nn.MaxPool2d(kernel_size=2,stride=2,padding=0),

Flatten(),

nn.Linear(400,120),

nn.ReLU(inplace=True),

nn.Linear(120,84),

nn.ReLU(inplace=True),

nn.Linear(84,10),

nn.ReLU(inplace=True)

)#self.criteon=nn.CrossEntropyLoss()

defforward(self,x):

logits=self.net(x)#pred=nn.Softmax(logits,dim=1),这一行不需要写,因为在CrossEntropyLoss这一步包含了softmax的操作

returnlogits#net=Lenet()#input=torch.randn(2,3,32,32)#out=net(input)#print(out.shape)

其中这里面的Flatten就是上面代码的Flatten类。因为它继承了nn.Module因此可以直接将其放在Sequential里面了,以后定义任何网络,我们都可以使用这个类来进行tensor的展平操作。

4.接下来就可以定义训练部分的代码了

importtorchfrom torchvision importdatasets,transformsfrom torch.utils.data importDataLoaderimporttorch.nn as nnimporttorch.optim as optimimporttorch.functional as Ffrom pytorch__lesson.pytorch_mnist.Lenet5 importLenet

batch_size=32

defmain():#cifar表示的是在当前的目录下新建一个叫cifar的文件夹,这个方法一次只能加载一张

cifar_train=datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([

transforms.Resize((32,32)),

transforms.ToTensor()

]),download=True)#这个方法才能保证一次读取进来的是一个batch_size大小的数据

cifar_train_loader=DataLoader(cifar_train,batch_size=batch_size,shuffle=True)

cifar_test=datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([

transforms.Resize((32,32)),

transforms.ToTensor()

]))

cifar_test_loader=DataLoader(cifar_test,batch_size=batch_size,shuffle=False)

x,label=iter(cifar_train_loader).next()print('x shapex:',x.shape,'label shape:',label.shape)#use CrossEntropy as the loss function

criteon=nn.CrossEntropyLoss()#use Lenet() function to build a model

#net=Lenet().to(device) 将模型放入cuda上进行加速

net=Lenet()

optimizer=optim.Adam(net.parameters(),lr=1e-3)#device=torch.device('cuda')

#net=Lenet().to(device) 将模型放入cuda上进行加速

print(net)for epoch in range(1000):for batchidx,(x,label) inenumerate(cifar_train_loader):#生成软对数

#将网络转化成train的模式

net.train()

logits=net(x)#x,label=x.to(device),label.to(device)

#使用crossentropyloss的就不需要将logits放入到softmax中了,直接就可以计算出loss

loss=criteon(logits,label)#接下来进行反向的传播,先是将梯度清零,再进行反向传播,再进行梯度更新

optimizer.zero_grad()

loss.backward()

optimizer.step()#loss是一个tensor scalor 是一个长度为0的标量

print(epoch,loss.item())

net.eval()

with torch.no_grad():#将整个网络转换成test模式或者validation模式

#test这一部分不需要构造计算图也不需要统计梯度,因此将这部分放在函数torch.no_grad()

total_correct=0

total_num=0for x,label incifar_test_loader:#如果有gpu的话先将x和label放入gup进行加速

#[batch_size,10]

logits=net(x)#取出最大下标的索引[b]

pred=logits.argmax(dim=1)#eq函数调用后会返回一个byte,true或者false估计,然后需要将其转换成float类型再通过item()函数来提取它的值

total_correct+=torch.eq(label,pred).float().sum()

total_num+=x.size(0)

acc=total_correct/total_numprint(epoch,'the acc of the test is :',(acc*100))if __name__=='__main__':

main()

因为我的电脑没有英伟达的显卡,不支持cuda加速,因此的话没办法都训练出来截图,如果有N卡的,可以自己试试,注释写的比较详细,我就不再赘述了,不是很难。

ps:我太唠叨了吧

python实现lenet_手把手教你写一个用pytorch实现的Lenet5相关推荐

  1. python k线合成_手把手教你写一个Python版的K线合成函数

    手把手教你写一个Python版的K线合成函数 在编写.使用策略时,经常会使用一些不常用的K线周期数据.然而交易所.数据源又没有提供这些周期的数据.只能通过使用已有周期的数据进行合成.合成算法已经有一个 ...

  2. 手把手教你写一个生成对抗网络

    成对抗网络代码全解析, 详细代码解析(TensorFlow, numpy, matplotlib, scipy) 那么,什么是 GANs? 用 Ian Goodfellow 自己的话来说: " ...

  3. 手把手教你写一个中文聊天机器人

    本文来自作者 赵英俊(Enjoy) 在 GitChat 上分享 「手把手教你写一个中文聊天机器人」,「阅读原文」查看交流实录. 「文末高能」 编辑 | 哈比 一.前言 发布这篇 Chat 的初衷是想和 ...

  4. 手把手教你写一个spring IOC容器

    本文分享自华为云社区<手把手教你写一个spring IOC容器>,原文作者:技术火炬手. spring框架的基础核心和起点毫无疑问就是IOC,IOC作为spring容器提供的核心技术,成功 ...

  5. 手把手教你写一个Matlab App(一)

    对于传统工科的学生用的最多的编程软件应该就是matlab,其集成度高,计算能力强,容易上手,颇受大众青睐.今天挖的这个新坑,主要是分享用matlab app designer设计GUI界面的一些方法和 ...

  6. 后端思维篇:手把手教你写一个并行调用模板

    前言 36个设计接口的锦囊中,也提到一个知识点:就是使用并行调用优化接口.所以接下来呢,就快马加鞭写第二篇:手把手教你写一个并行调用模板~ 一个串行调用的例子(App首页信息查询) Completio ...

  7. 从原理到实现丨手把手教你写一个线程池丨源码分析丨线程池内部组成及优化

    人人都能学会的线程池 手写完整版 1. 线程池的使用场景 2. 线程池的内部组成 3. 线程池优化 [项目实战]从原理到实现丨手把手教你写一个线程池丨源码分析丨线程池内部组成及优化 内容包括:C/C+ ...

  8. 手把手教你写一个没有服务器的颜值打分小程序,可直接上线

    小程序现在可以说非常火爆了,流量入口非常多.尤其是出了流量主功能以后,普通开发者也能在自己的个人小程序里植入官方广告来获取收入.程序员想赚点外快再合适不过了.今天教大家写一个颜值打分的小程序,利用现成 ...

  9. 手把手教你写一个手势密码解锁View(GesturePasswordView)

    相信大家在很多的app肯定看到过手势密码解锁View,但是大家有没有想过怎么实现这样一个View,哈,接下来,小编手把手教大家教写一个GesturePasswordView. 先看一张效果图 要实现这 ...

最新文章

  1. JVM生产环境参数实例及分析
  2. 多个安卓设备投屏到电脑_辅助多手机同时直播控场 TotalControl手机投屏软件
  3. .Net 程序集 签名工具sn.exe 密钥对SNK文件 最基本的用法
  4. 五、实例:在波士顿房价数据集上用随机森林回归填补缺失值
  5. 11 PP配置-生产主数据-工作中心相关-定义工作中心屏幕顺序
  6. Android studio ,Gradle 添加so库
  7. 镜像翻转_《蒙娜丽莎》镜像翻转后,暗藏神秘的第二张脸?网友:笑容消失了...
  8. Java实现PDF添加图片水印和文字水印
  9. 数据库系统原理与应用教程(031)—— MySQL 的数据完整性(四):定义外键(FOREIGN KEY)
  10. 反垃圾邮件技术介绍和部署思路
  11. 如何提高公寓房屋出租率?
  12. python+tkinter+threading制作多线程简易音乐播放器(自动播放,上一曲,下一曲,播放,暂停,实时显示歌曲名并能自动切换歌曲的功能)
  13. 19、生鲜电商平台-安全设计与架构
  14. 【全局面包屑导航】依据路由动态生成面包屑导航
  15. ColorMatrix 5*5颜色矩阵
  16. android极光推送原理,【揭秘】极光推送ios、Android消息推送达率的原理
  17. 调配颜色(自己随便造的名字)
  18. 联接+AI,华为用智能联接为智能时代加速
  19. 数据库报错1166 - Incorrect column name 的解决方法
  20. Oracle自动生成ID,自动编号,自增补零填充

热门文章

  1. windows10cmd中测试qq邮箱smtp服务
  2. 三星html查看器怎么取消默认,三星galaxy note各种使用小技巧
  3. Windows添加路由:route add 173.18.18.0 mask 255.255.255.0 172.18.18.1
  4. 利用ICMP协议反弹shell
  5. jmeter执行脚本并生成测试报告时报错:Results file:result.jtl is not empty
  6. nvue标签换行影响横排效果
  7. linux 触摸结构体,xboot-x4412ibox项目实战54-Linux触摸屏驱动之I2C驱动实验 - Powered by Discuz!...
  8. r55625u和i51135g7选哪个 r5 5625u和i5 1135g7对比
  9. Android第三方登录
  10. JRebel 激活地址