按照模式分类课本写的代码,如有错误欢迎指正!

main.m

%程序运行可能会需要3-5分钟的时间,请耐心等待。

clear;

%已对lms.mat进行随机打乱,并将Y由标量化为[1,10]矩阵形成data.mat

load data;

%定义测试集和验证集并增加偏置,测试集与验证集比例为4:1

testX = X(1:4000,:);

testY = Y(1:4000,:);

verifyX = X(4001:end,:);

verifyY = Y(4001:end,:);

testX = [ones(4000,1),testX];

%定义权重Wij和权重Wjk,并增加偏置

%输入层有400个单元,隐藏层有25个单元,输出层有10个单元

Wij = (rand(401,25)*2 - 1)*0.1;

Wjk = (rand(26,10)*2 - 1)*0.1;

%设置学习率和epoch,batch为整个数据集

epoch = 1000;

eta = 0.0005;

count = 1;

%变量初始化完成

%前馈与反向传播运算

while(count <= epoch)

%采用批量梯度下降法

%输入层与隐藏层的静激活

netj = testX*Wij;

%激活函数使用sigmoid函数

yj = sigmoid(netj);

%对第二层加入偏置

yj = [ones(4000,1),yj];

%第二层的静激活

netk = yj*Wjk;

%第二层的激活sigmoid函数

zk = sigmoid(netk);

%批量学习算法,计算损失函数

J = (norm(zk-testY)^2)/(2*4000);

%损失可视化用

plot_J(1,count) = J;

%开始反向传播

%计算两层sigmoid激活函数的导数

dnetj = dsigmoid(netj);

dnetk = dsigmoid(netk);

%计算deltaWjk,在批量梯度下降算法中应累加

deltaWjk = zeros(10,26);

for i = 1:4000

%bug:这里计算的是内积——已修复

deltaWjk = deltaWjk + eta*((testY(i,:) - zk(i,:))'.*dnetk(i,:)')*yj(i,:);

end

Wjk = deltaWjk'+Wjk;

%计算deltaWij,在批量梯度下降算法中应累加

deltaWij = zeros(25,401);

for i = 1:4000

deltak = (testY(i,:) - zk(i,:)).*dnetk(i,:);

%去掉偏重,权重减少一维

tmp = zeros(25, 1);

for j = 1:10

tmp = tmp + eta*Wjk(2:end,j)*deltak(j);

end

