代码:

  • 导入包
import keras
import numpy as np
import matplotlib.pyplot as plt
# Sequential按顺序构成的模型
from keras.models import Sequential
# Dense全连接层
from keras.layers import Dense,Activation
from keras.optimizers import SGD
  • 生成随机数据
# 使用Numpy生成200个-0.5~0.5之间的值
x_data = np.linspace(-0.5, 0.5, 200)
noise = np.random.normal(0, 0.02, x_data.shape)# y_data= x_data**2 + noise
y_data = np.square(x_data) + noise # 效果与上面一致# 显示随机点
plt.scatter(x_data, y_data)
plt.show()

  • 创建模型+训练
    加入隐藏层拟合更加复杂模型
    加入激活函数来拟合非线性模型
# 建立一个顺序模型
model = Sequential()
# 1-10-1: 加入一个隐藏层(10个神经元):来拟合更加复杂的线性模型。添加激活函数,来计算函数的非线性model.add(Dense(units=10, input_dim=1, activation='relu'))# 全连接层:输入一维数据,输出10个神经元
# model.add(Activation('tanh')) # 也可以直接在Dense里面加激活函数
model.add(Dense(units=1, activation='tanh')) # 全连接层:由于有上一层的添加,所以输入维度默认是10(可以不用写),输出1个值(要写)
# model.add(Activation('tanh'))# 自定义优化器SDG , 学习率默认是0.01(太小,导致要迭代好多次才能较好的拟合数据)
sgd = SGD(lr=0.3)
model.compile(optimizer=sgd, loss='mse')# 训练3000次数据
for step in range(3001):cost = model.train_on_batch(x_data, y_data)if step%500 == 0:print('cost: ',cost)# x_data输入神经网络中,得到预测值y_pred
y_pred = model.predict(x_data)# 显示随机点
plt.scatter(x_data, y_data)
plt.plot(x_data, y_pred,'r-', lw=3)
plt.show()

总代码:

import keras
import numpy as np
import matplotlib.pyplot as plt
# Sequential按顺序构成的模型
from keras.models import Sequential
# Dense全连接层
from keras.layers import Dense,Activation
from keras.optimizers import SGD# 使用Numpy生成200个-0.5~0.5之间的值
x_data = np.linspace(-0.5, 0.5, 200)
noise = np.random.normal(0, 0.02, x_data.shape)# y_data= x_data**2 + noise
y_data = np.square(x_data) + noise # 效果与上面一致# 显示随机点
plt.scatter(x_data, y_data)
plt.show()# 建立一个顺序模型
model = Sequential()
# 1-10-1: 加入一个隐藏层(10个神经元):来拟合更加复杂的线性模型。添加激活函数,来计算函数的非线性model.add(Dense(units=10, input_dim=1, activation='relu'))# 全连接层:输入一维数据,输出10个神经元
# model.add(Activation('tanh')) # 也可以直接在Dense里面加激活函数
model.add(Dense(units=1, activation='tanh')) # 全连接层:由于有上一层的添加,所以输入维度默认是10(可以不用写),输出1个值(要写)
# model.add(Activation('tanh'))# 自定义优化器SDG , 学习率默认是0.01(太小,导致要迭代好多次才能较好的拟合数据)
sgd = SGD(lr=0.3)
model.compile(optimizer=sgd, loss='mse')# 训练3000次数据
for step in range(3001):cost = model.train_on_batch(x_data, y_data)if step%500 == 0:print('cost: ',cost)# x_data输入神经网络中,得到预测值y_pred
y_pred = model.predict(x_data)# 显示随机点
plt.scatter(x_data, y_data)
plt.plot(x_data, y_pred,'r-', lw=3)
plt.show()

参考:

视频: 覃秉丰老师的“Keras入门”:http://www.ai-xlab.com/course/32
博客参考:https://www.cnblogs.com/XUEYEYU/tag/keras%E5%AD%A6%E4%B9%A0/

