机器学习应用篇(八)——基于BP神经网络的预测

文章目录

  • 机器学习应用篇(八)——基于BP神经网络的预测
    • 一、Introduction
      • 1 BP神经网络的优点
      • 2 BP神经网络的缺点
    • 二、实现过程
      • 1 Demo
      • 2 基于BP神经网络的乳腺癌分类预测
    • 三、Keys

一、Introduction

1 BP神经网络的优点

  1. 非线性映射能力:BP神经网络实质上实现了一个从输入到输出的映射功能,数学理论证明三层的神经网络就能够以任意精度逼近任何非线性连续函数。这使得其特别适合于求解内部机制复杂的问题,即BP神经网络具有较强的非线性映射能力。
  2. 自学习和自适应能力:BP神经网络在训练时,能够通过学习自动提取输入、输出数据间的“合理规则”,并自适应地将学习内容记忆于网络的权值中。即BP神经网络具有高度自学习和自适应的能力。
  3. 泛化能力:所谓泛化能力是指在设计模式分类器时,即要考虑网络在保证对所需分类对象进行正确分类,还要关心网络在经过训练后,能否对未见过的模式或有噪声污染的模式,进行正确的分类。也即BP神经网络具有将学习成果应用于新知识的能力。

2 BP神经网络的缺点

  1. 局部极小化问题:从数学角度看,传统的 BP神经网络为一种局部搜索的优化方法,它要解决的是一个复杂非线性化问题,网络的权值是通过沿局部改善的方向逐渐进行调整的,这样会使算法陷入局部极值,权值收敛到局部极小点,从而导致网络训练失败。加上BP神经网络对初始网络权重非常敏感,以不同的权重初始化网络,其往往会收敛于不同的局部极小,这也是每次训练得到不同结果的根本原因
  2. BP 神经网络算法的收敛速度慢:由于BP神经网络算法本质上为梯度下降法,它所要优化的目标函数是非常复杂的,因此,必然会出现“锯齿形现象”,这使得BP算法低效;又由于优化的目标函数很复杂,它必然会在神经元输出接近0或1的情况下,出现一些平坦区,在这些区域内,权值误差改变很小,使训练过程几乎停顿;BP神经网络模型中,为了使网络执行BP算法,不能使用传统的一维搜索法求每次迭代的步长,而必须把步长的更新规则预先赋予网络,这种方法也会引起算法低效。以上种种,导致了BP神经网络算法收敛速度慢的现象。
  3. BP 神经网络结构选择不一:BP神经网络结构的选择至今尚无一种统一而完整的理论指导,一般只能由经验选定。网络结构选择过大,训练中效率不高,可能出现过拟合现象,造成网络性能低,容错性下降,若选择过小,则又会造成网络可能不收敛。而网络的结构直接影响网络的逼近能力及推广性质。因此,应用中如何选择合适的网络结构是一个重要的问题。

二、实现过程

1 Demo

#%% 基础数组运算库导入
import numpy as np
# 画图库导入
import matplotlib.pyplot as plt
# 导入三维显示工具
from mpl_toolkits.mplot3d import Axes3D
# 导入BP模型
from sklearn.neural_network import MLPClassifier
# 导入demo数据制作方法
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import warnings
from sklearn.exceptions import ConvergenceWarning#%%模型训练
# 制作五个类别的数据,每个类别1000个样本
train_samples, train_labels = make_classification(n_samples=1000, n_features=3, n_redundant=0,n_classes=5, n_informative=3, n_clusters_per_class=1,class_sep=3, random_state=10)
# 将五个类别的数据进行三维显示
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(train_samples[:, 0], train_samples[:, 1], train_samples[:, 2], marker='o', c=train_labels)
plt.title('Demo Data Map')

#%% 建立 BP 模型, 采用sgd优化器,relu非线性映射函数
BP = MLPClassifier(solver='sgd',activation = 'relu',max_iter = 500,alpha = 1e-3,hidden_layer_sizes = (32,32),random_state = 1)
# 进行模型训练
with warnings.catch_warnings():warnings.filterwarnings("ignore", category=ConvergenceWarning,module="sklearn")BP.fit(train_samples, train_labels)
# 查看 BP 模型的参数
print(BP)
#%% 进行模型预测
predict_labels = BP.predict(train_samples)
# 显示预测的散点图
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(train_samples[:, 0], train_samples[:, 1], train_samples[:, 2], marker='o', c=predict_labels)
plt.title('Demo Data Predict Map with BP Model')# 显示预测分数
print("预测准确率: {:.4f}".format(BP.score(train_samples, train_labels)))# 可视化预测数据
print("真实类别:", train_labels[:10])
print("预测类别:", predict_labels[:10])
# 准确率等报表
print(classification_report(train_labels, predict_labels))# 计算混淆矩阵
classes = [0, 1, 2, 3]
cofusion_mat = confusion_matrix(train_labels, predict_labels, classes)
sns.set()
figur, ax = plt.subplots()
# 画热力图
sns.heatmap(cofusion_mat, cmap="YlGnBu_r", annot=True, ax=ax)
ax.set_title('confusion matrix')  # 标题
ax.set_xticklabels([''] + classes, minor=True)
ax.set_yticklabels([''] + classes, minor=True)
ax.set_xlabel('predict')  # x轴
ax.set_ylabel('true')  # y轴
plt.show()