deltaWij = deltaWij + (tmp.*dnetj(i,:)')*testX(i,:);

end

Wij = deltaWij'+Wij;

count = count+1;

end

pre_res = predict(verifyX,Wij,Wjk);

acc = accuracy(verifyY, pre_res);

fprintf('epoch = %d\n',epoch);

fprintf('learning_rate = %f\n',eta);

fprintf('第一次epoch的cost: %f\n', plot_J(1));

fprintf('最后一次epoch的cost: %f\n',plot_J(end));

fprintf('测试集的正确率为%f\n',acc);

%为了使cost的减少更直观,从第十次开始画

plot(plot_J(10:end));

title('损失函数变化');

xlabel('迭代次数');

ylabel('cost');

predict.m

function [res] = predict2(x,Wij,Wjk)

%采用批量梯度下降法

%对x加入偏置

x = [ones(1000,1), x];

%输入层与隐藏层的静激活

netj = x*Wij;

%激活函数使用ReLU函数

yj = sigmoid(netj);

%对第二层加入偏置

yj = [ones(1000,1),yj];

%第二层的静激活

netk = yj*Wjk;

%第二层的激活Relu函数

res = sigmoid(netk);

end

sigmoid.m

function [y] = sigmoid(x)

y = 1./(1+exp(-x));

end

dsigmoid.m

function [y] = dsigmoid(x)

tem = 1./(1+exp(-x));

y = tem.*(1 - tem);

end

实验结果:

matlab实现BP神经网络minst手写数字识别相关推荐

  1. MATLAB--基于BP神经网络的手写数字识别

    MATLAB–基于BP神经网络的手写数字识别 在干活的过程中整理下来的,希望对大家有帮助. 手写数字识别技术作为图像处理和模式识别中的研究热点,在大规模数据统计(如行业年检.人口普查等).票据识别.财 ...

  2. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  3. 基于BP神经网络的手写数字识别

    基于BP神经网络的手写数字识别 摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入 ...

  4. Python学习记录 搭建BP神经网络实现手写数字识别

    搭建BP神经网络实现手写数字识别 通过之前的文章我们知道了,构建一个简单的神经网络需要以下步骤 准备数据 初始化假设 输入神经网络进行计算 输出运行结果 这次,我们来通过sklearn的手写数字数据集 ...

  5. BP神经网络实现手写数字识别Python实现,带GUI手写画板

    BP神经网络实现手写数字识别 BP神经网络模型 用tkinter编写用于手写输入的画板 程序运行的效果截图 在B站看了一个机器学习基础的视频( 链接)后,发现到资料里面有一个用BP神经网络对手写数字进 ...

  6. 基于matlab BP神经网络的手写数字识别

    摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...

  7. 小生不才:tensorflow实战01-基于bp神经网络的手写数字识别

    前言 利用搭建网络八股,使用简单的bp神经网络完成手写数字的识别. 搭建过程 导入相应的包 获取数据集,划分数据集和测试集并进行简单处理(归一化等) 对数据进行乱序处理 定义网络结构 选择网络优化器以 ...

  8. 全连接神经网络——MINST手写数字识别

    简介 本文构建了一个全连接神经网络(FCN),实现对MINST数据集手写数字的识别,没有借助任何深度学习算法库,从原理上理解手写数字识别的全过程,包括反向传播,梯度下降等.最终的代码总行数不超过200 ...

  9. BP神经网络(手写数字识别)

    1实验环境 实验环境:CPU i7-3770@3.40GHz,内存8G,windows10 64位操作系统 实现语言:python 实验数据:Mnist数据集 程序使用的数据库是mnist手写数字数据 ...

最新文章

  1. 字符串转 Json 数组
  2. Bash脚本: 根据关键字做替换
  3. 34tomcat设置默认页面
  4. 品质主管每日工作需要做哪些_做微信社群运营需要用到哪些工具来铺助工作呢?...
  5. SpringMVC连接MongoDB操作数据库
  6. Android Studio 设置主题及字体
  7. 【weiphp微信开发教程】留言板插件开发详解
  8. Amlogic_Android7.1 HDMI显示流程源码分析
  9. 57. Attribute specified 属性
  10. python 等号 什么编码_Python运算符与编码
  11. ORA-39194: Table mode jobs require the tables to be comma separated.
  12. 今日头条含室内设计用户粉丝数量统计(2019.12.24)
  13. R语言做复杂金融产品的几何布朗运动的模拟
  14. 搜索引擎点击日志聚类实现相关搜索
  15. 手机:运行内存,机身内存,内存卡的区分
  16. Crunch生成字典
  17. oracle物料属性主要单位,Oracle EBS物料属性设定.doc
  18. 领导力:“不懂带团队你就自己累”
  19. Generalizing to Unseen Domains via Adversarial Data Augmentation 正文
  20. Windows日志研究

热门文章

  1. 随机行走(random walk)
  2. Layui数据表格获取数据库数据
  3. python模块安装_Python模块安装问题
  4. React高阶组件HOC配置
  5. 不会安装该公布程序,因为它可能不安全,请与管理员联系,将程序包的安装用户界面选项更改为基本
  6. 文件加密软件的局限性都有哪些呢,用户知道吗?
  7. 计算机专业学硕的专业代码,考研专业的代码是什么意思
  8. 嵌入式C语言--面试题
  9. 2021-11-18 WinFrom面试题 在Winform中,我们发现在一个Form相关的cs文件有两个,它们的类名都是一样的,只是代码后台类文件中的class之前有partical修饰而已,这是为
  10. 苏炳添未上线,雷军“非正式带货”7000万