Perceptron

原理

简单的感知机可以看作一个二分类,假定我们的公式为

f(x) = sign(w *x + b)

我们把 -b 做为一个标准,w* x 的结果与 -b 这个标准比较,

w*x > -b, f(x) = +1

w *x < -b, f(x) = -1

不难看出w是超平面的法向量,超平面上的向量与w的数量积为0。因此这个超平面就可以很好的区分我们的数据集。

而感知机就是来寻找w和b

优化方法

优化方法我们现有的方法比较多,诸如GD、SGD、Minibatch、Adam

当然我们的损失函数也包含多种,常见的有MSE, CrossEntropy.

这边简单展示一下MSE以及GD原理。

SoftMax

如果我们输出为多分类,那就成为一个SoftMax回归。

SoftMax回归和线性回归一样将输入特征与权重做线性叠加。与线性回归的一个主要不同在于,SoftMax回归的输出值个数等于标签里的类别数。

MLP

而我们给SoftMax回归增加隐藏层,就是我们所说的多层感知机,而

全连接层只是对数据做仿射变换,我们的方法是引入非线性变换,就是激活函数。

代码实现

这边选用CIFAR10数据集来做演示。CIFAR10包含10个类别,每个类别600张32x32的彩色图像。

1.导入依赖包

import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import os
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

2.加载数据集

这边对图片进行归一化处理。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
train_data = torchvision.datasets.CIFAR10(root="data",download=True,train=True,transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4,shuffle=True, num_workers=8)val_data = torchvision.datasets.CIFAR10(root="data",download=True,train=False,transform=transform)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=4,shuffle=True, num_workers=8)

3.定义模型及参数

用Sequential快速构建,对数据进行展平处理输入尺寸为图片尺寸 x 通道数,输出10分类,hidden layer设置为512。

net = nn.Sequential(nn.Flatten(),nn.Linear(1024*3, 512),nn.ReLU(),nn.Linear(512,10)
)

4.训练