3. 使用Keras-神经网络来拟合非线性模型相关推荐

  1. keras神经网络回归预测_如何使用Keras建立您的第一个神经网络来预测房价

    keras神经网络回归预测 by Joseph Lee Wei En 通过李维恩 一步一步的完整的初学者指南,可使用像Deep Learning专业版这样的几行代码来构建您的第一个神经网络! (A s ...

  2. 从零开始学keras之过拟合与欠拟合

    在预测电影评论.主题分类和房价回归中,模型在留出验证数据上的性能总是在几轮后达到最高点,然后开始下降.也就是说,模型很快就在训练数据上开始过拟合.过拟合存在于所有机器学习问题中.学会如何处理过拟合对掌 ...

  3. Keras神经网络实现泰坦尼克号旅客生存预测

    Keras神经网络实现泰坦尼克号旅客生存预测 介绍 数据集介绍 算法 学习器 分类器 实现 数据下载与导入 预处理 建立模型 训练 可视化 评估,预测 结果 代码 介绍 参考资料: 网易云课堂的深度学 ...

  4. Keras神经网络的学习与使用(1)

    Keras神经网络层学习与使用 Keras的简单介绍 Keras框架中的方法介绍 Compile()方法 fit()方法 summary()方法 evaluate()方法 perdict()方法 Ke ...

  5. Keras神经网络集成技术

    Keras神经网络集成技术 create_keras_neuropod 将Keras模型打包为神经网络集成包.目前,上文已经支持TensorFlow后端. create_keras_neuropod( ...

  6. 避免神经网络过拟合的5种技术(附链接) | CSDN博文精选

    作者 | Abhinav Sagar 翻译 | 陈超 校对 | 王琦 来源 | 数据派THU(ID:DatapiTHU) (*点击阅读原文,查看作者更多精彩文章) 本文介绍了5种在训练神经网络中避免过 ...

  7. 神经网络+过拟合+避免

    神经网络+过拟合+避免 作者:Abhinav Sagar & THU 最近一年我一直致力于深度学习领域.这段时间里,我使用过很多神经网络,比如卷积神经网络.循环神经网络.自编码器等等.我遇到的 ...

  8. 独家 | 避免神经网络过拟合的5种技术(附链接)

    作者:Abhinav Sagar 翻译:陈超 校对:王琦 本文约1700字,建议阅读8分钟. 本文介绍了5种在训练神经网络中避免过拟合的技术. 最近一年我一直致力于深度学习领域.这段时间里,我使用过很 ...

  9. 避免神经网络过拟合的5种技术

    作者:Abhinav Sagar 翻译:陈超 校对:王琦 本文约1700字,建议阅读8分钟. 本文介绍了5种在训练神经网络中避免过拟合的技术. 最近一年我一直致力于深度学习领域.这段时间里,我使用过很 ...

  10. 解决神经网络过拟合问题—Dropout方法、python实现

    解决神经网络过拟合问题-Dropout方法 一.what is Dropout?如何实现? 二.使用和不使用Dropout的训练结果对比 一.what is Dropout?如何实现? 如果网络模型复 ...

最新文章

  1. ZooKeeper集群环境安装与配置
  2. STM32 基础系列教程 24 - USB_HID_key
  3. Servlet中的HttpServlet
  4. 为什么手机版scp进不去_SCP1471,只属于你一人你的异常狗子,scp基金会系列
  5. calendar前推n天_Shell获取ES3天的索引列表进行迁移操作
  6. java new string作用_java中直接new String对象?
  7. Codeforces Round 493
  8. daoi php_聊聊这些年用过的AOI
  9. Codeforces Round 546 (Div. 2)
  10. Spring Cloud与微服务学习总结(2)——Spring Cloud相较于Dubbo等RPC服务框架的优势
  11. 敲黑板划重点!「PV,UV流量预测算法大赛」明日结果提交最后1天!
  12. 零基础怎么开启编程之路 -(第1期)
  13. c语言函数名称大全,C语言函数大全
  14. 关于mysql卸载不干净
  15. 煮酒论英雄——点评三国人物
  16. 麒麟鲲鹏升腾鸿蒙巴龙,华为四大芯片 麒麟、巴龙、昇腾和鲲鹏“四大天王”...
  17. UI设计中按钮如何设计,常见的按钮设计类型
  18. 打开Word文档的时候提示 “安全警告 宏已被禁用”
  19. C语言strchr()函数以及strstr()函数的实现
  20. RxJava入门之生命周期管理

热门文章

  1. zabbix重点笔记
  2. 从零到实现Shiro中Authorization和Authentication的缓存
  3. “约女生图书馆一起自习”总结
  4. 工作上碰到的技术问题积累
  5. 借助Squid代理服务器,建立灵活的访问控制系统
  6. Exchange 2010 使用http访问 OWA
  7. ADO.NET入门教程(六) 谈谈Command对象与数据检索
  8. redhat5.4上安装oracle9i
  9. [翻译]Chameleon介绍(6) : 动作控件
  10. java 注册表 下载_Java修改windows注册表(完全修改)