(注:代码来自pytorch官网)

学习使用pytorch构建神经网络,首先我们来看一下不使用深度学习框架的网络如何构建


```python
# -*- coding: utf-8 -*-
import numpy as np# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)learning_rate = 1e-6
for t in range(500):# Forward pass: compute predicted yh = x.dot(w1)h_relu = np.maximum(h, 0)y_pred = h_relu.dot(w2)# Compute and print lossloss = np.square(y_pred - y).sum()print(t, loss)# Backprop to compute gradients of w1 and w2 with respect to lossgrad_y_pred = 2.0 * (y_pred - y)grad_w2 = h_relu.T.dot(grad_y_pred)grad_h_relu = grad_y_pred.dot(w2.T)grad_h = grad_h_relu.copy()grad_h[h < 0] = 0grad_w1 = x.T.dot(grad_h)# Update weightsw1 -= learning_rate * grad_w1w2 -= learning_rate * grad_w2

不使用深度学习框架,我们来构造一个单隐层的网络
输入层维度1000,隐层100,输出层10,假设batch为64
学习率为1e-6

首先我们初始化数据集,使用随机初始化
x为(64*1000),y为(64*10)

接下来我们初始化两层网络的权重
w1 为(1000*100),w2位(100*10)

接下来我们迭代优化500次,每次的过程如下所示:

隐层结果h = x与w1相乘,结果为(64*100)
对隐层结果h使用relu函数进行裁剪,小于0的结果裁剪为0,得到h_relu结果
最终预测结果y_pred = h_relu与w2相乘,结果为(64*10)

loss函数为
预测结果y_pred与实际结果y的误差平方,一共(64*10)个误差平方
总loss为所有的误差平方和,将640个误差平方相加为一个数

计算梯度
计算y的梯度,因为我们取的是平方和,所以计算梯度时乘2,维度为(64*10)

计算w2的梯度,隐层结果转置与y梯度相乘,为(100*64)*(64*10)

计算隐层结果梯度,预测结果与w2转置相乘
(64*10)*(10*100)

将实际训练过程中裁剪的位置转换为0
(注:这里的h为训练时的h,h<0的位置没有传播)

计算w1的梯度

根据梯度调整参数


这些是一个例子中的loss结果

pytorch学习教程笔记(一)相关推荐

  1. PyTorch学习教程、手册

    文章目录 PyTorch学习教程.手册 PyTorch视频教程 NLP&PyTorch实战 CV&PyTorch实战 PyTorch论文推荐 PyTorch书籍推荐 PyTorch学习 ...

  2. LaTeX中文学习教程 笔记

    视频地址: LaTeX中文学习教程(用于论文或稿件排版,15集全) 用LaTeX写期刊论文的详细教程 一.基本结构 % 导言区 \documentclass{article}%book,article ...

  3. Scala学习教程笔记二之函数式编程、Object对象、伴生对象、继承、Trait、

    1:Scala之函数式编程学习笔记: 1:Scala函数式编程学习:1.1:Scala定义一个简单的类,包含field以及方法,创建类的对象,并且调用其方法:class User {private v ...

  4. MySQL精品学习资源合集 | 含学习教程笔记、运维技巧、图书推荐

    MySQL凭借开源.免费.低门槛.稳定等优势,成为了当前最流行的关系型数据库之一.从1998年发行第一版以来,通过不断地更新迭代,MySQL被越来越多的公司使用,已然成为当下潮流. 为了帮助大家更好地 ...

  5. Leaflet学习教程+笔记(Mars2D)

    Leaflet学习(Mars2D) 一 快速开始 中文教程–小白必备 中文API文档–学会查阅文档 准备工作-引入文件 <link rel="stylesheet" href ...

  6. MybatisMybatisPlusSpringboot之最全入门学习教程笔记

    目录 1 Mybatis概述 1.1 Mybatis概念 1.1.1 JDBC 缺点 1.1.2 Mybatis优化 1.2 快速入门 1.2.1 创建数据库 1.2.2 IDEA2021创建项目 1 ...

  7. 2021-7-20 pytorch学习基础笔记

    1.torch起源 2002发布Torch,后面Torch7在2011(Lua语言)大大制约了它的发展 Facebook在Torch7的基础上,在2016年10月发布了0.1 THNN后端 2018年 ...

  8. Cesium学习教程+笔记(Mars3D) 图层 图层组 矢量数据

    图层与图层组 为什么需要了解图层组?图层与图层组又是啥关系? 之前学过的点线面都是矢量数据 , 添加到矢量图层直接到地图上的 var tucengdian = new mars3d.layer.Gra ...

  9. python 入门学习教程笔记-- BMR 计算器

    本讲内容涉及到的知识点有: 1.数值类型: 2.字符串分割,字符串格式化输出,使用{}占位 https://docs.python.org/3/library/stdtypes.html#str.sp ...

最新文章

  1. OpenCV实现遍历文件夹下所有文件
  2. 高可用 Prometheus 架构实践中的踩坑集锦
  3. jquery学习资源
  4. android nougat和安卓7.1,Android Nougat 7.1.2 先睹为快
  5. C语言里 指针变量强制类型转换,C语言之强制类型转换与指针--#define DIR *((volatile unsigned int *) 0x0022)...
  6. sun服务器如何查cpu信息,solaris 如何查看CPU信息
  7. MySQL基础 增删改查练习
  8. Python五角星绘制
  9. 模拟斗地主洗牌发牌,并对已发好的拍进行排序(红桃A,方块A, 黑桃2.......)
  10. iOS OC利用imageview属性切出类似圆柱图形
  11. android 广告效果图,Android_Android实现加载广告图片和倒计时的开屏布局,这是一个android开屏布局的实例 - phpStudy...
  12. Linux下运行jar包的方法
  13. 考研复试英语自我介绍模版
  14. 【安全狗高危安全通告】OpenSSL存在远程代码执行漏洞和拒绝服务漏洞
  15. 微信小程序——圆形图片image控件、两个字和三个字对齐
  16. 四格漫画《MUXing》——他们在干什么
  17. 移动端检测更新安装包
  18. centos7.2系统下运行.sh文件的办法
  19. MATLAB数字图像小程序设计
  20. 康拓编码——Permutation Sequence

热门文章

  1. wordpress采集器-wordpress采集器安装下载教程
  2. 【新能力】京东小程序已加入知晓云 SDK 全家桶
  3. Maven教程初级篇01
  4. 【Thingsboard】源码分析:OTA 更新
  5. idea连接数据库无法识别sql语句中的表
  6. php dwt格式是什么,Dwt格式有什么特点?与dwg格式有什么区别?
  7. 华为监控服务器型号,华为NVR视频监控产品介绍.pdf
  8. 基于python的游戏市场分析
  9. Python 数据结构 tree 树
  10. asterisk中Callback+DISA电话回拨应用释例