声明一下:本人现在在学习机器学习以及深度学习方面的知识,想通过CSDN平台去记录自己的学习历程,也希望可以和大家一起学习,共同进步。


文章目录

  • 1.读取数据
  • 2.添加特征(小时、天、月份、星期)
  • 3.按月进行重新采样
  • 4.划分训练集和测试集
  • 5.数据预处理
  • 6.设置时间步长(LSTM)
  • 7. 建立LSTM模型
  • 8.训练模型
  • 9.模型预测
  • 10.预测结果可视化

1.读取数据

df = pd.read_csv("london_bike_sharing.csv", parse_dates=['timestamp'], index_col="timestamp"
)

展示部分数据

2.添加特征(小时、天、月份、星期)

df['hour'] = df.index.hour
df['day_of_month'] = df.index.day
df['day_of_week'] = df.index.dayofweek
df['month'] = df.index.month

3.按月进行重新采样

df_by_month = df.resample('M').sum()#向下采样并执行聚合

4.划分训练集和测试集

train_size = int(len(df) * 0.9)
test_size = len(df) - train_size
train, test = df.iloc[0:train_size], df.iloc[train_size:len(df)]
print(len(train), len(test))

最后我们得到训练集和测试集的大小分别为15672 和1742。

5.数据预处理

拓展:我们在机器学习领域,总是看到“算法的鲁棒性”这类字眼。搜查资料发现鲁棒性可以有3个层面的概念:
1.模型具有较高的精度或有效性,这也是对于机器学习中所有学习模型的基本要求;
2.对于模型假设出现的较小偏差,只能对算法性能产生较小的影响; 主要是:噪声(noise)
3.对于模型假设出现的较大偏差,不可对算法性能产生“灾难性”的影响;主要是:离群点(outlier)
为了提高算法的鲁棒性(剔除不必要的离群点),该案例引入sklearn中的RobustScaler 函数。
RobustScaler的计算方法如下:

其中vi表示样本的某个值;median是样本的中位数;IQR是样本的四分位距

from sklearn.preprocessing import RobustScaler#导包
f_columns = ['t1', 't2', 'hum', 'wind_speed']#选择要处理的列
f_transformer = RobustScaler()
cnt_transformer = RobustScaler()
#fit()的作用:可以理解为一个训练过程
f_transformer = f_transformer.fit(train[f_columns].to_numpy())#这里将dataframe转化为numpy
cnt_transformer = cnt_transformer.fit(train[['cnt']])
#分别对训练集和测试集进行处理
train.loc[:, f_columns] = f_transformer.transform(train[f_columns].to_numpy())
train['cnt'] = cnt_transformer.transform(train[['cnt']])
test.loc[:, f_columns] = f_transformer.transform(test[f_columns].to_numpy())
test['cnt'] = cnt_transformer.transform(test[['cnt']])
#transfrom()的作用:在fit的基础上,进行标准化,降维,归一化等操作

拓展:
上面的代码将fit和transform分开写,其实也可以用fit_transform。
1.fit_transform是fit和transform的组合。fit_transform是将fit和transform合并,即先拟合数据,然后转化它将其转化为标准形式。
2.注意的是,训练集使用fit_transform(),而测试集使用tranform(),不再使用fit_transform();原因:必须确保两个训练集有相同的数据指标,即处理方式相同!
3.举个简单的例子
fit_transform(train data)先根据具体转换的目的找到数据的整体指标,如均值、方差、最大值最小值…,然后对train data进行transform,从而实现数据的标准化、归一化…。根据之前fit的整体指标,对test data使用同样的均值、方差、最大最小值等指标进行transform(test data),从而保证train data和test data的处理方式相同。
原文请查看此处

6.设置时间步长(LSTM)

#定义一个设置时间步长的函数
def create_dataset(X, y, time_steps=1):Xs, ys = [], []for i in range(len(X) - time_steps):v = X.iloc[i:(i + time_steps)].valuesXs.append(v)        ys.append(y.iloc[i + time_steps])return np.array(Xs), np.array(ys)#转换为numpy数组

对训练、测试集分别调用函数

time_steps = 10#时间步长为10个小时
# reshape to [samples, time_steps, n_features]
X_train, y_train = create_dataset(train, train.cnt, time_steps)
X_test, y_test = create_dataset(test, test.cnt, time_steps)
print(X_train.shape, y_train.shape)

得到训练、测试集的大小为(15662, 10, 13) (15662,)
其中的(15662, 10, 13)可以这样理解 :

  • 15662:样本数
  • 10:10小时
  • 13:除了训练标签以外的列数

7. 建立LSTM模型

model = keras.Sequential()#创建顺序模型
model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128, input_shape=(X_train.shape[1], X_train.shape[2])#输入层只需要给出样本的特征尺寸))
)
Bidirectional:实现RNN类型神经网络的双向构造
model.add(keras.layers.Dropout(rate=0.2))
model.add(keras.layers.Dense(units=1))
model.compile(loss='mean_squared_error', optimizer='adam')#损失函数为'MSE',优化器为Adam

8.训练模型

history = model.fit(X_train, y_train, epochs=30, batch_size=32, validation_split=0.1,shuffle=False#不打乱顺序
)

看一下训练结果

plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.legend();

9.模型预测

y_pred = model.predict(X_test)

对标签数据进行反归一化处理(inverse_transform)

y_train_inv = cnt_transformer.inverse_transform(y_train.reshape(1, -1))
y_test_inv = cnt_transformer.inverse_transform(y_test.reshape(1, -1))
y_pred_inv = cnt_transformer.inverse_transform(y_pred)
注意还是要采取之前归一化的标准,即这里的cnt_transformer

