线性回归实现

1.准备数据
2.定义模型
3.定义损失函数
4.定义优化方法
5.训练
超参数初始化
循环传入数据计算损失
梯度回传更新参数
评估训练结果

%matplotlib inline
#可以直接画图
import random
import torch
# from d2l import torch as d2l#prepare the data
def synthetic_data(w, b, num_examples):"""y = XW+b"""x = torch.normal(0,1,(num_examples,len(w)))y = torch.matmul(x,w)+by += torch.normal(0, 0.01, y.shape)return x, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])
true_b = 4.2features, labels = synthetic_data(true_w, true_b, 1000) # data iteration
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i+batch_size, num_examples)]#最后一组可能会超过总的数量,所以就选取最后剩下的样本作为最后一组)yield features[batch_indices], labels[batch_indices]# model
def linereg(x,w,b):return torch.matmul(x,w) + b# loss
def mse_loss(y_hat, y):return ((y_hat - y)**2) / 2# sgd
def sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print(id(param))#注意这里要原地操作,如果写成param = param - 的形式,param的地址就改变了,就不携带梯度了,所以会报错 AttributeError: 'NoneType' object has no attribute 'zero_'#param = param - lr * param.grad / batch_size# print(id(param))param -= lr * param.grad / batch_sizeparam.grad.zero_()# init the parameters
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)# train
num_epochs = 10
lr = 1
net = linereg
loss = mse_loss
batch_size = 10
for epoch in range(num_epochs):for x,y in data_iter(batch_size, features, labels):train_loss = loss(net(x,w,b), y)train_loss.sum().backward()#梯度回传# print(w.grad)sgd([w,b], lr, batch_size)# testwith torch.no_grad():l = loss(net(features, w, b), labels)print(f'epoch {epoch+1}, loss{float(l.mean()):f}')print("b loss", (true_b - b))
print("w_loss", (true_w - w.reshape(true_w.shape)))

pytorch---线性回归实现相关推荐

  1. pytorch线性回归_PyTorch中的线性回归

    pytorch线性回归 For all those amateur Machine Learning and Deep Learning enthusiasts out there, Linear R ...

  2. (pytorch-深度学习系列)pytorch线性回归的便捷实现

    pytorch线性回归的便捷实现 继上一篇blog,使用更加简洁的方法实现线性回归 生成数据集: num_inputs = 2 num_examples = 1000 true_w = [2, -3. ...

  3. pytorch线性回归代码_[PyTorch 学习笔记] 1.3 张量操作与线性回归

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/linear_regression.py 张量的操作 拼 ...

  4. 《Pytorch - 线性回归模型》

    2020年10月4号,依然在家学习. 今天是我写的第一个 Pytorch程序,从今天起也算是入门了. 就从简单的线性回归开始吧. 话不多说,我就直接上代码实例,代码的注释我都是用中文直接写的. imp ...

  5. pytorch线性回归(笔记一)

    代码部分: import numpy as np import torch import matplotlib.pyplot as plt from torch import nn,optim fro ...

  6. Pytorch线性回归的详细实现

    线性回归 线性回归-单层神经网络 代码实现: 线性回归-单层神经网络 线性回归是⼀个单层神经⽹络  输⼊分别为x1和x2,因此输⼊层的输⼊个数为2,输⼊个数也叫特征数或 特征向量维度,输出层的输出个数 ...

  7. Pytorch线性回归

    1. 实现线性回归 用基础模型 y = wx + b 步骤: 1. 准备数据 2. 计算预测值 3. 计算损失,把参数的梯度置为0,进行反向传播 4. 更新参数 import torch import ...

  8. 使用线性回归和 PyTorch 预测自行车道的使用情况

    点击关注我哦 一篇文章带你亲临Kaggle项目 这篇文章将使用 Kaggle 的 Montréal 自行车道数据集(数据集下载地址:https://www.kaggle.com/pablomonleo ...

  9. 资源 | 4天学会PyTorch!香港科技大学开放PyTorch机器学习课件资源

    整理 | suiling 上周,香港科技大学计算机系Sung Kim教授公开了一份"三日 TensorFlow 速成课程"的学习资料(主要涉及 TensorFlow 的安装.内部机 ...

  10. pytorch贝叶斯网络_贝叶斯神经网络:2个在TensorFlow和Pytorch中完全连接

    pytorch贝叶斯网络 贝叶斯神经网络 (Bayesian Neural Net) This chapter continues the series on Bayesian deep learni ...

最新文章

  1. BZOJ4196[Noi2015]软件包管理器——树链剖分+线段树
  2. ibmmq 通道命令_IBM_MQ常用命令的.doc
  3. BZOJ3566 [SHOI2014]概率充电器 (树形DP概率DP)
  4. php 内置mail 包,配置php自带的mail功能
  5. opera9.6 的一个顽固的bug
  6. 百分制转化为五级制java_javav 的日志-编写存储过程,将百分制成绩,转换成绩等级’A’,’B’,’C’,’D’,’E’...
  7. Ruby 的环境搭建及安装
  8. Linux 串口编程三 使用termios与API进行串口程序开发
  9. 常量指针与指针常量勘误
  10. 关于Oracle RAC调整网卡MTU值的问题
  11. python codec_深入理解Python特性
  12. php 打印debug日志
  13. 【Hadoop学习笔记】大数据框架原理及主要工具概述
  14. 华硕笔记本重装系统bios设置
  15. 小米游戏本退出安全模式/win10安全模式密码
  16. python爬取高德地图数据_你的未来有我导航----教你如何爬取高德地图
  17. 计算机软件定时运行,做一回达人 Windows7定时运行程序
  18. 粒子群算法的matlab实现
  19. 女生英文名字的义意:
  20. 通过百度API获取城市经纬度(1)

热门文章

  1. kuka机器人码垛编程网盘_kuka机器人循环指令码垛编程探索
  2. JAVA 调用 labview_制作软接入点ESP8266并通过labview读取数据
  3. 一种新型智慧停车场车位占用监测模块
  4. Pikachu漏洞靶场 敏感信息泄露
  5. 猫和狗类(继承、多态、抽象、接口)
  6. ssh登录极路由后台_OpenWrt刷机详细流程(极路由)
  7. UESTC - 59 数据大搜索
  8. 【网络基础知识】VLAN技术介绍(详细)
  9. c++builder读取系统时间Now函数
  10. 计算机一级考试图片水印怎么加,图片水印怎么添加?一起来看看这几个方法