一、pytorch框架

1.1、概念

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。
2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:
1、具有强大的GPU加速的张量计算(如NumPy)。
2、包含自动求导系统的深度神经网络。

1.2、机器学习与深度学习的区别

两者之间区别很多,在本篇博客中只简单描述一部分。以图片的形式展现。
前者为机器学习的过程。
后者为深度学习的过程。

1.3、在python中导入pytorch成功截图

二、数据集

本次实验使用的是coco数据集中的植物病虫害数据集。分为训练文件Traindata和测试文件TestData.,
TrainData有9种分类,每一种分类有100张图片。
TestData有9中分类,每一种分类有10张图片。
在我下一篇博客中将数据集开源。
下面是我的数据集截图:

三、代码复现

3.1、导入第三方库

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib
import os
import cv2
from PIL import Image
import torchvision.transforms as transforms
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from Test.CNN import Net
import json
from Test.train_data import Mydataset,pad_image

3.2、CNN代码:

# 构建神经网络
class Net(nn.Module):#定义网络模块def __init__(self):super(Net, self).__init__()# 卷积,该图片有3层,6个特征,长宽均为5*5的像素点,每隔1步跳一下self.conv1 = nn.Conv2d(3, 6, 5)#//(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))self.pool = nn.MaxPool2d(2, 2)#最大池化#//(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)self.conv2 = nn.Conv2d(6, 16, 5)#卷积#//(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))self.fc1 = nn.Linear(16*77*77, 120)#全连接层,图片的维度为16,#(fc1): Linear(in_features=94864, out_features=120, bias=True)self.fc2 = nn.Linear(120, 84)#全连接层,输入120个特征输出84个特征self.fc3 = nn.Linear(84, 7)#全连接层,输入84个特征输出7个特征def forward(self, x):print("x.shape1: ", x.shape)x = self.pool(F.relu(self.conv1(x)))print("x.shape2: ", x.shape)x = self.pool(F.relu(self.conv2(x)))print("x.shape3: ", x.shape)x = x.view(-1, 16*77*77)print("x.shape4: ", x.shape)x = F.relu(self.fc1(x))print("x.shape5: ", x.shape)x = F.relu(self.fc2(x))print("x.shape6: ", x.shape)x = self.fc3(x)print("x.shape7: ", x.shape)return x

3.3、测试代码

img_path = "TestData/test_data/1/Apple2 (1).jpg" #使用相对路径
image = Image.open(img_path).convert('RGB')
image_pad = pad_image(image, (320, 320))
input = transform(image_pad).to(device).unsqueeze(0)
output = F.softmax(net(input), 1)
_, predicted = torch.max(output, 1)
score = float(output[0][predicted]*100)
print(class_map[predicted], " ", str(score)+" %")
plt.imshow(image_pad) # 显示图片

四、训练结果

4.1、LOSS损失函数

4.2、 ACC

4.3、单张图片识别准确率

四、小结

这次搭建的网络是基于深度学习框架Lenet,并自己做了一些修改完成。最终的训练的结果LOSS接近0,ACC接近100%。但是一般的识别率不会达到这么高,该模型可能会过拟合。可采取剪枝等操作减小过拟合。

pytorch深度学习框架——实现病虫害图像分类相关推荐

  1. 2021-7-26 pytorch深度学习框架学习

    1. Pytorch深度学习框架

  2. pytorch深度学习框架--gpu和cpu的选择

    pytorch深度学习框架–gpu和cpu的选择 基于pytorch框架,最近实现了一个简单的手写数字识别的程序,我安装的pytorch是gpu版(你也可以安装cpu版本的,根据个人需要),这里我介绍 ...

  3. 人工智能:PyTorch深度学习框架介绍

    目录 1.PyTorch 2.PyTorch常用的工具包 3.PyTorch特点 4.PyTorch不足之处 今天给大家讲解一下PyTorch深度学习框架的一些基础知识,希望对大家理解PyTorch有 ...

  4. pytorch深度学习框架—torch.nn模块(一)

    pytorch深度学习框架-torch.nn模块 torch.nn模块中包括了pytorch中已经准备好的层,方便使用者调用构建的网络.包括了卷积层,池化层,激活函数层,循环层,全连接层. 卷积层 p ...

  5. 【深度学习】基于PyTorch深度学习框架的序列图像数据装载器

    作者 | Harsh Maheshwari 编译 | VK 来源 | Towards Data Science 如今,深度学习和机器学习算法正在统治世界.PyTorch是最常用的深度学习框架之一,用于 ...

  6. [PyTorch] 深度学习框架PyTorch中的概念和函数

    Pytorch的概念 Pytorch最重要的概念是tensor,意为"张量". Variable是能够构建计算图的 tensor(对 tensor 的封装).借用Variable才 ...

  7. 开源基于PyTorch深度学习框架实现图卷积

    开源代码参考:学习与优化 Graph Convolutional Networks paper -> paper link -> github Distilling Knowledge F ...

  8. windows10使用cuda11搭建pytorch深度学习框架——运行Dlinknet提取道路(三)——模型精度评估代码完善

    重新调试好代码,使用Dinknet34模型对数据集进行训练 数据集大小为1480张图片 运行时间为2022年1月12日16:00 记录下该模型训练时间 但如何评估模型的精度也是一个问题,因此作如下总结 ...

  9. windows10使用cuda11搭建pytorch深度学习框架——运行Dlinknet提取道路(二)——代码运行问题解决

    运行程序 去github上下载Dlinknet的代码 https://github.com/zlckanata/DeepGlobe-Road-Extraction-Challenge 把数据集放进da ...

最新文章

  1. 关于div的滚动条滚动到底部,内容显示不全的问题。(已解决)
  2. 苹果成AI“收购狂魔”,5年买下25家公司
  3. Coursera吴恩达《卷积神经网络》课程笔记(3)-- 目标检测
  4. MyBatis查询结果resultType返回值类型详细介绍
  5. eval、json.parse()的介绍和使用注意点
  6. 删除sql下注册服务器
  7. 不修条地铁,都不好意思叫自己大城市
  8. VS2013中修改.dll工程项目的.lib和.dll的输出路径
  9. c语言禁止窗口关闭,无法关闭窗口的程序
  10. 用数据追女神:追女生如同创业
  11. springcloud分布式事务处理方案
  12. 软件分层的利与不利之处.txt
  13. 项目:识别Twitter用户性别
  14. 闪电网络开启BTC支付时代?他们不同意
  15. Python(三)微信公众号开发
  16. 2022年NPS基准:NPS分数达到多少算好?
  17. 魔兽争霸3 ce基址 偏移
  18. 华为MateBook E 12.6英寸 win11 16g+512g 轻评测
  19. html中auto是设置什么的,css中margin:auto什么意思?margin:auto属性的用法详解
  20. 程序员阵线联盟 之歌

热门文章

  1. 大一的计算机考试和英语考试,大一计算机期末考试和答案
  2. 使用 Fragment 处理 onActivityResult
  3. TightVNC怎么退出全屏
  4. unity事件系统3,三个博客脚本要一起用
  5. 电脑卖场 WIN7 SP1 x86 旗舰版 V201306 (IE9)【蓝天科技】
  6. 电脑C盘清理,你了解吗
  7. DSL学习总结 -- 絮絮叨叨
  8. dstwo linux n64,次世代?论坛惊现NDS用N64模拟器正在开发?
  9. 快递查询 (快递100)
  10. 2W五千字的C++基础知识整理汇总