文章目录

  • 我忍不住吐槽一下这个包,第一行代码就报错,没有datasets是闹哪样
  • 加载库
  • 训练样本长度
  • 测试样本长度
  • 有多少训练/测试样本
  • 批量大小(同时处理多少个样本?)
  • 超参数
  • 预测/目标长度
  • 使用 CUDA?
  • 人工种子初始化
  • NARMA30 数据集
  • 数据加载器
  • 内部矩阵W生成
  • 输入权重Win生成
  • Bias vector偏差向量生成
  • 创建一个Leaky-integrated ESN,用最小二乘训练算法,调用 esn = etrs.ESN(
  • 如果可能,在GPU中传输
  • 对每一批
  • 计算输出矩阵Wout来结束训练
  • 得到训练集中的第一个样本,转为Variable.
  • 用我们训练过的ESN做预测
  • 打印训练误差 MSE和 NRMSE
  • 得到测试集中的第一个样本转为Variable.
  • 用我们训练过的ESN做预测
  • 打印测试集上的 MSE 和NRMSE
  • 展示目标和预测

我忍不住吐槽一下这个包,第一行代码就报错,没有datasets是闹哪样

加载库

import torch
from echotorch.datasets.NARMADataset import NARMADataset
import echotorch.nn.reservoir as etrs
import echotorch.utils
import echotorch.utils.matrix_generation as mg
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
import numpy as np
import matplotlib.pyplot as plt

训练样本长度

train_sample_length = 5000

测试样本长度

test_sample_length = 1000

有多少训练/测试样本

n_train_samples = 1
n_test_samples = 1

批量大小(同时处理多少个样本?)

batch_size = 1

超参数

spectral_radius = 1.07
leaky_rate = 0.9261
input_dim = 1
reservoir_size = 410
connectivity = 0.1954
ridge_param = 0.00000409
input_scaling = 0.9252
bias_scaling = 0.079079

预测/目标长度

plot_length = 200

使用 CUDA?

use_cuda = False
use_cuda = torch.cuda.is_available() if use_cuda else False

人工种子初始化

np.random.seed(1)
torch.manual_seed(1)

NARMA30 数据集

narma10_train_dataset = NARMADataset(train_sample_length,
n_train_samples, system_order=10)

narma10_test_dataset = NARMADataset(test_sample_length,
n_test_samples, system_order=10)

数据加载器

trainloader = DataLoader(narma10_train_dataset,
batch_size=batch_size, shuffle=False, num_workers=2)

testloader = DataLoader(narma10_test_dataset,
batch_size=batch_size, shuffle=False, num_workers=2)

内部矩阵W生成

w_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
connectivity=connectivity,
spetral_radius=spectral_radius
)

输入权重Win生成

win_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
connectivity=connectivity,
scale=input_scaling,
apply_spectral_radius=False
)

Bias vector偏差向量生成

wbias_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
connectivity=connectivity,
scale=bias_scaling,
apply_spectral_radius=False
)

