LSTM共享单车使用量预测

  • 1、加载数据集、数据可视化、预处理
    • - 引入包
    • - 加载数据集
    • - 数据集描述方式
    • - 数据集可视化处理
  • 5、模型搭建、编译、训练
    • - 模型搭建
    • - 模型编译
    • - 保持模型权重文件
    • - 模型训练
    • - 曲线显示训练结果
  • 6、模型验证

来自哔哩哔哩课程LSTM共享单车使用量预测
,属于多变量预测案例

1、加载数据集、数据可视化、预处理

- 引入包

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import datetimefrom sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_scoreimport tensorflow as tf
from tensorflow.keras import Sequential, layers, utils, losses
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoardimport warnings
warnings.filterwarnings('ignore')

- 加载数据集

# 加载数据集
dataset = pd.read_csv("BikeShares.csv", parse_dates=['timestamp'], index_col=['timestamp'])

- 数据集描述方式

dataset.shape #数据集大小
dataset.head() #默认显示前5行
dataset.tail() #默认显示后5行
dataset.info() #数据集信息:总结的信息
dataset.describe() #数据集描述:描述最大值最小值等信息

- 数据集可视化处理

通过可视化处理能直观地展现不同变量之间的关系

# 字段t1(气温)与字段cnt(单车使用量)之间的关系
plt.figure(figsize=(16,8))
sns.pointplot(x='t1', y='cnt', data=dataset) #使用'pointplot()'是'点-线'图
plt.show()# 字段t2(体感温度)与字段cnt(单车使用量)之间的关系
plt.figure(figsize=(16,8))
sns.lineplot(x='t2', y='cnt', data=dataset) #使用'lineplot()'
plt.show()
图1 pointplot图 图2 lineplot图

#### - 数据归一化处理 把数据较大的列数据进行归一化处理 ```python # 分别对字段t1, t2, hum, wind_speed进行归一化 columns = ['cnt', 't1', 't2', 'hum', 'wind_speed'] for col in columns: scaler = MinMaxScaler() dataset[col] = scaler.fit_transform(dataset[col].values.reshape(-1,1)) ``` #### - 构建特征集

# 特征数据集
X = dataset.drop(columns=['cnt'], axis=1)
# 标签数据集
y = dataset['cnt']# 1 数据集分离: X_train, X_test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False, random_state=666)
# 使用'train_test_split'函数进行数据集切分,切分比例为80%训练集,'shuffle=False'表示切分后不打乱顺序,'random_state'规定了数据的顺序# 2 构造特征数据集
def create_dataset(X, y, seq_len=10):features = []targets = []for i in range(0, len(X) - seq_len, 1):data = X.iloc[i:i+seq_len] # 序列数据  利用'iloc'函数,取行label = y.iloc[i+seq_len] # 标签数据# 保存到features和labelsfeatures.append(data)targets.append(label)# 返回return np.array(features), np.array(targets)
# ① 构造训练特征数据集
train_dataset, train_labels = create_dataset(X_train, y_train, seq_len=10)
# ② 构造测试特征数据集
test_dataset, test_labels = create_dataset(X_test, y_test, seq_len=10)
# 3 构造批数据
def create_batch_dataset(X, y, train=True, buffer_size=1000, batch_size=128):batch_data = tf.data.Dataset.from_tensor_slices((tf.constant(X), tf.constant(y))) # 数据封装,tensor类型if train: # 训练集return batch_data.cache().shuffle(buffer_size).batch(batch_size)else: # 测试集return batch_data.batch(batch_size)
# 训练批数据
train_batch_dataset = create_batch_dataset(train_dataset, train_labels)
# 测试批数据
test_batch_dataset = create_batch_dataset(test_dataset, test_labels, train=False)

5、模型搭建、编译、训练

- 模型搭建

# 模型搭建--版本1
model = Sequential([layers.LSTM(units=256, input_shape=train_dataset.shape[-2:], return_sequences=True),# 这里input_shape可以直接写(10,8)layers.Dropout(0.4), #舍去一部分神经元,保留一部分,符合LSTM定义layers.LSTM(units=256, return_sequences=True),layers.Dropout(0.3),layers.LSTM(units=128, return_sequences=True),layers.LSTM(units=32),layers.Dense(1)
])

- 模型编译

# 模型编译
model.compile(optimizer='adam',loss='mse')

- 保持模型权重文件

checkpoint_file = "best_model.hdf5" #实际报错了,可以改为.ckpt
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_file, monitor='loss',mode='min',save_best_only=True,save_weights_only=True)

- 模型训练

# 模型训练
history = model.fit(train_batch_dataset,epochs=30,validation_data=test_batch_dataset,callbacks=[tensorboard_callback, checkpoint_callback])

- 曲线显示训练结果

# 显示训练结果
plt.figure(figsize=(16,8))
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend(loc='best')
plt.show()

6、模型验证

test_preds = model.predict(test_dataset, verbose=1)
test_preds = test_preds[:, 0] # 获取列值
# 计算r2值
score = r2_score(test_labels, test_preds)
print("r^2 值为: ", score)
# 绘制 预测与真值结果
plt.figure(figsize=(16,8))
plt.plot(test_labels[:300], label="True value") #取前300个绘制曲线图
plt.plot(test_preds[:300], label="Pred value")
plt.legend(loc='best')
plt.show()
  • LSTM多特征值与LSTM单特征值相比,基本相同,区别在于送入的特征数据集特征数增多了