如果归一化时的对象的shape是(n, 3),则反归一化时的 data 的shape必须是(m, 3)

10.预测结果可视化

plt.plot(y_test_inv.flatten(), marker='.', label="true")
plt.plot(y_pred_inv.flatten(), 'r', label="prediction")
plt.ylabel('Bike Count')
plt.xlabel('Time Step')
plt.legend()
plt.show();

关于flatten函数的介绍:

  1. 功能:将numpy数组展开为一维数组
  2. 默认方向是行方向,加’a’也是行方向,但是加‘f’是列方向
  3. flatten函数不能直接作用于列表 !

案例学习1.LSTM相关推荐

  1. 深度学习之LSTM案例分析(三)

    #背景 来自GitHub上<tensorflow_cookbook>[https://github.com/nfmcclure/tensorflow_cookbook/tree/maste ...

  2. 通过脚本案例学习shell(五) 通过创建DNS脚本一步一步教你将一个普通脚本规范到一个生产环境脚本...

    通过脚本案例学习shell(五) 通过创建DNS脚本一步一步教你将一个普通脚本规范到一个生产环境脚本   版权声明: 本文遵循"署名非商业性使用相同方式共享 2.5 中国大陆"协议 ...

  3. 《大数据导论》——1.4节案例学习背景

    本节书摘来自华章社区<大数据导论>一书中的第1章,第1.4节案例学习背景,作者瓦吉德·哈塔克(Wajid Khattak),保罗·布勒(Paul Buhler),更多章节内容可以访问云栖社 ...

  4. 【深度学习】LSTM神经网络解决COVID-19预测问题(二)

    [深度学习]LSTM神经网络解决COVID-19预测问题(二) 文章目录 1 概述 2 模型求解和检验 3 模型代码 4 模型评价与推广 5 参考 1 概述 建立一个普适性较高的模型来有效预测疫情的达 ...

  5. 【深度学习】LSTM神经网络解决COVID-19预测问题(一)

    [深度学习]LSTM神经网络解决COVID-19预测问题 文章目录 1 概述 2 数据分析 3 SIR模型和LSTM网络的对比 4 LSTM神经网络的建立 5 参考 1 概述 我们将SIR传播模型和L ...

  6. ArcGIS案例学习1_2

    ArcGIS案例学习1_2 联系方式:向日葵,135_4855_4328, xiexiaokui#qq.com 时间:第一天下午 案例1:矢量提取,栅格提取和坐标系投影变换 目的:认识数据类型 教程: ...

  7. 零元学Expression Blend 4 ndash; Chapter 21 以实作案例学习MouseDragElementBehavior

    原文:零元学Expression Blend 4 – Chapter 21 以实作案例学习MouseDragElementBehavior 本章将教大家如何运用Blend 4内建的行为注入元件「Mou ...

  8. ArcGIS案例学习笔记-找出最近距离的垂线

    ArcGIS案例学习笔记-找出最近距离的垂线 联系方式:谢老师,135-4855-4328,xiexiaokui@qq.com 目的:对于任意矢量要素类,查找最近距离并做图 数据: 方法: 0. 计算 ...

  9. ArcGIS案例学习笔记2_2_等高线生成DEM和三维景观动画

    ArcGIS案例学习笔记2_2_等高线生成DEM和三维景观动画 计划时间:第二天下午 教程:Pdf/405 数据:ch9/ex3 方法: 1. 创建DEM SA工具箱/插值分析/地形转栅格 2. 生成 ...

最新文章

  1. 环形缓冲区: ringbuf.c
  2. Visual Studio调试时遇到的问题:生成下面模块时,启用了优化或没有调试信息
  3. ajax和cs的关系,fetch、axios 与Ajax之间关系
  4. HttpServlet的doGet()和doPost()方法
  5. JavaScript中四种不同的属性检测方式比较
  6. 3-docker 架构和底层技术简介
  7. 把十六进制字符转换成十进制数
  8. Flutter NestedScrollView 滑动折叠头部下拉刷新效果
  9. 【英语学习】【Daily English】U03 Leisure Time L02 I'm more of an indoorsy person anyway
  10. Java Jackson
  11. 燃烧我的卡路里 ---- Flutter瘦内存瘦包之图片组件
  12. C# WebService 上传图片
  13. mysql优化--explain分析sql语句执行效率
  14. 用java写出死锁的例子_【面试】请写一个java死锁的例子-Go语言中文社区
  15. c语言各章知识重点(谭浩强版本)
  16. C++类学习---------step1
  17. 万恶的android
  18. aip通用文档 服务器,为 Rights Management 连接器配置服务器 - AIP | Microsoft Docs
  19. 运行tomcat报错:Address localhost:1099 is already in use
  20. Python免费的家庭视频监控系统(1)

热门文章

  1. 二维码生成器和二维码扫描器
  2. 普元 EOS Platform Governor HTTP接入不拦截Url配置规则
  3. 再次挑战自己,骑行成都天府绿道100公里
  4. Android Studio升级后 出现在No subject alternative DNS name matching services.gradle.org found.
  5. mysql在手游中的作用_战神引挚手游数据库解析mysql/mir
  6. MCycDB:环境微生物组甲烷循环数据库
  7. UI设计初学者快速入门的5大建议!
  8. 符合JEITA规范的锂离子电池充电器解决方案
  9. 视频直播推流不成功如何排查
  10. java se 09