#%%# 进行新的测试数据测试
test_sample = np.array([[-1, 0.1, 0.1]])
print(f"{test_sample} 类别是: ", BP.predict(test_sample))
print(f"{test_sample} 类别概率分别是: ", BP.predict_proba(test_sample))test_sample = np.array([[-1.2, 10, -91]])
print(f"{test_sample} 类别是: ", BP.predict(test_sample))
print(f"{test_sample} 类别概率分别是: ", BP.predict_proba(test_sample))test_sample = np.array([[-12, -0.1, -0.1]])
print(f"{test_sample} 类别是: ", BP.predict(test_sample))
print(f"{test_sample} 类别概率分别是: ", BP.predict_proba(test_sample))test_sample = np.array([[100, -90.1, -9.1]])
print(f"{test_sample} 类别是: ", BP.predict(test_sample))
print(f"{test_sample} 类别概率分别是: ", BP.predict_proba(test_sample))

2 基于BP神经网络的乳腺癌分类预测

#%%基于BP神经网络的乳腺癌分类
#基本库导入
# 导入乳腺癌数据集
from sklearn.datasets import load_breast_cancer
# 导入BP模型
from sklearn.neural_network import MLPClassifier
# 导入训练集分割方法
from sklearn.model_selection import train_test_split
# 导入预测指标计算函数和混淆矩阵计算函数
from sklearn.metrics import classification_report, confusion_matrix
# 导入绘图包
import seaborn as sns
import matplotlib.pyplot as plt
# 导入三维显示工具
from mpl_toolkits.mplot3d import Axes3D
# 导入乳腺癌数据集
cancer = load_breast_cancer()
# 查看数据集信息
print('breast_cancer数据集的长度为:',len(cancer))
print('breast_cancer数据集的类型为:',type(cancer))
# 分割数据为训练集和测试集
cancer_data = cancer['data']
print('cancer_data数据维度为:',cancer_data.shape)
cancer_target = cancer['target']
print('cancer_target标签维度为:',cancer_target.shape)
cancer_names = cancer['feature_names']
cancer_desc = cancer['DESCR']
#分为训练集与测试集
cancer_data_train,cancer_data_test = train_test_split(cancer_data,test_size=0.2,random_state=42)#训练集
cancer_target_train,cancer_target_test = train_test_split(cancer_target,test_size=0.2,random_state=42)#测试集

#%%# 建立 BP 模型, 采用Adam优化器,relu非线性映射函数
BP = MLPClassifier(solver='adam',activation = 'relu',max_iter = 1000,alpha = 1e-3,hidden_layer_sizes = (64,32, 32),random_state = 1)
# 进行模型训练
BP.fit(cancer_data_train, cancer_target_train)
#%% 进行模型预测
predict_train_labels = BP.predict(cancer_data_train)
# 可视化真实数据
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(cancer_data_train[:, 0], cancer_data_train[:, 1], cancer_data_train[:, 2], marker='o', c=cancer_target_train)
plt.title('True Label Map')
plt.show()
# 可视化预测数据
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(cancer_data_train[:, 0], cancer_data_train[:, 1], cancer_data_train[:, 2], marker='o', c=predict_train_labels)
plt.title('Cancer with BP Model')
plt.show()

#%% 显示预测分数
print("预测准确率: {:.4f}".format(BP.score(cancer_data_test, cancer_target_test)))
# 进行测试集数据的类别预测
predict_test_labels = BP.predict(cancer_data_test)
print("测试集的真实标签:\n", cancer_target_test)
print("测试集的预测标签:\n", predict_test_labels)
#%% 进行预测结果指标统计 统计每一类别的预测准确率、召回率、F1分数
print(classification_report(cancer_target_test, predict_test_labels))

#%% 计算混淆矩阵
confusion_mat = confusion_matrix(cancer_target_test, predict_test_labels)
# 打印混淆矩阵
print(confusion_mat)
# 将混淆矩阵以热力图的方式显示
sns.set()
figure, ax = plt.subplots()
# 画热力图
sns.heatmap(confusion_mat, cmap="YlGnBu_r", annot=True, ax=ax)
# 标题
ax.set_title('confusion matrix')
# x轴为预测类别
ax.set_xlabel('predict')
# y轴实际类别
ax.set_ylabel('true')
plt.show()


注:之前还做过基于BP神经网络的人口普查数据预测,有需要的猿友私信

