【机器学习代码例】用BP神经网络做预测
机器学习算法
源码下载链接
导入包
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
定义激活函数
# 激活函数
def tanh(x):return (np.exp(x)-np.exp(-x))/(np.exp(x)+np.exp(-x))
def de_tanh(x):return (1-x**2)
数据的读取
# 数据的读取
def Data(df):df.columns = ["x", "y", "h", "use", "As"]longtitude = df["x"] # 抽取前四列作为训练数据的各属性值longtitude = np.array(longtitude)latitude = df["y"]latitude = np.array(latitude)elevation = df["h"]elevation = np.array(elevation)functional = df["use"]functional = np.array(functional)ag = df["As"]ag = np.array(ag)return longtitude,latitude,elevation,functional,ag
归一化
# 对数据进行归一化处理
def Normalized(samplein,sampleout):inminmax = np.array([samplein.min(axis=1).T.tolist()[0],samplein.max(axis=1).T.tolist()[0]]).transpose() # 对应最大值最小值# print("sampleinminmax:\n",sampleinminmax)# print("sampleout:\n",sampleout)outminmax = np.array([sampleout.min(axis=1).T.tolist()[0],sampleout.max(axis=1).T.tolist()[0]]).transpose() # 对应最大值最小值,4*299# print("sampleoutminmax:\n",sampleoutminmax)innorm = (2 * (np.array(samplein.T) - inminmax.transpose()[0])/ (inminmax.transpose()[1] - inminmax.transpose()[0]) - 1).transpose() # 1*299# print("sampleinnorm:\n",sampleinnorm)outnorm = (2 * (np.array(sampleout.T) - outminmax.transpose()[0])/ (outminmax.transpose()[1] - outminmax.transpose()[0]) - 1).transpose()# print("sampleoutnorm:\n",sampleoutnorm)# 给输出样本添加噪音noise = 0.03 * np.random.rand(outnorm.shape[0], outnorm.shape[1])outnorm += noisereturn innorm,outnorm,outminmax
初始化参数
# 初始化 w1,b1,w2,b2
def initial():scale = np.sqrt(3/((indim+outdim)*0.5)) #最大值最小值范围为-1.44~1.44w1 = np.random.uniform(low=-scale, high=scale, size=[hiddenunitnum,indim])b1 = np.random.uniform(low=-scale, high=scale, size=[hiddenunitnum,1])w2 = np.random.uniform(low=-scale, high=scale, size=[outdim,hiddenunitnum])b2 = np.random.uniform(low=-scale, high=scale, size=[outdim,1])# print("scale:\n",scale)# print("w1:\n",w1)# print("b1",b1)# print("w2:\n",w2)# print("b2",b2)return w1,b1,w2,b2
更新参数
# 学习训练更新权重
def new_weight(w1,b1,w2,b2,maxepochs,sampleinnorm,sampleoutnorm,samnum,errorfinal,learnrate):errhistory = []for i in range(maxepochs):print("Generation : ", i)hiddenout = tanh((np.dot(w1, sampleinnorm).transpose() + b1.transpose())).transpose() # 8*299 np.dot为矩阵的乘法networkout = tanh((np.dot(w2, hiddenout).transpose() + b2.transpose())).transpose() # 1*299err = sampleoutnorm - networkout # 1*299loss = np.sum(np.abs(err)) / samnummse = np.sum(np.square(sampleoutnorm - networkout)) / len(networkout)print(loss)print(mse)sse = sum(sum(err ** 2))errhistory.append(sse)if sse < errorfinal:breakdelta2 = err * de_tanh(networkout)delta1 = np.dot(w2.transpose(), delta2) * de_tanh(hiddenout) # hiddenout*(1-hiddenout) #8*299dw2 = np.dot(delta2, hiddenout.transpose()) # 1*8db2 = np.dot(delta2, np.ones((samnum, 1))) # 1*1dw1 = np.dot(delta1, sampleinnorm.transpose()) # 8*4db1 = np.dot(delta1, np.ones((samnum, 1))) # 8*1w2 += learnrate * dw2b2 += learnrate * db2w1 += learnrate * dw1b1 += learnrate * db1# print('更新的权重w1:', w1)# print('更新的偏置b1:', b1)# print('更新的权重w2:', w2)# print('更新的偏置b2:', b2)# print("平均损失值为:", loss)return w1,b1,w2,b2
计算误差
# 计算误差
def err(output,output1):mse = sum(np.square(output - output1)) / len(output)mae = sum(np.abs(output - output1)) / len(output)mape = sum(np.abs((output - output1) / output)) / len(output)return mse,mae,mape
预测
# 对测试集进行预测
def predict(w1,b1,w2,b2,inputnorm):hiddenout = tanh((np.dot(w1, inputnorm).transpose() + b1.transpose())).transpose()networkout = tanh((np.dot(w2, hiddenout).transpose() + b2.transpose())).transpose()return networkout
反归一化
# 对预测值进行反归一化
def Anti_Normal(sampleoutminmax,networkout):diff = sampleoutminmax[:, 1] - sampleoutminmax[:, 0]networkout2 = (networkout + 1) / 2networkout2 = networkout2 * diff + sampleoutminmax[0][0]output1 = networkout2.flatten() # 降成一维数组output1 = output1.tolist() # output1 为预测值return output1
画图
# 根据预测结果进行画图
def show(output,output1):fig = plt.figure(figsize=(11, 9), dpi=120)plt.title('Prediction of As Content', fontdict={'weight': 'normal', 'size': 24})x = range(1, 21, 1)plt.plot(x, output, color="black", label="real", linewidth=2.0, linestyle="-")plt.plot(x, output1, color="black", label="prediction", linewidth=2.0, linestyle="--")plt.tick_params(labelsize=18)plt.xlim(0, 21)plt.ylim(0, 30)plt.xlabel("spot", fontdict={'weight': 'normal', 'size': 20})plt.ylabel("As Content(μg/g)", fontdict={'weight': 'normal', 'size': 20})plt.xticks(range(1, 21, 1))plt.legend(loc="upper right", prop={'size': 20})plt.savefig("bp.png")plt.show()
主函数
if __name__ == '__main__':# 1、======= 读取数据 ==========# 训练集的读取train = pd.read_csv("train.csv") # 返回一个DataFrame的对象,这个是pandas的一个数据结构train_longtitude, train_latitude, train_elevation, train_functional, train_ag = Data(train)samplein = np.mat([train_longtitude, train_latitude, train_elevation, train_functional])sampleout = np.mat([train_ag])# print("samplein:\n", samplein)# print("sampleout:\n",sampleout)# print("samplein.shape:",samplein.shape) # 4 * 300# print("sampleout.shape:",sampleout.shape) # 1 * 300# 测试集的读取test = pd.read_csv("test.csv")test_longtitude, test_latitude, test_elevation, test_functional, test_ag = Data(test)input = np.mat([test_longtitude, test_latitude,test_elevation, test_functional])output = np.mat([test_ag])# output = ag# print("input:\n", input)# print("input.shape:", input.shape)# print("output:\n",output)# print("output.shape:",output.shape)# 2、======= 数据的归一化处理 ==========# 训练集归一化处理sampleinnorm, sampleoutnorm,sampleoutminmax = Normalized(samplein, sampleout)# print("sampleinnorm:\n",sampleinnorm)# print("sampleoutnorm:\n",sampleoutnorm)# print("sampleoutminmax:\n",sampleoutminmax)# 测试集归一化inputnorm, outputnorm,outputminmax = Normalized(input, output)# print("inputnorm:\n",inputnorm)# print("outputnorm:\n",outputnorm)# print("outputminmax:\n",outputminmax)# 3、====== 神经网络训练 =======maxepochs = 10000learnrate = 0.01errorfinal = 0.65 * 10 ** (-3)samnum = 300indim = 4outdim = 1hiddenunitnum = 7# 参数 w1,b1,w2,b2 赋初值w1,b1,w2,b2 = initial()# 更新参数w1,b1,w2,b3 = new_weight(w1,b1,w2,b2,maxepochs,sampleinnorm,sampleoutnorm,samnum,errorfinal,learnrate)# print('更新的权重w1:', w1)# print('更新的偏置b1:', b1)# print('更新的权重w2:', w2)# print('更新的偏置b2:', b2)# 4、====== 训练好的神经网络预测测试集 ======networkout = predict(w1,b1,w2,b2,inputnorm)# 5、======= 反归一化 =========output1 = Anti_Normal(sampleoutminmax, networkout) # list 表# print("output1:\n",output1)# 6、======= 计算误差 =========output = test_ag # list 表# print("output:\n",output)mse, mae, mape = err(output,output1)print("mse:",mse,"mae:",mae,"mape:",mape)# 7、 ========= 出图 ===========show(output,output1)
【机器学习代码例】用BP神经网络做预测相关推荐
- BP神经网络做分类+隐含层节点确定+红酒数据为例
网上用BP神经网络做预测的代码有很多,但是做分类的很少,(虽然都是一个道理),但是预测的代码下载下来还得动手修改,对于想直接复制粘贴的友友们很不友好.想用分类代码的直接来我这里复制粘贴即可,跑不通的欢 ...
- 机器学习应用篇(八)——基于BP神经网络的预测
机器学习应用篇(八)--基于BP神经网络的预测 文章目录 机器学习应用篇(八)--基于BP神经网络的预测 一.Introduction 1 BP神经网络的优点 2 BP神经网络的缺点 二.实现过程 1 ...
- 基于果蝇优化的BP神经网络(预测应用) - 附代码
基于果蝇优化的BP神经网络(预测应用) 文章目录 基于果蝇优化的BP神经网络(预测应用) 1.数据介绍 3.FOA优化BP神经网络 3.1 BP神经网络参数设置 3.2 果蝇算法应用 4.测试结果: ...
- 基于头脑风暴优化的BP神经网络(预测应用) - 附代码
基于头脑风暴优化的BP神经网络(预测应用) - 附代码 文章目录 基于头脑风暴优化的BP神经网络(预测应用) - 附代码 1.数据介绍 3.BSO优化BP神经网络 3.1 BP神经网络参数设置 3.2 ...
- 基于布谷鸟优化的BP神经网络(预测应用) - 附代码
基于布谷鸟优化的BP神经网络(预测应用) - 附代码 文章目录 基于布谷鸟优化的BP神经网络(预测应用) - 附代码 1.数据介绍 3.CS优化BP神经网络 3.1 BP神经网络参数设置 3.2 布谷 ...
- 基于鸟群优化的BP神经网络(预测应用) - 附代码
基于鸟群优化的BP神经网络(预测应用) - 附代码 文章目录 基于鸟群优化的BP神经网络(预测应用) - 附代码 1.数据介绍 3.BSA优化BP神经网络 3.1 BP神经网络参数设置 3.2 鸟群算 ...
- bp神经网络时间序列预测,bp神经网络有几个阶段
什么是BP神经网络? . BP算法的基本思想是:学习过程由信号正向传播与误差的反向回传两个部分组成:正向传播时,输入样本从输入层传入,经各隐层依次逐层处理,传向输出层,若输出层输出与期望不符,则将误差 ...
- 利用MATLAB 2016a进行BP神经网络的预测(含有神经网络工具箱)
最近一段时间在研究如何利用预测其销量个数,在网上搜索了一下,发现了很多模型来预测,比如利用回归模型.时间序列模型,GM(1,1)模型,可是自己在结合实际的工作内容,发现这几种模型预测的精度不是很高,于 ...
- 神经网络可以用来预测吗,如何用神经网络做预测
如何利用训练好的神经网络进行预测 谷歌人工智能写作项目:神经网络伪原创 如何人工神经网络来预测下一个数值 newff函数建立BP神经网络,历史数据作为样本,例如前n个数据作为输入,输入节点为n写作猫. ...
最新文章
- cass道路道路设计参数文件命令为什么没反应_为什么MySQL不建议使用delete删除数据?...
- 【AI产品】一键去除杂物,Photo Eraser助你拍出美丽照片
- python seek_Python 文件操作seek()函数
- spring-boot-route(十二)整合redis做为缓存
- IIS 配置 url 重写...
- Atitit.js javascript异常处理机制与java异常的转换 多重catc hDWR 环境 .js exception process Vob7
- AH3050_12V升18V2A 同步升压芯片
- IDL学习——调用enviTask对高分2号影像进行预处理
- 基于LZ77算法和Huffman编码的文件压缩项目
- 【C语言】 文件指针编程应用
- hbase metric 监控项
- 梅西大学研究员创造出新3D打印系统 用螺杆作为进料机构挤出颗粒
- MySQL数据库知识大全
- mysql error code: 1205_Mysql错误:ERROR 1205 (HY000): Lock wait timeout exceeded解决办法
- js实现json格式化,以及json校验工具的简单实现
- EFR32芯科zigbee学习文档资源总结
- 参禅静坐--虚极静笃--快速恢复脑力体力
- STS (Spring Tool Suite) 目录和作用初级
- HTML5(李炎恢)学习笔记三 ------------- HTML5元素(上)
- 3dsMax9 64bit版本下载
热门文章
- c语言二进制转换算法栈,用C语言顺序栈实现十进制和二进制的转换
- 对象入参指定泛型类型,如何得到正确的MethodInfo对象当一个类使用泛型和泛型类型参数...
- 渝粤题库 陕西师范大学 《经济法Ⅰ》作业
- 2018江苏专转本计算机知识点,2018江苏专转本计算机真题[含解析]
- Centos7 卸载 Nginx 并重新安装 Nginx
- Unity资源--几种卡通动物的模型+动画
- http,socks5,socks4代理的区别
- 华为交换机导入配置_华为交换机配置的导出和导入方法
- 别再问如何用Python提取PDF内容了!
- 1619C. Wrong Addition