利用LSTM对股价进行预测并可视化

Tushare API

网址:Tushare大数据社区

注册账号并获取自己的token码

tushare提供了众多接口,根据积分不同,token权限不同,读者请自行前往官网了解,这里不做过多介绍。

首先导入要用到的包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tushare as ts
from keras.layers import LSTM, Dense, Dropout
from keras.models import Sequential, load_model
from sklearn.preprocessing import MinMaxScaler
import talib
import mpl_finance as mpf

调用tushare API获取股价数据

token = '你的token码'
ts.set_token(token)
data_ts = ts.pro_bar(ts_code='000001.SZ', start_date='20100101', end_date='20201214', asset='E', freq='D').iloc[::-1, :]
data_ts = data_ts.reset_index()
data_ts = data_ts.drop('index', axis=1)

对2020年的数据进行可视化

数据处理

## K线
data_plot = data_ts.loc[data_ts['trade_date'] >= '20200101']
data_plot = data_plot.reset_index()
data_plot = data_plot.drop(['index'], axis=1)
data_plot = data_plot.reset_index()
ochl = data_plot[['index','open','close','high','low']].values
## MACD
dif, dea, macdbar = talib.MACD(data_plot['close'].values, fastperiod=12, slowperiod=26, signalperiod=9)
x = data_plot['index'].values
# RSI
shortRSI = talib.RSI(data_plot['close'].values, timeperiod=6)
longRSI = talib.RSI(data_plot['close'].values, timeperiod=12)

绘图

## 绘图
# 指数平滑移动平均线
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(20,8), dpi=80)
axes[0].plot(data_plot['close'].ewm(span=5).mean(), label='5日均线', linewidth=0.8)
axes[0].plot(data_plot['close'].ewm(span=10).mean(), label='10日均线', linewidth=0.8)
axes[0].plot(data_plot['close'].ewm(span=20).mean(), label='20日均线', linewidth=0.8)
axes[0].plot(data_plot['close'].ewm(span=30).mean(), label='30日均线', linewidth=0.8)
axes[0].plot(data_plot['close'].ewm(span=60).mean(), label='60日均线', linewidth=0.8)
axes[0].legend(loc='best')
# k线图
mpf.candlestick_ochl(axes[0], ochl, width=0.5, colorup='r', colordown='g')
# plt.xticks(data_plot['index'].values[::25],data_plot['trade_date'][::25])
# MACD
axes[1].plot(x, dif, label='差离值', linewidth=1)
axes[1].plot(x, dea, label='讯号线', linewidth=1)
# MACD Bar
bar1 = np.where(macdbar>0, macdbar, 0)
bar2 = np.where(macdbar<0, macdbar, 0)
axes[1].bar(x, bar1, color='r', label='bar1')
axes[1].bar(x, bar2, color='g', label='bar2')
plt.xticks(data_plot['index'].values[::25],data_plot['trade_date'][::25])
plt.show()

LSTM预测股价

准备数据

取2010-2019年日收盘价做训练,2020的收盘价做验证

data_train_index = data_ts.loc[data_ts['trade_date'] <= '20200101'].index
data_train = data_ts.loc[data_ts['trade_date'] <= '20200101']['close'].values.reshape(-1, 1)
data_test_index = data_ts.loc[data_ts['trade_date'] >= '20200101'].index
data_test = data_ts.loc[data_ts['trade_date'] >= '20200101']['close'].values.reshape(-1, 1)

绘图看看训练集和测试集

plt.figure(figsize=(20, 6))
plt.plot(data_train_index, data_train, label='Train_data')
plt.plot(data_test_index, data_test, label='Test_data')
plt.legend(loc='best')
plt.xticks(data_ts.index[::150],data_ts['trade_date'][::150])
plt.title('训练数据和测试数据')
plt.show()

对数据归一化并reshape成想要的格式

scaler = MinMaxScaler(feature_range=(0, 1))
data_train = scaler.fit_transform(data_train)
data_test = scaler.transform(data_test)def creat_dataset(data):x=[]y=[]for i in range(50, data.shape[0]):x.append(data[i-50:i, 0])y.append(data[i, 0])x = np.array(x)y = np.array(y)return x, yx_train, y_train = creat_dataset(data_train)
x_test, y_test = creat_dataset(data_test)x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)

定义模型