三、Keys

  1. BP神经网络的要点在于前向传播和误差反向传播,来对参数进行更新,使得损失最小化。
  2. 它是一个迭代算法,基本思想是:
  3. 先计算每一层的状态和激活值,直到最后一层(即信号是前向传播的);
    
  4. 计算每一层的误差,误差的计算过程是从最后一层向前推进的(反向传播);
    
  5. 更新参数(目标是误差变小)。迭代前面两个步骤,直到满足停止准则(比如相邻两次迭代的误差的差别很小)。
    

886~~~

机器学习应用篇(八)——基于BP神经网络的预测相关推荐

  1. 机器学习之基于BP神经网络的预测

    BP神经网络具有以下优点: 1) 非线性映射能力:BP神经网络实质上实现了一个从输入到输出的映射功能,数学理论证明三层的神经网络就能够以任意精度逼近任何非线性连续函数.这使得其特别适合于求解内部机制复 ...

  2. 基于BP神经网络飞机颠簸预测

    背景介绍 飞机在飞行过程中遇到扰动气流或者受到方向.大小不同的气流冲击导致的左右摇晃.前后颠簸.上下抛掷以及局部震颤等想象统称为颠簸.中度以上颠簸会使飞机仪表指示失常,操纵困难:特别严重时会破坏飞机结 ...

  3. 【机器学习代码例】用BP神经网络做预测

    机器学习算法 源码下载链接 导入包 import numpy as np import matplotlib.pyplot as plt import pandas as pd 定义激活函数 # 激活 ...

  4. 基于bp神经网络的pid算法,神经网络pid控制器设计

    基于BP神经网络的PID控制器设计 参考一下刘金琨的<先进PID控制>这本书. 例子:被控对象yout(k)=a(k)yout(k-1)/(1+yout(k-1)^2)+u(k_1)其中a ...

  5. 基于bp神经网络的pid控制,pid神经网络什么原理

    关于基于神经网络的PID液位控制用MATLAB怎么编程啊?求高手指点!!!! . 其实只需要PID参数能够顺利确定就行了,这里有个程序,你试试看closeallclearallclctic%初始化x= ...

  6. 【论文研读】基于BP 神经网络的 Q235 钢力学性能预测模型

    基于BP 神经网络的 Q235 钢力学性能预测模型 刘志伟1, 2 , 马劲红1, 2 , 陈伟1 , 王文正1 1.华北理工大学 冶金与能源学院, 河北 唐山 063210; 2.现代冶金技术教育部 ...

  7. 基于BP神经网络的手写数字识别

    基于BP神经网络的手写数字识别 摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入 ...

  8. gadecod matlab,【预测模型】基于遗传算法优化BP神经网络房价预测matlab源码

    一.简介 1 遗传算法概述 遗传算法(Genetic Algorithm,GA)是进化计算的一部分,是模拟达尔文的遗传选择和自然淘汰的生物进化过程的计算模型,是一种通过模拟自然进化过程搜索最优解的方法 ...

  9. 基于BP神经网络的非线性函数拟合(一维高斯函数)研究-含Matlab代码

    目录 一.引言 二.BP神经网络的结构与原理 2.1 信息前向传播 2.2 误差的反向传播过程 三.基于BP神经网络的非线性函数拟合 3.1 数据生成 3.2 神经网络拟合结果 四.参考文献 五.Ma ...

最新文章

  1. 深度学习基础总结,无一句废话(附完整思维导图)
  2. 云原生解决什么问题?
  3. 剑指云内存数据库,阿里云在下一盘大棋
  4. html 页面视图中的资源文件(css/js/image)的路径问题。
  5. cpu,内核和逻辑处理器的关系
  6. (四)数据结构之“队列”
  7. 新知丨口服益生菌是商业噱头?
  8. (四)深入浅出TCPIP之TCP三次握手和四次挥手(下)的抓包分析
  9. 要写related_name的两种情况
  10. 【白皮书分享】2021年智慧城市白皮书:依托智慧服务,共创新型智慧城市.pdf(附下载链接)
  11. Context与ApplicationContext
  12. I00015 打印等腰三角形字符图案(底边在上)
  13. 禁止选中页面内容-兼容ie、firefox、chrome
  14. 【Unity】12.2 导航网格寻路简单示例
  15. Eclipse安装包 百度网盘
  16. 爱立信实习总结之面试心得
  17. 计算机中任务管理器的主要功能是什么,任务管理器的作用有哪些 可以解决9成电脑问题...
  18. java11降到java8
  19. Linux的memory日志,Linux:日志,cpu,memory,mount,load等系统信息查看
  20. git stach储藏功能(SourceTree 使用方法,Visual studio 2019 中使用

热门文章

  1. NOI Linux 2.0 Arbiter 测评系统详细步骤(保姆式指南)
  2. 城市轨道交通信号与通信系统
  3. project 2007 调整字体颜色
  4. 更改ubuntu登陆界面
  5. chatgpt赋能python:Python绝对值符号:用法及实例
  6. 计算机文化学习笔记2
  7. php 选座,jQuery在线选座(高铁版)
  8. 【三、网络配置与系统管理】
  9. 摸鱼神器,GitHub 上开源的坦克大战!
  10. 培训教育系统阶段性效果展示