[人工智能学习日志]深度学习-LSTM共享单车使用量预测
LSTM共享单车使用量预测
- 1、加载数据集、数据可视化、预处理
- - 引入包
- - 加载数据集
- - 数据集描述方式
- - 数据集可视化处理
- 5、模型搭建、编译、训练
- - 模型搭建
- - 模型编译
- - 保持模型权重文件
- - 模型训练
- - 曲线显示训练结果
- 6、模型验证
来自哔哩哔哩课程LSTM共享单车使用量预测
,属于多变量预测案例
1、加载数据集、数据可视化、预处理
- 引入包
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import datetimefrom sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_scoreimport tensorflow as tf
from tensorflow.keras import Sequential, layers, utils, losses
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoardimport warnings
warnings.filterwarnings('ignore')
- 加载数据集
# 加载数据集
dataset = pd.read_csv("BikeShares.csv", parse_dates=['timestamp'], index_col=['timestamp'])
- 数据集描述方式
dataset.shape #数据集大小
dataset.head() #默认显示前5行
dataset.tail() #默认显示后5行
dataset.info() #数据集信息:总结的信息
dataset.describe() #数据集描述:描述最大值最小值等信息
- 数据集可视化处理
通过可视化处理能直观地展现不同变量之间的关系
# 字段t1(气温)与字段cnt(单车使用量)之间的关系
plt.figure(figsize=(16,8))
sns.pointplot(x='t1', y='cnt', data=dataset) #使用'pointplot()'是'点-线'图
plt.show()# 字段t2(体感温度)与字段cnt(单车使用量)之间的关系
plt.figure(figsize=(16,8))
sns.lineplot(x='t2', y='cnt', data=dataset) #使用'lineplot()'
plt.show()
![]() |
![]() |
#### - 数据归一化处理 把数据较大的列数据进行归一化处理 ```python # 分别对字段t1, t2, hum, wind_speed进行归一化 columns = ['cnt', 't1', 't2', 'hum', 'wind_speed'] for col in columns: scaler = MinMaxScaler() dataset[col] = scaler.fit_transform(dataset[col].values.reshape(-1,1)) ``` #### - 构建特征集
# 特征数据集
X = dataset.drop(columns=['cnt'], axis=1)
# 标签数据集
y = dataset['cnt']# 1 数据集分离: X_train, X_test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False, random_state=666)
# 使用'train_test_split'函数进行数据集切分,切分比例为80%训练集,'shuffle=False'表示切分后不打乱顺序,'random_state'规定了数据的顺序# 2 构造特征数据集
def create_dataset(X, y, seq_len=10):features = []targets = []for i in range(0, len(X) - seq_len, 1):data = X.iloc[i:i+seq_len] # 序列数据 利用'iloc'函数,取行label = y.iloc[i+seq_len] # 标签数据# 保存到features和labelsfeatures.append(data)targets.append(label)# 返回return np.array(features), np.array(targets)
# ① 构造训练特征数据集
train_dataset, train_labels = create_dataset(X_train, y_train, seq_len=10)
# ② 构造测试特征数据集
test_dataset, test_labels = create_dataset(X_test, y_test, seq_len=10)
# 3 构造批数据
def create_batch_dataset(X, y, train=True, buffer_size=1000, batch_size=128):batch_data = tf.data.Dataset.from_tensor_slices((tf.constant(X), tf.constant(y))) # 数据封装,tensor类型if train: # 训练集return batch_data.cache().shuffle(buffer_size).batch(batch_size)else: # 测试集return batch_data.batch(batch_size)
# 训练批数据
train_batch_dataset = create_batch_dataset(train_dataset, train_labels)
# 测试批数据
test_batch_dataset = create_batch_dataset(test_dataset, test_labels, train=False)
5、模型搭建、编译、训练
- 模型搭建
# 模型搭建--版本1
model = Sequential([layers.LSTM(units=256, input_shape=train_dataset.shape[-2:], return_sequences=True),# 这里input_shape可以直接写(10,8)layers.Dropout(0.4), #舍去一部分神经元,保留一部分,符合LSTM定义layers.LSTM(units=256, return_sequences=True),layers.Dropout(0.3),layers.LSTM(units=128, return_sequences=True),layers.LSTM(units=32),layers.Dense(1)
])
- 模型编译
# 模型编译
model.compile(optimizer='adam',loss='mse')
- 保持模型权重文件
checkpoint_file = "best_model.hdf5" #实际报错了,可以改为.ckpt
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_file, monitor='loss',mode='min',save_best_only=True,save_weights_only=True)
- 模型训练
# 模型训练
history = model.fit(train_batch_dataset,epochs=30,validation_data=test_batch_dataset,callbacks=[tensorboard_callback, checkpoint_callback])
- 曲线显示训练结果
# 显示训练结果
plt.figure(figsize=(16,8))
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend(loc='best')
plt.show()
6、模型验证
test_preds = model.predict(test_dataset, verbose=1)
test_preds = test_preds[:, 0] # 获取列值
# 计算r2值
score = r2_score(test_labels, test_preds)
print("r^2 值为: ", score)
# 绘制 预测与真值结果
plt.figure(figsize=(16,8))
plt.plot(test_labels[:300], label="True value") #取前300个绘制曲线图
plt.plot(test_preds[:300], label="Pred value")
plt.legend(loc='best')
plt.show()
- LSTM多特征值与LSTM单特征值相比,基本相同,区别在于送入的特征数据集特征数增多了
[人工智能学习日志]深度学习-LSTM共享单车使用量预测相关推荐
- [人工智能学习日志]深度学习-股票价格预测案例1
来自股票价格预测bilibili课程. 源自jupyter notebook文件main.ipynb. 代码用tf1书写,使用tf2会因为版本不对应而报错,tf2版本的代码后续再研究. 股票价格预测 ...
- 1、数据分析--共享单车使用量预测
数据字段分析 列名 desc 中文描述 datetime hourly date + timestamp 小时日期 和时间戳 season 1 = spring, 2 = summer, 3 = fa ...
- python3人工智能网盘_《Python3入门人工智能掌握机器学习+深度学习提升实战能力》百度云网盘资源分享下载[MP4/5.77GB]...
内容简介 本资源为<Python3入门人工智能掌握机器学习+深度学习提升实战能力>百度云网盘资源分享下载,具体看下文目录,格式为MP4/5.77GB.本资源已做压缩包处理,请勿直接在百度网 ...
- 【人工智能项目】- 深度学习实现猫狗大战
[人工智能项目]- 深度学习实现猫狗大战 本次实现猫狗大战,实质上就是猫狗的二分类任务. 环境 !nvidia-smi Mon Jun 22 04:24:29 2020 +-------------- ...
- 人工智能趋势与深度学习算法
人工智能趋势与深度学习算法 1 前沿技术 1.1 Transformer模型: 1.2 BERT模型:基于Transformer Encoder构建的预测模型 1.3 自监督学习(Self-super ...
- 人工智能,机器学习,深度学习入门好文,强烈推荐
让我们从机器学习谈起 导读:在本篇文章中,将对机器学习做个概要的介绍.本文的目的是能让即便完全不了解机器学习的人也能了解机器学习,并且上手相关的实践.当然,本文也面对一般读者,不会对阅读有相关的前提要 ...
- 【人工智能项目】深度学习实现白葡萄酒品质预测
[人工智能项目]深度学习实现白葡萄酒品质预测 任务介绍 评价一款葡萄酒时不外乎从颜色.酸度.甜度.香气.风味等入手,而决定这些就是葡萄酒的挥发酸度.糖分.密度等. 根据给出的白葡萄酒酸度.糖分.PH值 ...
- 【人工智能项目】深度学习实现汉字书法识别
[人工智能项目]深度学习实现汉字书法识别 背景介绍 竞赛数据提供100个汉字书法单字,包括碑帖,手写书法,古汉字等.图片全部为单通道宽度jpg,宽高不定. 数据集介绍 训练集:每个汉字400张图片,共 ...
- 【人工智能项目】深度学习实现10类猴子细粒度识别
[人工智能项目]深度学习实现10类猴子细粒度识别 任务说明 本次比赛需要选手准确识别10种猴子,数据集只有图片,没有boundbox等标注数据. 环境说明 !nvidia-smi Fri Mar 27 ...
最新文章
- django系列 1 :python+django环境搭建 +mac提示找不到manage.py命令
- excel中如何et vb根据数据自动生成表格_如何实现excel与PPT互联互通(动态生成PPT)...
- gRPC客户端创建和调用原理解析
- 在数组里查找这样的数,它大于等于左侧所有数,小于等于右侧所有数
- 温州大学c语言作业布置的网站,老师APP上布置作业 三年级娃为刷排名半夜做题_央广网...
- 前端学习(2015)vue之电商管理系统电商系统之实现图片的预览效果
- Linux下grub.cnf详解
- 14 FI配置-财务会计-定义未结清过帐期间变式
- 【Bringing Old Photos Back to Life】How to train?如何训练
- linux安装neo4j
- iOS中的坑:URL不识别##
- moodle安装过程中可能出现的问题
- 【Java基础笔记】ASCll码表
- 批量调度工具 Taskctl 作业类型的维护管理
- 2018网易编程射击游戏
- Vue路由导航报错:NavigationDuplicated: Avoided redundant navigation to current location解决方法
- 苹果 watchOS 3.2 首个测试版:剧场模式、SiriKit
- 手把手教你做一个网页
- MediaRecorder录制视频和录音
- 照片文件与计算机系统,照片文件格式怎么修改
热门文章
- 价值1500的全新UI众人帮任务帮PHP源码/悬赏任务抖音快手头条点赞源码/带三级分销可封装小程序
- 错误代码:0x80030001的一个好的解决办法
- 交换机相关--VLAN
- nginx与php重启
- 不懂如何在图片上添加贴纸?马上教你图片加贴纸方法
- comint32.sys,GD*I32.dll,bj*rl.dll,addr*help.dll***群的删除
- 王者约战电竞平台 Java+原生开发 源码 顶级体验
- 设置vim 永久显示行号
- Android 快速开发框架:推荐10个框架
- fanuc机器人防干涉区域功能