[人工智能学习日志]深度学习-LSTM共享单车使用量预测相关推荐

  1. [人工智能学习日志]深度学习-股票价格预测案例1

    来自股票价格预测bilibili课程. 源自jupyter notebook文件main.ipynb. 代码用tf1书写,使用tf2会因为版本不对应而报错,tf2版本的代码后续再研究. 股票价格预测 ...

  2. 1、数据分析--共享单车使用量预测

    数据字段分析 列名 desc 中文描述 datetime hourly date + timestamp 小时日期 和时间戳 season 1 = spring, 2 = summer, 3 = fa ...

  3. python3人工智能网盘_《Python3入门人工智能掌握机器学习+深度学习提升实战能力》百度云网盘资源分享下载[MP4/5.77GB]...

    内容简介 本资源为<Python3入门人工智能掌握机器学习+深度学习提升实战能力>百度云网盘资源分享下载,具体看下文目录,格式为MP4/5.77GB.本资源已做压缩包处理,请勿直接在百度网 ...

  4. 【人工智能项目】- 深度学习实现猫狗大战

    [人工智能项目]- 深度学习实现猫狗大战 本次实现猫狗大战,实质上就是猫狗的二分类任务. 环境 !nvidia-smi Mon Jun 22 04:24:29 2020 +-------------- ...

  5. 人工智能趋势与深度学习算法

    人工智能趋势与深度学习算法 1 前沿技术 1.1 Transformer模型: 1.2 BERT模型:基于Transformer Encoder构建的预测模型 1.3 自监督学习(Self-super ...

  6. 人工智能,机器学习,深度学习入门好文,强烈推荐

    让我们从机器学习谈起 导读:在本篇文章中,将对机器学习做个概要的介绍.本文的目的是能让即便完全不了解机器学习的人也能了解机器学习,并且上手相关的实践.当然,本文也面对一般读者,不会对阅读有相关的前提要 ...

  7. 【人工智能项目】深度学习实现白葡萄酒品质预测

    [人工智能项目]深度学习实现白葡萄酒品质预测 任务介绍 评价一款葡萄酒时不外乎从颜色.酸度.甜度.香气.风味等入手,而决定这些就是葡萄酒的挥发酸度.糖分.密度等. 根据给出的白葡萄酒酸度.糖分.PH值 ...

  8. 【人工智能项目】深度学习实现汉字书法识别

    [人工智能项目]深度学习实现汉字书法识别 背景介绍 竞赛数据提供100个汉字书法单字,包括碑帖,手写书法,古汉字等.图片全部为单通道宽度jpg,宽高不定. 数据集介绍 训练集:每个汉字400张图片,共 ...

  9. 【人工智能项目】深度学习实现10类猴子细粒度识别

    [人工智能项目]深度学习实现10类猴子细粒度识别 任务说明 本次比赛需要选手准确识别10种猴子,数据集只有图片,没有boundbox等标注数据. 环境说明 !nvidia-smi Fri Mar 27 ...

最新文章

  1. django系列 1 :python+django环境搭建 +mac提示找不到manage.py命令
  2. excel中如何et vb根据数据自动生成表格_如何实现excel与PPT互联互通(动态生成PPT)...
  3. gRPC客户端创建和调用原理解析
  4. 在数组里查找这样的数,它大于等于左侧所有数,小于等于右侧所有数
  5. 温州大学c语言作业布置的网站,老师APP上布置作业 三年级娃为刷排名半夜做题_央广网...
  6. 前端学习(2015)vue之电商管理系统电商系统之实现图片的预览效果
  7. Linux下grub.cnf详解
  8. 14 FI配置-财务会计-定义未结清过帐期间变式
  9. 【Bringing Old Photos Back to Life】How to train?如何训练
  10. linux安装neo4j
  11. iOS中的坑:URL不识别##
  12. moodle安装过程中可能出现的问题
  13. 【Java基础笔记】ASCll码表
  14. 批量调度工具 Taskctl 作业类型的维护管理
  15. 2018网易编程射击游戏
  16. Vue路由导航报错:NavigationDuplicated: Avoided redundant navigation to current location解决方法
  17. 苹果 watchOS 3.2 首个测试版:剧场模式、SiriKit
  18. 手把手教你做一个网页
  19. MediaRecorder录制视频和录音
  20. 照片文件与计算机系统,照片文件格式怎么修改

热门文章

  1. 价值1500的全新UI众人帮任务帮PHP源码/悬赏任务抖音快手头条点赞源码/带三级分销可封装小程序
  2. 错误代码:0x80030001的一个好的解决办法
  3. 交换机相关--VLAN
  4. nginx与php重启
  5. 不懂如何在图片上添加贴纸?马上教你图片加贴纸方法
  6. comint32.sys,GD*I32.dll,bj*rl.dll,addr*help.dll***群的删除
  7. 王者约战电竞平台 Java+原生开发 源码 顶级体验
  8. 设置vim 永久显示行号
  9. Android 快速开发框架:推荐10个框架
  10. fanuc机器人防干涉区域功能