实现了正弦曲线的拟合,即regression问题。

创建的模型单输入单输出,两个隐层分别为100、50个神经元。

在keras的官方文档中,给的例子多是关于分类的。因此在测试regression时,遇到了一些问题。总结来说,应注意以下几个方面:

1)训练数据需是矩阵型,这里的输入和输出是1000*1,即1000个样本;每个样本得到一个输出;

注意:训练数据的生成非常关键,首先需要检查输入数据和输出数据的维度匹配;

2)对数据进行规范化,这里用到的是零均值单位方差的规范方法。规范化方法对于各种训练模型很有讲究,具体参照另一篇笔记:http://blog.csdn.net/csmqq/article/details/51461696;

3)输出层的激活函数选择很重要,该拟合的输出有正负值,因此选择tanh比较合适;

4)regression问题中,训练函数compile中的误差函数通常选择mean_squared_error。

5)值得注意的是,在训练时,可以将测试数据的输入和输出绘制出来,这样可以帮助调试参数。

6)keras中实现回归问题,返回的准确率为0。

# -*- coding: utf-8 -*-
"""
Created on Mon May 16 13:34:30 2016
@author: Michelle
"""
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.optimizers import SGD
from keras.layers.advanced_activations import LeakyReLU
from sklearn import preprocessing
from keras.utils.visualize_plots import figures
import matplotlib.pyplot as plt
import numpy as np    #part1: train data
#generate 100 numbers from -2pi to 2pi
x_train = np.linspace(-2*np.pi, 2*np.pi, 1000)  #array: [1000,]
x_train = np.array(x_train).reshape((len(x_train), 1)) #reshape to matrix with [100,1]
n=0.1*np.random.rand(len(x_train),1) #generate a matrix with size [len(x),1], value in (0,1),array: [1000,1]
y_train=np.sin(x_train)+n#训练数据集:零均值单位方差
x_train = preprocessing.scale(x_train)
scaler = preprocessing.StandardScaler().fit(x_train)
y_train = scaler.transform(y_train)#part2: test data
x_test = np.linspace(-5,5,2000)
x_test = np.array(x_test).reshape((len(x_test), 1))
y_test=np.sin(x_test)#零均值单位方差
x_test = scaler.transform(x_test)
#y_test = scaler.transform(y_test)
##plot testing data
#fig, ax = plt.subplots()
#ax.plot(x_test, y_test,'g')#prediction data
x_prd = np.linspace(-3,3,101)
x_prd = np.array(x_prd).reshape((len(x_prd), 1))
x_prd = scaler.transform(x_prd)
y_prd=np.sin(x_prd)
#plot testing data
fig, ax = plt.subplots()
ax.plot(x_prd, y_prd,'r')#part3: create models, with 1hidden layers
model = Sequential()
model.add(Dense(100, init='uniform', input_dim=1))
#model.add(Activation(LeakyReLU(alpha=0.01)))
model.add(Activation('relu'))model.add(Dense(50))
#model.add(Activation(LeakyReLU(alpha=0.1)))
model.add(Activation('relu'))model.add(Dense(1))
#model.add(Activation(LeakyReLU(alpha=0.01)))
model.add(Activation('tanh'))#sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer="rmsprop", metrics=["accuracy"])
#model.compile(loss='mean_squared_error', optimizer=sgd, metrics=["accuracy"])#model.fit(x_train, y_train, nb_epoch=64, batch_size=20, verbose=0)
hist = model.fit(x_test, y_test, batch_size=10, nb_epoch=100, shuffle=True,verbose=0,validation_split=0.2)
#print(hist.history)
score = model.evaluate(x_test, y_test, batch_size=10)out = model.predict(x_prd, batch_size=1)
#plot prediction dataax.plot(x_prd, out, 'k--', lw=4)
ax.set_xlabel('Measured')
ax.set_ylabel('Predicted')
plt.show()
figures(hist)

虚线是预测值,红色是输入值;

绘制误差值随着迭代次数的曲线函数是Visualize_plots.py,

1)将其放在C:\Anaconda2\Lib\site-packages\keras\utils下面。

2)在使用时,需要添加这句话:from keras.utils.visualize_plots import figures,然后在程序中直接调用函数figures(hist)。

垓函数的实现代码为:

# -*- coding: utf-8 -*-
"""
Created on Sat May 21 22:26:24 2016@author: Shemmy
"""def figures(history,figure_name="plots"):""" method to visualize accuracies and loss vs epoch for training as well as testind data\nArgumets: history     = an instance returned by model.fit method\nfigure_name = a string representing file name to plots. By default it is set to "plots" \nUsage: hist = model.fit(X,y)\n              figures(hist) """from keras.callbacks import Historyif isinstance(history,History):import matplotlib.pyplot as plthist     = history.history epoch    = history.epochacc      = hist['acc']loss     = hist['loss']val_loss = hist['val_loss']val_acc  = hist['val_acc']plt.figure(1)plt.subplot(221)plt.plot(epoch,acc)plt.title("Training accuracy vs Epoch")plt.xlabel("Epoch")plt.ylabel("Accuracy")     plt.subplot(222)plt.plot(epoch,loss)plt.title("Training loss vs Epoch")plt.xlabel("Epoch")plt.ylabel("Loss")  plt.subplot(223)plt.plot(epoch,val_acc)plt.title("Validation Acc vs Epoch")plt.xlabel("Epoch")plt.ylabel("Validation Accuracy")  plt.subplot(224)plt.plot(epoch,val_loss)plt.title("Validation loss vs Epoch")plt.xlabel("Epoch")plt.ylabel("Validation Loss")  plt.tight_layout()plt.savefig(figure_name)else:print "Input Argument is not an instance of class History"