model = Sequential()
model.add(LSTM(units=96, return_sequences=True, input_shape=(x_train.shape[1], 1)))
model.add(Dropout(0.2))
model.add(LSTM(units=96, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(units=96, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(units=96, return_sequences=True))
model.add(Dropout(0.2))
model.add(Dense(units=1))model.compile(loss='mean_squared_error', optimizer='adam')print(model.summary())

模型训练

一行代码就够

model.fit(x_train, y_train, epochs=50, batch_size=32)

预测股价并可视化

用刚刚训练的模型在2020年收盘价的数据上预测一下

predictions = model.predict(x_test)
predictions = scaler.inverse_transform(predictions.squeeze())
y_test_scaled = scaler.inverse_transform(y_test.reshape(-1, 1))

画图看看结果如何

fig = plt.figure(figsize=(20, 6))
x1 = data_ts.index[0:len(data_ts['close'].values)]
x2 = data_ts.index.values[len(y_train)+100:len(y_train)+100+len(predictions)]
plt.plot(x1, data_ts['close'].values, color='red', label='True Price')
plt.plot(x2, predictions[:,49], color='blue', label='Predicted Testing Price')
plt.xticks(data_ts.index[::150],data_ts['trade_date'][::150])
plt.legend(loc=3)
plt.title('预测结果')
left,bottom,width,height=0.5,0.55,0.35,0.3
plt.axes([left,bottom,width,height])    # 设置figure的比例大小
plt.plot(data_ts.index[-181:-1], y_test_scaled, color='red')
plt.plot(data_ts.index[-181:-1], predictions[:, 49], color='blue')
plt.xticks(data_ts.index[-180:-1][::40],data_ts['trade_date'][::40])
plt.xticks([])
plt.yticks([])
plt.show()

图中蓝色的是预测的结果,虽然不能完全准确但走势基本是相同的。

参考油管:https://www.youtube.com/watch?v=lpU3PGyDKQ4

注:本文只用做经验分享,不做任何商业推广。

如有侵权,请联系删除!

利用LSTM对股价进行预测并可视化相关推荐

  1. 利用LSTM进行股价预测

    利用LSTM进行股价预测 效果 原理 代码 应用 效果 原理 LSTM即长短记忆网络,是一种很强的RNN,这种网络的特性是以前的输入会影响现在的输出,具体原理请自行搜索. 算法流程: 获取yahoo财 ...

  2. Python中利用LSTM模型进行时间序列预测分析

    时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...

  3. 利用LSTM进行空气指数预测

    毕设终于结束,感谢指导老师以及团队大伙们的辛苦付出,是时候总结一下毕设的内容了. 我们团队的毕业设计是关于利用递归神经网络模型LSTM(long-short-term memory)对中国主要城市的空 ...

  4. TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)

    目录 I. 前言 II. 数据处理 III. LSTM模型 IV. 训练/测试 V. 源码及数据 I. 前言 在前面的一篇文章TensorFlow搭建LSTM实现时间序列预测(负荷预测)中,我们利用L ...

  5. Python中TensorFlow长短期记忆神经网络LSTM、指数移动平均法预测股票市场时间序列和可视化

    最近我们被客户要求撰写关于LSTM的研究报告,包括一些图形和统计输出. 本文探索Python中的长短期记忆(LSTM)网络,以及如何使用它们来进行股市预测. 相关视频:LSTM神经网络架构和工作原理及 ...

  6. TF之LSTM:利用LSTM算法对Boston(波士顿房价)数据集【13+1,506】进行回归预测(房价预测)

    TF之LSTM:利用LSTM算法对Boston(波士顿房价)数据集[13+1,506]进行回归预测(房价预测) 相关文章 DL之LSTM:利用LSTM算法对Boston(波士顿房价)数据集[13+1, ...

  7. pytorch利用rnn通过sin预测cos 利用lstm预测手写数字

    一.利用rnn通过sin预测cos 1.首先可视化一下数据 import numpy as np from matplotlib import pyplot as plt def show(sin_n ...

  8. pytorch LSTM的股价预测

    股价预测一直以来都是幻想能够被解决的问题,本文中主要使用了lstm模型去对股价做一个大致的预测,数据来源是tushare,非常感谢tushare的数据!! 为什么要用LSTM? LSTM是一种序列模型 ...

  9. DL之LSTM:基于《wonderland爱丽丝梦游仙境记》小说数据集利用LSTM算法(层加深,基于keras)对单个character字符预测

    DL之LSTM:基于<wonderland爱丽丝梦游仙境记>小说数据集利用LSTM算法(层加深,基于keras)对单个character字符预测 目录 基于<wonderland爱丽 ...

  10. ML之kNNC:基于iris莺尾花数据集(PCA处理+三维散点图可视化)利用kNN算法实现分类预测daiding

    ML之kNNC:基于iris莺尾花数据集(PCA处理+三维散点图可视化)利用kNN算法实现分类预测 目录 基于iris莺尾花数据集(PCA处理+三维散点图可视化)利用kNN算法实现分类预测 设计思路 ...

最新文章

  1. 鲲鹏服务器光盘安装操作系统,鲲鹏服务器上安装
  2. 程序员入错行怎么办?
  3. 2016 、12 、11本周
  4. Newtonsoft.Json高级用法
  5. 为什么 if else 不是好代码?
  6. pythonappium环境搭建_python appium环境搭建
  7. 能使曲线变平滑的一维滤波器_双边滤波器的原理及实现
  8. python删除首行_Python删除文件第一行
  9. java clone concurrentlinkedqueue_java – ConcurrentLinkedQueue代码解释
  10. 笔记本电脑下载python视频-学Python买什么笔记本电脑?
  11. 【优化算法】吉萨金字塔建造优化算法(GPC)【含Matlab源码 1438期】
  12. 魔兽实名好友怎么显示服务器,魔兽世界实名好友跨服组队详细解析
  13. Android 渠道游戏 - 聚合SDK
  14. 数字系统实验—第11-12周任务(认识数据存储芯片HM62256、IP核、LPM开发流程和平台、 IIC串行总线时序分析)
  15. 近期爬虫学习体会以及爬豆瓣Top250源码实战
  16. LED电子时钟显示屏(NTP时间同步服务器)是如何完成授时服务的?
  17. 嵩山少林寺网站向全世界公布了千年武功秘籍
  18. 某网站提供的免费香港虚拟主机测试
  19. 强国的语言与语言强国
  20. 【STM32学习】(21)STM32实现步进电机

热门文章

  1. python爬虫爬取链家二手房信息
  2. 【转】iPhone通讯录AddressBook.framework和AddressBookUI.framework的应用
  3. postgres 禁止远程登录_Postgresql允许远程访问配置修改
  4. C++ 制作FlappyBird
  5. 深入了解Element Form表单动态验证问题
  6. CentOS7配简单的桌面环境openbox
  7. lang3之StringUtils
  8. java java.lang.string_无法将java.lang.String字段设置为java.lang.String
  9. 开发一款游戏引擎需要的知识与技术
  10. VS2019配置WTL10.0