损失计算选用交叉熵函数,优化器选用SGD,调用显卡运行。

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = 0.01)
epochs = 30device = "cuda:0"
train(net,train_loader,val_loader,epochs,optimizer,loss,device)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XrRdwV5J-1645528506210)(https://z3.ax1x.com/2021/10/26/5I1IN8.jpg)]

结果

可以看出我们的验证准确值过低,这主要是因为数据集特征不明显,我们在更换数据集验证。

更换数据集

选用7分类的海贼王图片进行训练,可以看出训练结果明显优于CIFAR数据集。

同时我们再挑选一张不在训练集的图片进行验证,发现结果正确。

Perceptron相关推荐

  1. sklearn MLP(多层感知机、Multi-layer Perceptron)模型使用RandomSearchCV获取最优参数及可视化

    sklearn MLP(多层感知机.Multi-layer Perceptron)模型使用RandomSearchCV获取最优参数及可视化 Deep Learning 近年来在各个领域都取得了 sta ...

  2. MLPclassifier,MLP 多层感知器的的缩写(Multi-layer Perceptron)

    先看代码(sklearn的示例代码): [python] view plain copy from sklearn.neural_network import MLPClassifier X = [[ ...

  3. Hinton神经网络公开课编程练习1 The perceptron learning algorithm

    为什么80%的码农都做不了架构师?>>>    本文由码农场同步,最新版本请查看原文:http://www.hankcs.com/ml/the-perceptron-learning ...

  4. 【Python-ML】SKlearn库感知器(perceptron) 使用

    # -*- coding: utf-8 -*- ''' Created on 2018年1月12日 @author: Jason.F @summary: Scikit-Learn库感知器学习算法 '' ...

  5. 【Python-ML】感知器学习算法(perceptron)

    1.数学模型   2.权值训练 3.Python代码 感知器收敛的前提是两个类别必须是线性可分的,且学习速率足够小.如果两个类别无法通过一个线性决策边界进行划分,要为模型在训练集上的学习迭代次数设置一 ...

  6. 感知器 Perceptron

    基本概念 线性可分:在特征空间中可以用一个线性分界面正确无误地分开两 类样本:采用增广样本向量,即存 在合适的增广权向量 a 使得: 则称样本是线性可分的.如下图中左图线性可分,右图不可分.所有满足条 ...

  7. DL之PerceptronAdalineGD:基于iris莺尾花数据集利用Perceptron感知机和AdalineGD算法实现二分类

    DL之Perceptron&AdalineGD:基于iris莺尾花数据集利用Perceptron感知机和AdalineGD算法实现二分类 目录 基于iris莺尾花数据集利用Perceptron ...

  8. ML:基于自定义数据集利用Logistic、梯度下降算法GD、LoR逻辑回归、Perceptron感知器、SVM支持向量机、LDA线性判别分析算法进行二分类预测(决策边界可视化)

    ML:基于自定义数据集利用Logistic.梯度下降算法GD.LoR逻辑回归.Perceptron感知器.支持向量机(SVM_Linear.SVM_Rbf).LDA线性判别分析算法进行二分类预测(决策 ...

  9. DL之perceptron:利用perceptron感知机对股票实现预测

    DL之perceptron:利用perceptron感知机对股票实现预测 目录 输出结果 实现代码 输出结果 更新-- 实现代码 import numpy as np import operator ...

  10. DL之Perceptron:Perceptron感知器(感知机/多层感知机/人工神经元)的简介、原理、案例应用(相关配图)之详细攻略

    DL之Perceptron:Perceptron感知器(感知机/多层感知机/人工神经元)的简介.原理.案例应用(相关配图)之详细攻略 目录 Perceptron的简介.原理 多层感知机 实现代码 案例 ...

最新文章

  1. pycharm debug后会出现 step over /step into/step into my code /force step into /step out 分别表示...
  2. ItemsControl 解析
  3. 正直、智慧、成熟、诚信——毒霸用人的基本原则
  4. CV_LOAD_IMAGE_COLOR 和 CV_BGR2RGBA找不到定义
  5. vim安装时报错:Depends:vim-common (=2:7.4.1689-3ubuntu1.4) but 2:8.0.1453-1ubuntu1.1 is to be installed
  6. Java 设计模式之Bridge桥接模式
  7. Ubuntu12.04 安装(无法将 grub-efi 软件包安装到/target/中,如果没有 GRUB 启动引导期,所安装的系统无法启动)...
  8. 前景检测算法(九)--PBAS算法
  9. @ResponseBody与@RestController的作用与区别
  10. MixGo V1.0 发布,混合型高性能 Go 框架
  11. idea swagger生成接口文档_Springboot结合swagger-ui自动生成接口文档
  12. 消息队列控制灯代码_基于ARM的智能灯光控制系统经验总结分享
  13. java 获取32位纯数字 或字母与数字结合的唯一id
  14. ERP、CRM、SCM、电子商务、BI、ITSS
  15. 区块链赋能数字交通建设 PPT
  16. react-pdf预览pdf
  17. 什么是索引回表,如何避免(索引覆盖)
  18. 笔记本电脑变WiFi和WiFi共享精灵的应用教程比较
  19. Android系统下载管理DownloadManager功能介绍及使用示例
  20. windows下编译Sqlite-3.38.0及使用(存储json)

热门文章

  1. 专家称新冠肺炎传播途径包括气溶胶传播,意味着什么?应如何防控?
  2. 提高国外 VPS 云主机性能(Linux系统)的 4 个步骤
  3. c语言老鼠迷宫程序,C语言经典算法——老鼠走迷宫(二)
  4. 刺激战场pc服务器没有响应,刺激战场PC端玩不了怎么办 PC端玩不了解决方法[多图]...
  5. 【艾琪出品】-【计算机】测试题系列五参考资料
  6. Mac安装git flow
  7. 基于Leap Motion设备及Unity3D引擎的自定义手势识别
  8. 克莱斯勒召回1604辆牧马人,涉及高压电池保险丝存在安全隐患
  9. 树莓派安装Ubuntu Mate解决无法连接WiFi问题,并部署Ros系统
  10. C语言报告多少字,一个统计字数的程序