讨论keras中实现拟合回归问题的帖子: https://github.com/fchollet/keras/issues/108

用keras创建拟合网络解决回归问题Regression相关推荐

  1. Keras过拟合相关解决办法

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 这种过拟合的处理称为正则化.我们来学习一些最常用的正则化技术,并将其应用于实践中.       ...

  2. 【机器学习】多项式回归案例五:正则惩罚解决过拟合(Ridge回归和Lasso回归)

    正则惩罚解决过拟合(Ridge回归和Lasso回归) 案例五: 正则惩罚解决过拟合(Ridge回归和Lasso回归) 3.2.1 模块加载与数据读入 3.2.2 特征工程 3.2.3 模型搭建与应用 ...

  3. 收藏 | 用 Keras 实现神经网络来解决梯度消失的问题

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者 | Jonathan Quijas 编译 | rong ...

  4. 用 Keras 创建自己的图像标题生成器

    总览 了解图像字幕生成器如何使用编码器-解码器工作 知道如何使用Keras创建自己的图像标题生成器 介绍 图像标题生成器是人工智能的热门研究领域,涉及图像理解和该图像的语言描述.生成格式正确的句子需要 ...

  5. 哪个才是解决回归问题的最佳算法?线性回归、神经网络还是随机森林?

    编译 | AI科技大本营 参与 | 王珂凝 编辑 | 明 明 [AI科技大本营导读]现在,不管想解决什么类型的机器学习(ML)问题,都会有各种不同的算法可以供你选择.尽管在一定程度上,一种算法并不能总 ...

  6. Keras: 创建多个输入以及混合数据输入的神经网络模型

    目录 摘要 正文 Keras: 创建多个输入以及混合数据输入的神经网络模型 什么是混合数据? Keras如何接受多个输入? 房价数据集 获取房价数据集 项目结构 加载数值和分类数据 加载图像数据集 定 ...

  7. 30分钟 Keras 创建一个图像分类器

    深度学习是使用人工神经网络进行机器学习的一个子集,目前已经被证明在图像分类方面非常强大.尽管这些算法的内部工作在数学上是严格的,但 Python 库(比如 keras)使这些问题对我们所有人都可以接近 ...

  8. python回归分析预测模型_在Python中如何使用Keras模型对分类、回归进行预测

    姓名:代良全 学号:13020199007 转载自:https://www.jianshu.com/p/83ba11abdffc [嵌牛导读]: 在Python中如何使用Keras模型对分类.回归进行 ...

  9. 如何解决回归任务数据不均衡的问题?

    摘要:现有的处理不平衡数据/长尾分布的方法绝大多数都是针对分类问题,而回归问题中出现的数据不均衡问题确极少被研究. 本文分享自华为云社区<如何解决回归任务数据不均衡的问题?>,原文作者:P ...

最新文章

  1. C++中最好不要在构造函数和析构函数中调用虚函数!!!
  2. 屏蔽storm ui的kill功能
  3. Eclipse 中 SDK无法更新---解决方法
  4. lvs+keepalived 集群
  5. SAP Spartacus元素被选中后,focus颜色的css实现
  6. python综合管理系统_Python-20 (信息系统-框架/循环/增删/综合应用)
  7. Java ArrayList 数组之间相互转换
  8. mvc:annotation-driven/与mvc:default-servlet-handler/之间的一个问题
  9. java如何获取scanner_java – 使用Scanner获取用户输入
  10. iText 7 基础
  11. lol大脚一直卡在读取服务器信息,英雄联盟大脚 - 英雄联盟 - LOL英雄联盟官网 - 英雄联盟攻略 - 英雄联盟专题站...
  12. C语言也能干大事第十四节(如鹏基础)
  13. java书名号乱码_别骗我,这些居然是汉字,不是乱码
  14. Ffmpeg视频压制的基础知识
  15. day42.自动关机小程序
  16. 汽车距离报警系统c语言编程,基于单片机的汽车防盗报警系统设计与实现.doc
  17. 批量同时创建邮箱和AD账户
  18. 《我要进大厂》- Java基础夺命连环10问,你能坚持到第几问?(面向对象基础篇)
  19. 购物网站(测试+步骤+代码)
  20. android ogg转mp3,MP3提取转换器

热门文章

  1. Windows打包为用户安装字体
  2. 联想笔记本恢复默认F1~F12功能---功能键的切换
  3. 我的世界服务器怎么修改书与笔,书与笔 - Minecraft Wiki,最详细的官方我的世界百科...
  4. TOGAF认证流程图
  5. 易达号-云对讲智能门禁M101-拆机
  6. 小程序助手多功能【微信小程序反编译】工具
  7. channel shuffle通道洗牌
  8. request-response
  9. 2023长沙理工大学计算机考研信息汇总
  10. 机器人新车号牌安装_很帅的动作!现代机器人这样安装汽车挡风玻璃