创建一个Leaky-integrated ESN,用最小二乘训练算法,调用 esn = etrs.ESN(

esn = etrs.LiESN(
input_dim=input_dim,
hidden_dim=reservoir_size,
output_dim=1,
leaky_rate=leaky_rate,
learning_algo=‘inv’,
w_generator=w_generator,
win_generator=win_generator,
wbias_generator=wbias_generator,
ridge_param=ridge_param
)

如果可能,在GPU中传输

if use_cuda:
esn.cuda()
#end if

对每一批

for data in trainloader:
#输入和输出
inputs, targets = data
#将数据转换为变量Variables
inputs, targets = Variable(inputs), Variable(targets)
if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
# ESN需要的inputs and targets
esn(inputs, targets)
# end for

计算输出矩阵Wout来结束训练

esn.finalize()

得到训练集中的第一个样本,转为Variable.

dataiter = iter(trainloader)
train_u, train_y = dataiter.next()
train_u, train_y = Variable(train_u), Variable(train_y)
if use_cuda: train_u, train_y = train_u.cuda(), train_y.cuda()

用我们训练过的ESN做预测

y_predicted = esn(train_u)

打印训练误差 MSE和 NRMSE

print(u"Train MSE: {}".format(echotorch.utils.mse(y_predicted.data, train_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, train_y.data)))
print(u"")

得到测试集中的第一个样本转为Variable.

dataiter = iter(testloader)
test_u, test_y = dataiter.next()
test_u, test_y = Variable(test_u), Variable(test_y)
if use_cuda: test_u, test_y = test_u.cuda(), test_y.cuda()

用我们训练过的ESN做预测

y_predicted = esn(test_u)

打印测试集上的 MSE 和NRMSE

print(u"Test MSE: {}".format(echotorch.utils.mse(y_predicted.data, test_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, test_y.data)))
print(u"")

展示目标和预测

plt.plot(test_y[0, :plot_length, 0].data, ‘r’)
plt.plot(y_predicted[0, :plot_length, 0].data, ‘b’)
plt.show()

ESN学习笔记——echotorch(2)narma10相关推荐

  1. CV学习笔记 | CV综述 [2020.10.01]

    文章目录 0. 概述(整理完后随时修改) 1. 人工神经网络 1.1. 人工神经网络发展历程 1.2. 一些神经元节点的工作原理 1.2.1. 基本神经元 1.2.2. 卷积神经元(Convoluti ...

  2. PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

    您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...

  3. 容器云原生DevOps学习笔记——第三期:从零搭建CI/CD系统标准化交付流程

    暑期实习期间,所在的技术中台-效能研发团队规划设计并结合公司开源协同实现符合DevOps理念的研发工具平台,实现研发过程自动化.标准化: 实习期间对DevOps的理解一直懵懵懂懂,最近观看了阿里专家带 ...

  4. 容器云原生DevOps学习笔记——第二期:如何快速高质量的应用容器化迁移

    暑期实习期间,所在的技术中台-效能研发团队规划设计并结合公司开源协同实现符合DevOps理念的研发工具平台,实现研发过程自动化.标准化: 实习期间对DevOps的理解一直懵懵懂懂,最近观看了阿里专家带 ...

  5. 2020年Yann Lecun深度学习笔记(下)

    2020年Yann Lecun深度学习笔记(下)

  6. 2020年Yann Lecun深度学习笔记(上)

    2020年Yann Lecun深度学习笔记(上)

  7. 知识图谱学习笔记(1)

    知识图谱学习笔记第一部分,包含RDF介绍,以及Jena RDF API使用 知识图谱的基石:RDF RDF(Resource Description Framework),即资源描述框架,其本质是一个 ...

  8. 计算机基础知识第十讲,计算机文化基础(第十讲)学习笔记

    计算机文化基础(第十讲)学习笔记 采样和量化PictureElement Pixel(像素)(链接: 采样的实质就是要用多少点(这个点我们叫像素)来描述一张图像,比如,一幅420x570的图像,就表示 ...

  9. Go 学习推荐 —(Go by example 中文版、Go 构建 Web 应用、Go 学习笔记、Golang常见错误、Go 语言四十二章经、Go 语言高级编程)

    Go by example 中文版 Go 构建 Web 应用 Go 学习笔记:无痕 Go 标准库中文文档 Golang开发新手常犯的50个错误 50 Shades of Go: Traps, Gotc ...

  10. MongoDB学习笔记(入门)

    MongoDB学习笔记(入门) 一.文档的注意事项: 1.  键值对是有序的,如:{ "name" : "stephen", "genda" ...

最新文章

  1. multiprocessing python_Python多线程/进程(threading、multiprocessing)知识覆盖详解
  2. 智能指针auto_ptr介绍
  3. 计算机显示器不清楚跟电池有关系吗,电脑液晶显示器显像模糊是什么原因
  4. [error] OpenEvent(Global\ngx_stop_25184) failed (2: The system cannot find the file specified)
  5. 台湾国立大学郭彦甫Matlab教程笔记(11) advanced 2D plots 上
  6. 计算机如何改变沟通方式,雅思阅读模拟题:计算机改变沟通方式
  7. k近邻算法_K近邻(knn)算法是如何完成分类的?
  8. 如何找到一篇论文的源代码?
  9. 【Android Demo】简单手机通讯录
  10. 【转】Java杂谈(四)
  11. 移动开发者大会.html5。Android。ios。wp联盟
  12. 3D Bounding Box Estimation Using Deep Learning and Geometry
  13. 计算机图形学(裁剪)
  14. java项目如何打包?
  15. 运用flask框架发送短信验证码的流程及具体代码
  16. 记一次 ERROR scheduler.AsyncEventQueue: Dropping event from queue shared导致OOM
  17. 蚂蚁金服在云原生架构下的可观察性的探索和实践
  18. 【Homeassistant 与Ultrasonic Distance超声波距离传感器握手】
  19. Cadence 中贴片元件焊盘的制作
  20. 深入浅出工控机加固的那点事

热门文章

  1. 2016.7.31整机升级计划
  2. 木瓜奇迹洗服务器维护,木瓜奇迹各种职业+点法
  3. matlab学霸表白公式,学霸隐藏式表白数学公式
  4. 微信wifi服务器地址,从零开始改造路由器实现微信连WIFI的功能(七):更简单的认证服务器wifidog-server...
  5. Naive Bayes Model 朴素贝叶斯 简单易懂的笔记by hch
  6. Idea中如何查看pom中dependency Analyzer的快捷键
  7. HTML heading
  8. RK3399 hi3559A 平台离线语音识别、合成、翻译、声纹
  9. MyBatis自带的缓存配置(Cache)
  10. 1041: 数列求和1