文章目录

  • 简述
  • 理论基础
    • 回归
    • softmax
    • 损失函数
  • 读取数据
  • 初始化模型参数
  • 实现softmax运算
  • 定义模型
  • 定义损失函数
  • 计算分类准确率
  • 训练模型
  • 预测
  • 整体代码
    • d2lzh_pytorch.py
    • main.py

简述

这次,将会使用到Fashion-MNIST数据集和操作

理论基础

回归

假设有输入的特征数有4个,分类的标签有3个.则回归方程为:
o1=x1w11+x2w21+x3w31+x4w41+b1o_1=x_1w_{11}+x_2w{21}+x_3w_{31}+x_4w_{41}+b_1o1​=x1​w11​+x2​w21+x3​w31​+x4​w41​+b1​
o2=x1w12+x2w22+x3w32+x4w42+b1o_2=x_1w_{12}+x_2w{22}+x_3w_{32}+x_4w_{42}+b_1o2​=x1​w12​+x2​w22+x3​w32​+x4​w42​+b1​
o3=x1w11+x2w21+x3w31+x4w43+b1o_3=x_1w_{11}+x_2w{21}+x_3w_{31}+x_4w_{43}+b_1o3​=x1​w11​+x2​w21+x3​w31​+x4​w43​+b1​

softmax

这样会得[y1′,y2′,y3′][y_1^{'},y_2^{'},y_3^{'}][y1′​,y2′​,y3′​],哪个数字更大,就取哪个.表示该样本属于这个标签.通过softmax可以得到:

损失函数

交叉熵损失函数

读取数据

使用Fashion-MNIST数据集,并设置批量大小为256

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型参数

型的输入向量的长度是 28×28=78428×28=784:该向量的每个元素对应图像中每个像素。由于图像有10个类别,单层神经网络输出层的输出个数为10,因此softmax回归的权重和偏差参数分别为784×10和1×101×10的矩阵

num_inputs = 784
num_outputs = 10W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)

实现softmax运算

def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

定义模型

def net(X):return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

定义损失函数

def cross_entropy(y_hat, y):return - torch.log(y_hat.gather(1, y.view(-1, 1)))

计算分类准确率

def accuracy(y_hat, y):return (y_hat.argmax(dim=1) == y).float().mean().item()

训练模型

'''训练模型'''
num_epochs, lr = 5, 0.1
d2l.train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [w, b], lr)

预测

X, y = iter(test_iter).next()true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]d2l.show_fashion_mnist(X[0:9], titles[0:9])

整体代码

d2lzh_pytorch.py

import random
from IPython import display
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sysdef use_svg_display():# 用矢量图显示display.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):use_svg_display()# 设置图的尺寸plt.rcParams['figure.figsize'] = figsize'''给定batch_size, feature, labels,做数据的打乱并生成指定大小的数据集'''
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size): #(start, staop, step)j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) #最后一次可能没有一个batchyield features.index_select(0, j), labels.index_select(0, j)'''定义线性回归的模型'''
def linreg(X, w, b):return torch.mm(X, w) + b'''定义线性回归的损失函数'''
def squared_loss(y_hat, y):return (y_hat - y.view(y_hat.size())) ** 2 / 2'''线性回归的优化算法 —— 小批量随机梯度下降法'''
def sgd(params, lr, batch_size):for param in params:param.data -= lr * param.grad / batch_size #这里使用的是param.data'''MINIST,可以将数值标签转成相应的文本标签'''
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]'''定义一个可以在一行里画出多张图像和对应标签的函数'''
def show_fashion_mnist(images, labels):use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()'''获取并读取Fashion-MNIST数据集;该函数将返回train_iter和test_iter两个变量'''
def load_data_fashion_mnist(batch_size):mnist_train = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=True, download=True,transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=False, download=True,transform=transforms.ToTensor())if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter'''评估模型net在数据集data_iter的准确率'''
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / n'''训练模型,softmax'''
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:sgd(params, lr, batch_size)else:optimizer.step()train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

main.py

import torch
import torchvision
import numpy as np
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l# 使用Fashion-MNIST数据集
'''获取和读取数据'''
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)'''初始化模型参数'''
num_inputs = 784
num_outputs = 10
w = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype = torch.float)
w.requires_grad_(requires_grad = True)
b.requires_grad_(requires_grad = True)'''softmax运算'''
def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True)return X_exp/partition'''定义模型'''
def net(X):return softmax(torch.mm(X.view(-1, num_inputs), w) + b)'''定义损失函数'''
def cross_entropy(y_hat, y):return - torch.log(y_hat.gather(1, y.view(-1, 1)))'''计算准确率'''
def accuracy(y_hat, y):return (y.hat.argmax(dim=1) == y).float().mean().item()'''训练模型'''
num_epochs, lr = 5, 0.1
d2l.train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [w, b], lr)'''预测'''
X, y = iter(test_iter).next()true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]d2l.show_fashion_mnist(X[0:9], titles[0:9])

