欢迎关注 “小白玩转Python”,发现更多 “有趣”

引言

本文提供了一个使用PyTorch构建一个非常基本的 Logistic模型的简单步骤,并将其应用于猴子图像的分类。

首先我们可以从下面的网址下载用于模型训练和测试的数据集:

https://www.kaggle.com/slothkong/10-monkey-species

这个数据集包含了10种猴子的图片,包括:

n0 — alouattapalliata

n1 — erythrocebuspatas

n2 — cacajaocalvus

n3 — macacafuscata

n4 — cebuellapygmea

n5 — cebuscapucinus

n6 — micoargentatus

n7 — saimirisciureus

n8 — aotusnigriceps

n9 — trachypithecusjohnii

在数据集中有两个文件: 训练文件和验证文件。训练和验证文件都包含10个标记为 n0-n9的子文件夹,如上所述,它们各代表一种猴子。每个猴子的图像至少是400x300像素。训练文件中可用的总图像为1096,验证文件夹中可用的总图像为272个图像。训练图像将用于训练和验证模型,而验证图像将用作测试图像,以报告模型的最终准确性。

第一步: 加载和查看数据

构建任何机器学习模型的第一步是理解基础数据。让我们首先读取图像数据,查看其中的一些图像,并将图像数据转换为张量。

导入相关库:

# Import relevant librariesimport torch
import jovian
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torch.utils.data import DataLoader, TensorDataset, random_splitfrom PIL import Image
import glob

设置超参数:

# Hyperparameters
batch_size = 16
learning_rate = 1e-3jovian.reset()
jovian.log_hyperparams(batch_size=batch_size, learning_rate=learning_rate)

载入图像(包括训练图像和测试图像)并将图像转换为 float32类型的张量:

# Load image and convert image to multidimensional array
def image_to_array(images_folder):dataset = []for i in range(10):for filename in glob.glob(images_folder + "/n{}/*.jpg".format(i)):im = Image.open(filename)im = im.resize((400,300))pixels = np.asarray(im).astype('float32')pixels /= 255.0pixels = torch.from_numpy(pixels)dataset.append((pixels, i))return dataset
# Load Training Data
train_dataset = image_to_array("monkey_species/training/training")# Load Test Data
test_dataset = image_to_array("monkey_species/validation/validation")

查看示例图片:

# View a sample Image
img_tensor, label = train_dataset[0]
print(img_tensor.shape)plt.imshow(img_tensor)
print('Label:', label)

第二步:为训练准备数据

上面已经将图像数据转换为张量,我们可以开始准备用于模型训练,验证和测试的数据了。

训练数据——将用于训练模型(通过计算交叉熵损失和使用梯度下降法调整模型的权重)。

验证数据——将用于在训练时评估模型,并调整超参数(学习率和批量大小)。

测试数据——将用于计算模型的准确度。

从训练数据集创建验证集(20%的训练数据将用于验证)。同时生成批量的训练、验证和测试数据:

# Training and Validation dataset
val_size = round(0.2*len(train_dataset))
train_size = len(train_dataset) - val_size
train_ds, val_ds = random_split(train_dataset, [train_size, val_size])# Dataloaders
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# Verify batch
for xb, yb in train_loader:print("inputs:", xb)print("targets:", yb)break

第三步:训练模型

现在我们已经为训练、验证和测试准备好了数据。我们可以使用训练数据集开始训练模型,并使用验证集对其进行验证。

为了进行训练,让我们创建一个自定义模型类和一些实用程序函数,如下所示:

input_size = 300*400*3
num_classes = len(label_dict)
print("Input Size: ", input_size, "\nNumber of Classes: ", num_classes)
class MonkeyClassificationModel(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(input_size, num_classes)def forward(self, xb):xb = xb.reshape(-1, input_size)out = self.linear(xb)return outdef training_step(self, batch):images, labels = batch out = self(images)                  # Generate predictionsloss = F.cross_entropy(out, labels) # Calculate lossreturn lossdef validation_step(self, batch):images, labels = batch out = self(images)                    # Generate predictionsloss = F.cross_entropy(out, labels)   # Calculate lossacc = accuracy(out, labels)           # Calculate accuracyreturn {'val_loss': loss.detach(), 'val_acc': acc.detach()}def validation_epoch_end(self, outputs):batch_losses = [x['val_loss'] for x in outputs]epoch_loss = torch.stack(batch_losses).mean()   # Combine lossesbatch_accs = [x['val_acc'] for x in outputs]epoch_acc = torch.stack(batch_accs).mean()      # Combine accuraciesreturn {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}def epoch_end(self, epoch, result):print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))model = MonkeyClassificationModel()
list(model.parameters())
def accuracy(outputs, labels):_, preds = torch.max(outputs, dim=1)return torch.tensor(torch.sum(preds == labels).item() / len(preds))def evaluate(model, val_loader):outputs = [model.validation_step(batch) for batch in val_loader]return model.validation_epoch_end(outputs)def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):history = []optimizer = opt_func(model.parameters(), lr)for epoch in range(epochs):# Training Phase for batch in train_loader:loss = model.training_step(batch)loss.backward()optimizer.step()optimizer.zero_grad()# Validation phaseresult = evaluate(model, val_loader)model.epoch_end(epoch, result)history.append(result)return history
history1 = fit(100, learning_rate, model, train_loader, val_loader)
history2 = fit(100, learning_rate/10, model, train_loader, val_loader)
history3 = fit(100, learning_rate/10, model, train_loader, val_loader)
history4 = fit(100, learning_rate/100, model, train_loader, val_loader)
history5 = fit(100, learning_rate/1000, model, train_loader, val_loader)history = history1 + history2 + history3 + history4 + history5accuracies = [r['val_acc'] for r in history]
plt.plot(accuracies, '-x')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Accuracy vs. No. of epochs')
# Evaluate on test dataset
result = evaluate(model, test_loader)
result

第四步: 使用训练后的模型进行预测

模型经过训练后,我们可以使用该模型来预测,即对测试图像进行分类。让我们定义一个分类图片的函数:

def predict_image(input_img, model):inputs = input_img.unsqueeze(0)predictions = model(inputs)_, preds  = torch.max(predictions, dim=1)return preds[0].item()
label_dict = {0:"alouattapalliata", 1:"erythrocebuspatas", 2:"cacajaocalvus", 3:"macacafuscata",4:"cebuellapygmea", 5:"cebuscapucinus", 6:"micoargentatus", 7:"saimirisciureus",8:"aotusnigriceps", 9:"trachypithecusjohnii"}

部分测试结果如下:

预测正确

预测错误

保存模型:

# Save
torch.save(model.state_dict(), 'monkey_classification.pth')

模型精度及提高精度的思路

使用测试数据集计算模型的准确率约为56.6%。使用这个相当简单的 Logistic模型模型得到的准确度很差。因此,对于这个特定的数据集,需要考虑一个更复杂的机器学习或者深度学习模型。相信使用卷积神经网络(CNN)或深层神经网络(DNN)可以获得更好的分类精度。

可以通过以下策略进一步提高模型的准确性:

1. 增加数据

2. 使用CNN模型

3. 更改优化函数

结束语

尽管 Logistic模型的准确性很差,但本篇文章展示了如何使用 PyTorch 构建一个简单的 Logistic模型。类似的步骤可以应用于任何简单的线性分类问题。

·  END  ·

HAPPY LIFE

