ESN学习笔记——echotorch(2)narma10
文章目录
- 我忍不住吐槽一下这个包,第一行代码就报错,没有datasets是闹哪样
- 加载库
- 训练样本长度
- 测试样本长度
- 有多少训练/测试样本
- 批量大小(同时处理多少个样本?)
- 超参数
- 预测/目标长度
- 使用 CUDA?
- 人工种子初始化
- NARMA30 数据集
- 数据加载器
- 内部矩阵W生成
- 输入权重Win生成
- Bias vector偏差向量生成
- 创建一个Leaky-integrated ESN,用最小二乘训练算法,调用 esn = etrs.ESN(
- 如果可能,在GPU中传输
- 对每一批
- 计算输出矩阵Wout来结束训练
- 得到训练集中的第一个样本,转为Variable.
- 用我们训练过的ESN做预测
- 打印训练误差 MSE和 NRMSE
- 得到测试集中的第一个样本转为Variable.
- 用我们训练过的ESN做预测
- 打印测试集上的 MSE 和NRMSE
- 展示目标和预测
我忍不住吐槽一下这个包,第一行代码就报错,没有datasets是闹哪样
加载库
import torch
from echotorch.datasets.NARMADataset import NARMADataset
import echotorch.nn.reservoir as etrs
import echotorch.utils
import echotorch.utils.matrix_generation as mg
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
import numpy as np
import matplotlib.pyplot as plt
训练样本长度
train_sample_length = 5000
测试样本长度
test_sample_length = 1000
有多少训练/测试样本
n_train_samples = 1
n_test_samples = 1
批量大小(同时处理多少个样本?)
batch_size = 1
超参数
spectral_radius = 1.07
leaky_rate = 0.9261
input_dim = 1
reservoir_size = 410
connectivity = 0.1954
ridge_param = 0.00000409
input_scaling = 0.9252
bias_scaling = 0.079079
预测/目标长度
plot_length = 200
使用 CUDA?
use_cuda = False
use_cuda = torch.cuda.is_available() if use_cuda else False
人工种子初始化
np.random.seed(1)
torch.manual_seed(1)
NARMA30 数据集
narma10_train_dataset = NARMADataset(train_sample_length,
n_train_samples, system_order=10)
narma10_test_dataset = NARMADataset(test_sample_length,
n_test_samples, system_order=10)
数据加载器
trainloader = DataLoader(narma10_train_dataset,
batch_size=batch_size, shuffle=False, num_workers=2)
testloader = DataLoader(narma10_test_dataset,
batch_size=batch_size, shuffle=False, num_workers=2)
内部矩阵W生成
w_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
connectivity=connectivity,
spetral_radius=spectral_radius
)
输入权重Win生成
win_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
connectivity=connectivity,
scale=input_scaling,
apply_spectral_radius=False
)
Bias vector偏差向量生成
wbias_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
connectivity=connectivity,
scale=bias_scaling,
apply_spectral_radius=False
)
创建一个Leaky-integrated ESN,用最小二乘训练算法,调用 esn = etrs.ESN(
esn = etrs.LiESN(
input_dim=input_dim,
hidden_dim=reservoir_size,
output_dim=1,
leaky_rate=leaky_rate,
learning_algo=‘inv’,
w_generator=w_generator,
win_generator=win_generator,
wbias_generator=wbias_generator,
ridge_param=ridge_param
)
如果可能,在GPU中传输
if use_cuda:
esn.cuda()
#end if
对每一批
for data in trainloader:
#输入和输出
inputs, targets = data
#将数据转换为变量Variables
inputs, targets = Variable(inputs), Variable(targets)
if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
# ESN需要的inputs and targets
esn(inputs, targets)
# end for
计算输出矩阵Wout来结束训练
esn.finalize()
得到训练集中的第一个样本,转为Variable.
dataiter = iter(trainloader)
train_u, train_y = dataiter.next()
train_u, train_y = Variable(train_u), Variable(train_y)
if use_cuda: train_u, train_y = train_u.cuda(), train_y.cuda()
用我们训练过的ESN做预测
y_predicted = esn(train_u)
打印训练误差 MSE和 NRMSE
print(u"Train MSE: {}".format(echotorch.utils.mse(y_predicted.data, train_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, train_y.data)))
print(u"")
得到测试集中的第一个样本转为Variable.
dataiter = iter(testloader)
test_u, test_y = dataiter.next()
test_u, test_y = Variable(test_u), Variable(test_y)
if use_cuda: test_u, test_y = test_u.cuda(), test_y.cuda()
用我们训练过的ESN做预测
y_predicted = esn(test_u)
打印测试集上的 MSE 和NRMSE
print(u"Test MSE: {}".format(echotorch.utils.mse(y_predicted.data, test_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, test_y.data)))
print(u"")
展示目标和预测
plt.plot(test_y[0, :plot_length, 0].data, ‘r’)
plt.plot(y_predicted[0, :plot_length, 0].data, ‘b’)
plt.show()
ESN学习笔记——echotorch(2)narma10相关推荐
- CV学习笔记 | CV综述 [2020.10.01]
文章目录 0. 概述(整理完后随时修改) 1. 人工神经网络 1.1. 人工神经网络发展历程 1.2. 一些神经元节点的工作原理 1.2.1. 基本神经元 1.2.2. 卷积神经元(Convoluti ...
- PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call
您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...
- 容器云原生DevOps学习笔记——第三期:从零搭建CI/CD系统标准化交付流程
暑期实习期间,所在的技术中台-效能研发团队规划设计并结合公司开源协同实现符合DevOps理念的研发工具平台,实现研发过程自动化.标准化: 实习期间对DevOps的理解一直懵懵懂懂,最近观看了阿里专家带 ...
- 容器云原生DevOps学习笔记——第二期:如何快速高质量的应用容器化迁移
暑期实习期间,所在的技术中台-效能研发团队规划设计并结合公司开源协同实现符合DevOps理念的研发工具平台,实现研发过程自动化.标准化: 实习期间对DevOps的理解一直懵懵懂懂,最近观看了阿里专家带 ...
- 2020年Yann Lecun深度学习笔记(下)
2020年Yann Lecun深度学习笔记(下)
- 2020年Yann Lecun深度学习笔记(上)
2020年Yann Lecun深度学习笔记(上)
- 知识图谱学习笔记(1)
知识图谱学习笔记第一部分,包含RDF介绍,以及Jena RDF API使用 知识图谱的基石:RDF RDF(Resource Description Framework),即资源描述框架,其本质是一个 ...
- 计算机基础知识第十讲,计算机文化基础(第十讲)学习笔记
计算机文化基础(第十讲)学习笔记 采样和量化PictureElement Pixel(像素)(链接: 采样的实质就是要用多少点(这个点我们叫像素)来描述一张图像,比如,一幅420x570的图像,就表示 ...
- Go 学习推荐 —(Go by example 中文版、Go 构建 Web 应用、Go 学习笔记、Golang常见错误、Go 语言四十二章经、Go 语言高级编程)
Go by example 中文版 Go 构建 Web 应用 Go 学习笔记:无痕 Go 标准库中文文档 Golang开发新手常犯的50个错误 50 Shades of Go: Traps, Gotc ...
- MongoDB学习笔记(入门)
MongoDB学习笔记(入门) 一.文档的注意事项: 1. 键值对是有序的,如:{ "name" : "stephen", "genda" ...
最新文章
- multiprocessing python_Python多线程/进程(threading、multiprocessing)知识覆盖详解
- 智能指针auto_ptr介绍
- 计算机显示器不清楚跟电池有关系吗,电脑液晶显示器显像模糊是什么原因
- [error] OpenEvent(Global\ngx_stop_25184) failed (2: The system cannot find the file specified)
- 台湾国立大学郭彦甫Matlab教程笔记(11) advanced 2D plots 上
- 计算机如何改变沟通方式,雅思阅读模拟题:计算机改变沟通方式
- k近邻算法_K近邻(knn)算法是如何完成分类的?
- 如何找到一篇论文的源代码?
- 【Android Demo】简单手机通讯录
- 【转】Java杂谈(四)
- 移动开发者大会.html5。Android。ios。wp联盟
- 3D Bounding Box Estimation Using Deep Learning and Geometry
- 计算机图形学(裁剪)
- java项目如何打包?
- 运用flask框架发送短信验证码的流程及具体代码
- 记一次 ERROR scheduler.AsyncEventQueue: Dropping event from queue shared导致OOM
- 蚂蚁金服在云原生架构下的可观察性的探索和实践
- 【Homeassistant 与Ultrasonic Distance超声波距离传感器握手】
- Cadence 中贴片元件焊盘的制作
- 深入浅出工控机加固的那点事
热门文章
- 2016.7.31整机升级计划
- 木瓜奇迹洗服务器维护,木瓜奇迹各种职业+点法
- matlab学霸表白公式,学霸隐藏式表白数学公式
- 微信wifi服务器地址,从零开始改造路由器实现微信连WIFI的功能(七):更简单的认证服务器wifidog-server...
- Naive Bayes Model 朴素贝叶斯 简单易懂的笔记by hch
- Idea中如何查看pom中dependency Analyzer的快捷键
- HTML heading
- RK3399 hi3559A 平台离线语音识别、合成、翻译、声纹
- MyBatis自带的缓存配置(Cache)
- 1041: 数列求和1