Pytorch手动实现softmax回归相关推荐

  1. [PyTorch]手动实现logistic回归(只借助Tensor和Numpy相关的库)

    文章目录 实验要求 一.生成训练集 二.数据加载器 三.手动构建模型 3.1 logistic回归模型 3.2 损失函数和优化算法 3.3 模型训练 四.训练结果 实验要求 人工构造训练数据集 手动实 ...

  2. pytorch学习笔记(九):softmax回归的简洁实现

    文章目录 1. 获取和读取数据 2. 定义和初始化模型 3. softmax和交叉熵损失函数 4. 定义优化算法 5. 训练模型 6. 总代码 7.小结 使用Pytorch实现一个softmax回归模 ...

  3. Lesson 12.5 softmax回归建模实验

    Lesson 12.5 softmax回归建模实验 接下来,继续上一节内容,我们进行softmax回归建模实验. 导入相关的包 # 随机模块 import random# 绘图模块 import ma ...

  4. 手动以及使用torch.nn实现logistic回归和softmax回归

    其他文章 手动以及使用torch.nn实现logistic回归和softmax回(当前文章) 手动以及使用torch.nn实现前馈神经网络实验 文章目录 任务 一.Pytorch基本操作考察 1.1 ...

  5. 【深度学习】基于Pytorch的softmax回归问题辨析和应用(一)

    [深度学习]基于Pytorch的softmax回归问题辨析和应用(一) 文章目录 1 概述 2 网络结构 3 softmax运算 4 仿射变换 5 对数似然 6 图像分类数据集 7 数据预处理 8 总 ...

  6. 【深度学习】基于Pytorch的softmax回归问题辨析和应用(二)

    [深度学习]基于Pytorch的softmax回归问题辨析和应用(二) 文章目录1 softmax回归的实现1.1 初始化模型参数1.2 Softmax的实现1.3 优化器1.4 训练 2 多分类问题 ...

  7. [pytorch、学习] - 3.7 softmax回归的简洁实现

    参考 3.7. softmax回归的简洁实现 使用pytorch实现softmax import torch from torch import nn from torch.nn import ini ...

  8. pytorch学习笔记(八):softmax回归的从零开始实现

    文章目录 1. 获取和读取数据 2. 初始化模型参数 3. 实现softmax运算 4. 定义模型 5. 定义损失函数 6. 计算分类准确率 7. 训练模型 8. 预测 9. 总代码 小结 这一节我们 ...

  9. 【深度学习】基于MindSpore和pytorch的Softmax回归及前馈神经网络

    1 实验内容简介 1.1 实验目的 (1)熟练掌握tensor相关各种操作: (2)掌握广义线性回归模型(logistic模型.sofmax模型).前馈神经网络模型的原理: (3)熟练掌握基于mind ...

最新文章

  1. 考研计算机专业英语题型,考研英语一题型及分值
  2. qualcomm memory dump 抓取方法
  3. doc命令下查看java安装路径
  4. 博客园 CSS 代码定制
  5. 使用腾讯开发平台获取QQ用户数据资料
  6. 工业机器人导轨 百度文库_工业机器人或许开创一个全新的PLC时代
  7. 深浅克隆面试题汇总——附详细答案
  8. 网页Object标签 遮盖DIV标签解决方法
  9. python几种排序_Python实现几种排序算法
  10. pylon 内存泄露的问题
  11. 该学Java或.NET?
  12. 变量和算术运算之变量(三)
  13. 服务器 'xxx' 上的 MSDTC 不可用。
  14. 后端游戏引擎调研-2021.07
  15. python-docx读取word段落的样式字体
  16. python汉字转gb2312_PYTHON中UTF-8向GB2312编码转换的问题一解
  17. push_back()函数的用法
  18. Data too long for column ‘xxxx‘ at row 1 解决办法
  19. 苹果8a1660是什么版本_苹果A1660是什么型号?
  20. MATLAB如何固定text在图中的相对位置

热门文章

  1. 拔掉电源会怎样?GaussDB(for Redis)双活让你有备无患
  2. JAVA SOCKET实现全双工通信
  3. 知网导入EndNote
  4. select html默认选中的值,HTML/jquery中的select标签设置默认选中取值
  5. LTE(4G) RRC消息流程
  6. 软件工程——流程图和盒图
  7. java jsp+servlet+mysql实现登录网页设计
  8. Java实现二分图的最大匹配
  9. Docker启动Nacos(单例)、Redis
  10. java将字符串转换为大写或小写