使用PyTorch实现对花朵的分类
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
PyTorch是一个非常适合初学者的高度可靠且强大的机器学习库。自2016年10月以来,它已经开源并由Facebook维护,并被开发人员用于研究其原型,以部署最先进的深度学习应用程序。与TensorFlow等其他机器学习库相比,PyTorch更加直观,并具有实现模型的Python方式。
决定要分类什么?
识别花朵的类型需要某种形式关于花朵的知识,人必须事先看过花朵才能识别花朵。同样,对于计算机,很难对算法进行硬编码以识别花朵的类型。到目前为止,机器学习是从给定的大量花朵图片中识别花朵名称的唯一选择。这使得使用深度学习实现花识别任务对于每个初学者来说都非常有趣。
花朵识别数据集对于像我这样的初学者而言,是一个很好的数据集,可用于实施和练习各种机器学习模型。
使用什么数据集?
我们将使用Kaggle上可用的花朵识别数据集。数据集链接:https ://www.kaggle.com/alxmamaev/flowers-recognition
预处理数据集
我们将使用神经网络对花朵进行分类。神经网络是深度学习的一种形式,最适合当今的图像分类。我们首先导入所有需要的模块以运行我们的代码。
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
import torchvision
from torchvision.datasets.utils import download_url
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
我们导入了PyTorch的组件以及NumPy和Pandas等数据科学库。图片是非结构化数据,为了将其输入到我们的深度学习模型中,我们必须将其转换为张量。我们需要对图像进行预处理,然后才能为模型做好准备。我们首先使用ImageFolder 存在于torchvision.datasets 来准备数据集。ImageFolder是一个非常有用的工具当图像存储在不同的文件夹中,其中每个文件夹都充当类名。PyTorch还具有其他更简单的准备数据集的方式,我们可以在其中准备自己的自定义数据集。
transformer = torchvision.transforms.Compose([ # Applying Augmentationtorchvision.transforms.Resize((224, 224)),torchvision.transforms.RandomHorizontalFlip(p=0.5),torchvision.transforms.RandomVerticalFlip(p=0.5),torchvision.transforms.RandomRotation(30),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]
)
dataset = ImageFolder(base_dir, transform=transformer)
我们还习惯于transforms.Compose将图像转换为张量并应用其他图像增强技术。此外,在将各种图像加载到数据集时,请阅读各种变换技术并应用于图像。我们应该使用程序加载图像,以便可以每次分批添加数据集,并且可以优化效率。
定义模型
我们可以使用从PyTorch类继承的类来定义深度学习模型的框架 nn.Module.
def accuracy(outputs, labels): _, preds = torch.max(outputs, dim=1)return torch.tensor(torch.sum(preds == labels).item() / len(preds))class ImageClassificationModel(nn.Module):def training_step(self, batch):images, labels = batchout = self(images) # Generate predictionsloss = F.cross_entropy(out, labels) # Calculate lossreturn lossdef __init__(self):super().__init__()self.network = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2), # output: 64 x 16 x 16nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2), # output: 128 x 8 x 8nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2), # output: 256 x 4 x 4nn.Flatten(),nn.Linear(256*28*28, 1024),nn.ReLU(),nn.Linear(1024, 512),nn.ReLU(),nn.Linear(512, 5))def forward(self, xb):return self.network(xb)def validation_step(self, batch):images, labels = batchout = self(images) # Generate predictionsloss = F.cross_entropy(out, labels) # Calculate lossacc = accuracy(out, labels) # Calculate accuracyreturn {'val_loss': loss.detach(), 'val_acc': acc}def validation_epoch_end(self, outputs):batch_losses = [x['val_loss'] for x in outputs]epoch_loss = torch.stack(batch_losses).mean() # Combine lossesbatch_accs = [x['val_acc'] for x in outputs]epoch_acc = torch.stack(batch_accs).mean() # Combine accuraciesreturn {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}def epoch_end(self, epoch, result):print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['train_loss'], result['val_loss'], result['val_ac']))
训练模型
首先训练模型,让我们将超参数设置为:
num_epochs = 10
opt_func = torch.optim.Adam
lr = 0.001
现在,在将模型运行10个epoach后,我们可以看到使用基本的卷积神经网络(CNN)模型达到了约65%。
测试模型
65%是一个很好的结果,因为我以前曾尝试过使用带有一些隐藏层的简单神经网络(NN),结果仅为40%左右。因此,CNN非常适合对图像进行分类,因为它们有比其他形式的机器学习更好的检测模式。
使用转移学习
现在让我们再次尝试使用已经定义的模型(如Resnet-18)进行转移学习,以改善模型的预测。使用相同的超参数集,我们的测试集中可以达到82%左右,这是非常令人印象深刻的。如果我们使用其他更好的CNN架构,例如Resnet50,Inception V3等,则可以进一步改善结果。
plot_accuracies(history)
保存模型
训练完成后,我们必须保存我们的模型,以便我们可以使用它来根据模型生成预测,甚至将来可以进行更多训练。
weights_fname = 'flower-resnet.pth'
torch.save(model.state_dict(), weights_fname)
产生预测
每个机器学习周期的目标是创建一个可被用于对常规数据进行分类的模型。这可以通过几行python代码为最终用户实现模型。
def predict_image(img, model):# Convert to a batch of 1xb = to_device(img.unsqueeze(0), device)# Get predictions from modelyb = model(xb)# Pick index with highest probability_, preds = torch.max(yb, dim=1)# Retrieve the class labelreturn dataset.classes[preds[0].item()]img, label = test_ds[2]
plt.imshow(img.permute(1, 2, 0))
print('Label:', dataset.classes[label], ', Predicted:', predict_image(img, model))
Label: sunflower , Predicted: sunflower
我们还可以使用服务器上的模型来识别花朵的类型。该模型可以轻松部署在服务器上,以供最终用户识别不同类型的花朵。
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
使用PyTorch实现对花朵的分类相关推荐
- 基于pytorch后量化(mnist分类)---浮点训练vs多bit后量化vs多bit量化感知训练效果对比
基于pytorch后量化(mnist分类)-浮点训练vs多bit后量化vs多bit量化感知训练效果对比 代码下载地址:下载地址 试了 bit 数为 1-8 的准确率,得到下面这张折线图: 发现,当 b ...
- 【小白学习PyTorch教程】十五、BERT:通过PyTorch来创建一个文本分类的Bert模型
@Author:Runsen 2018 年,谷歌发表了一篇题为<Pre-training of deep bidirectional Transformers for Language Unde ...
- Pytorch TextCNN实现中文文本分类(附完整训练代码)
Pytorch TextCNN实现中文文本分类(附完整训练代码) 目录 Pytorch TextCNN实现中文文本分类(附完整训练代码) 一.项目介绍 二.中文文本数据集 (1)THUCNews文本数 ...
- Pytorch实现102类鲜花分类(102 Category Flower Dataset)
Pytorch实现102类鲜花分类(VGG19和ResNet152模型) 本文主要讲解该算法的实现过程,原理部分需读者自行研究,可以找一些论文之类的. 实验环境 python3.6+pytorch1. ...
- 【项目实战课】基于Pytorch的EfficientNet血红细胞分类竞赛实战
欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的EfficientNet血红细胞分类竞赛实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题, ...
- CNN02:Pytorch实现VGG16的CIFAR10分类
CNN02:Pytorch实现VGG16的CIFAR10分类 1.VGG16的网络结构和原理 VGG的具体网络结构和原理参考博客: https://www.cnblogs.com/guoyaohua/ ...
- 使用Pytorch实现图像花朵分类
基于pytorch-classifier这个源码进行实现的图像分类 代码的介绍在这个链接里面,这篇博客主要是为了带着大家通过实践的方式熟悉一下代码的使用,并且了解相关功能. 1. 下载相关资料 这里我 ...
- PyTorch框架:(3)使用PyTorch框架构构建神经网络分类任务
目录 0.背景 1.分类任务介绍: 2.网络架构 3.手写网络 3.1.读取数据集 3.2.查看数据集 3.3将x和y转换成tensor的格式 3.4.定义model 0.背景 其实分类和回归本质上没 ...
- 迁移学习CNN图像分类模型 - 花朵图片分类
训练一个好的卷积神经网络模型进行图像分类不仅需要计算资源还需要很长的时间.特别是模型比较复杂和数据量比较大的时候.普通的电脑动不动就需要训练几天的时间.为了能够快速地训练好自己的花朵图片分类器,我们可 ...
最新文章
- oracle library cache lock,【案例】Oracle等待事件library cache lock产生原因和解决办法...
- golang教程汇总
- outlook2007 未知错误,代码0x80040600解决方法
- 个人工作总结09(第二阶段)
- 200 道算法面试题集锦!Python 实现,含华为、BAT 等校招真题!
- IOS8 Playground介绍
- mysql 自身参照自身_mysql个人散乱笔记,慎重参考
- 本地分发_2020年分发Python应用程序的12个热门途径
- 编码 括号_Java编码规范整理汇总
- TensorFlow笔记(10) CheckPoint
- CCF201312-1 出现次数最多的数
- hexo搭建个人博客_hexo 搭建个人博客
- linux学习之用户的切换
- 怒拒Facebook:语音识别大神、Kaldi之父将加盟小米
- 监听input框值得改变
- 电子线路(线性部分)——第一章 晶体二极管
- 学计算机的女生选择公务员还是考研,女生本科毕业!考研好,还是考公务员更好?...
- Python: PS 滤镜--素描
- 古人的谦称、尊称与贱称
- 购物商城系统设计与实现总结_商品列表展示页的实现
热门文章
- 通过ppt制作圆形图标及自定义形状图形制作
- idea修改注释颜色
- vscode 选中相同批量更改_VSCode同时更改所有相同的变量名或类名的图文教程
- excel中第一列相同,合并第2列中相应单元格内容,并用顿号隔开
- oracle 新建绑定变量,在Oracle中,绑定变量是什么?绑定变量有什么优缺点?
- android浮浮窗广告实现,浮窗广告的实现原理
- python里面的缩进是什么意思_Python缩进规则(一看即懂)
- 怎么联通linux网络设置,linux用联通的网络如何上网.docx
- 不错的一篇小说《风月恨》
- 解决Windows 2012、2016、2019Server服务器系统Intel 网卡驱动数字签名安装方法