在PyTorch中使用Logistic回归进行10种猴子物种分类相关推荐

  1. Pytorch中CNN图像回归问题预测值都一样

    ** Pytorch中CNN图像回归问题预测值都一样 ** 上网也查阅了许多资料,然后对比各种方法都试了一遍,归结为以下几点: 1.出现预测值都一样的情况,一般都是在某一层梯度消失了,然后导致输入到下 ...

  2. 逆境中激励员工士气的10种方法

    逆境中激励员工士气的10种方法 如果你是主管,心里正想着要召集部属全员到齐,然后训示:"各位,现在是我们必须齐心协力的时刻否则我们就要经营不下去了."那你最好三思.再三思-- 现在 ...

  3. VGG16对10种猴子分类

    VGG16对10种猴子分类 dataset: import os import torch import numpy as np from PIL import Image from torch.ut ...

  4. [PyTorch]手动实现logistic回归(只借助Tensor和Numpy相关的库)

    文章目录 实验要求 一.生成训练集 二.数据加载器 三.手动构建模型 3.1 logistic回归模型 3.2 损失函数和优化算法 3.3 模型训练 四.训练结果 实验要求 人工构造训练数据集 手动实 ...

  5. 机器学习中的Logistic回归算法(LR)

    Logistic回归算法(LR) 算法简介 LR名为回归,实际是一种分类算法.其针对输入样本集 x x,假设的输出结果 y=hθ(x)y=h_{\theta}(x) 的取值范围为 y=[0,1] y= ...

  6. 激活层是每一层都有吗_我心目中最值得栽种的10种藤本月季,每一种都很优秀,你喜欢吗...

    我是从2015年开始接触藤本月季,当时藤本还比较贵,一株好一点的苗要二三十块,而且花市没有,都要从网上购买,一不小心就会买到小白花.截止目前,栽种过的藤本月季品种至少也有50种以上了,但真正保留下来的 ...

  7. 从Mybatis源码中,学习到的10种设计模式

    一.前言:小镇卷码家 总有不少研发伙伴问小傅哥:"为什么学设计模式.看框架源码.补技术知识,就一个普通的业务项目,会造飞机不也是天天写CRUD吗?" 你说的没错,但你天天写CRUD ...

  8. 从流程的自动化中获得最大价值的10种方式

    流程自动化很好,如果它可以节省时间并减少错误.但是如果它不能在业务流程中"很好地契合",那么会难以得到普及.问问有谁没有对语音助手感到伤脑筋. 所幸的是,某些最佳实践让你可以从流程 ...

  9. 从COVID-19大流行中汲取哪些教训?10种方法帮CIO预防下一次危机

    导读:以下这些从COVID-19大流行中汲取的经验教训,可以帮助IT领导者和TI经理为下一次紧急情况做好准备. 您的IT团队准备好应对COVID-19大流行了吗?您的网络可以一次使用VPN处理所有员工 ...

最新文章

  1. Paramiko: SSH and SFTP With Python
  2. 14岁上大学,29岁拿下教职,如今这位华裔学者拿下Jeffrey Elman大奖
  3. 云计算平台中虚拟专用网和VPC有什么区别?
  4. leetcode206.反转链表 解题思路(简单)
  5. 昨天订了一台FSC Lifebook S6220
  6. linux中使用lftp上传下载文件
  7. leetcode547. 省份数量
  8. tars 部署 oracle,Tars 部署介绍(必看)
  9. Redis发布订阅机制
  10. CTS(22)---GMS认证-Android8.x新增cts测试(VTS下测试GSI版本)
  11. JAVA中如何全局监听鼠标事件
  12. android刷机教程 华为,华为的安卓手机该怎么刷机
  13. 谷歌浏览器访问接口无返回
  14. 宏定义语句的 GPBCON 、GPBDAT、GPBUP 地址(老师布置的作业,没接触过,不懂,求大神解答,万分感谢)
  15. 潘金莲——中国女性解放思想的先驱《其实我的心没走》
  16. 自监督学习论文、代码汇总
  17. RDKit | 基于RDKit和SMARTS的化学反应处理
  18. 【统计学】从样本到总体
  19. 【水文模型】04 参数识别与敏感性分析方法
  20. SK海力士拟2022年后投资千亿美元新建4座半导体工厂

热门文章

  1. IOS 9.3.3更后打电话没声音解决方法
  2. 【MySQL 17】安装异常:Could not open file ‘/var/log/mysql/mysqld.log‘ for error logging: Permission denied
  3. 找情感挽回大师拯救情感的究竟是些什么人
  4. 质量内建是规模化敏捷(规模化研发交付)的核心
  5. 11月27日云栖精选夜读:阿里毕玄:智能时代,运维工程师在谈什么? 飞
  6. 有的时候,我们有需要将由不同栏位获得的资料串连在一起
  7. 特征融合 Pytorch concat串连两个预训练特征
  8. 计算机屏幕频率是什么,电脑“频率”什么意思?CPU、显卡、内存、显示器频率你知道多少?...
  9. Ebay Trading API整理
  10. 未磁科技完成超亿元A轮融资,核心团队毕业于北航