用上帝的恋爱公式讲线性回归-下

  • 定义模型
  • 初始化模型
  • 定义损失函数
  • 定义优化算法
  • 训练模型
  • 在验证集上测试

前文:用上帝的恋爱公式讲线性回归-上

定义模型

net = nn.Sequential(nn.Linear(4, 1,bias=True))

初始化模型

定义模型初始化方式,使用默认的normal方式

def init_weights(layer):classname=layer.__class__.__name__if classname.find('Linear') != -1:I.normal_(layer.weight.data, std=0.1)print(layer.weight.data)I.normal_(layer.bias.data, std=0.1)print(layer.bias.data)

执行初始化:

net.apply(init_weights)

输出如下:

tensor([[-0.0355, -0.0437, -0.0928,  0.1229]])
tensor([-0.0566])Sequential((0): Linear(in_features=4, out_features=1, bias=True)
)

查看参数初始化情况:

for name, param in net.named_parameters():print(name, param)

输出如下:

0.weight Parameter containing:
tensor([[-0.0355, -0.0437, -0.0928,  0.1229]], requires_grad=True)
0.bias Parameter containing:
tensor([-0.0566], requires_grad=True)

定义损失函数

loss = torch.nn.MSELoss()

定义优化算法

trainer = torch.optim.SGD(net.parameters(), lr =0.06)

训练模型

num_epochs = 200for epoch in range(1, num_epochs+1):with torch.autograd.detect_anomaly():for index, (X, y) in enumerate(train_data_iter):#print(net(X))trainer.zero_grad()l = loss(torch.log(net(X).clamp(min=1e-20)), torch.log( y))l.backward()trainer.step()#breakl = loss(net(test_features), test_labels)print('epoch %d, loss: %f' % (epoch, l))
epoch 1, loss: 8447828.000000
epoch 2, loss: 7849816.500000......
epoch 194, loss: 2.003807
epoch 195, loss: 2.004321
epoch 196, loss: 2.339057
epoch 197, loss: 2.005957
epoch 198, loss: 2.246327
epoch 199, loss: 2.057835
epoch 200, loss: 2.034282

在验证集上测试

因为只是一个简单的demo,并没有对模型优化做更多的微调,各位有兴趣,可以自行微调,把loss降到更低。

for index, (X, y) in enumerate(test_data_iter):print("predict y_hat:", net(X))print("ture_y:", y)break

从输出看,我们的预测天数偏差还是比较大的,比如第一条我们预测的是58.65天,而真实情况是66.88天。我们的预测模型在测试集的表现是2,这是一个平均值,他的意义是把整个测试集所有的情况都预测完,总偏差在/测试集总数=2,他代表的是一个平均情况。

predict y_hat: tensor([[58.6569],[52.2799],[64.1686],[66.6607],[52.4029]], grad_fn=<AddmmBackward>)
ture_y: tensor([[66.8827],[60.7605],[72.7446],[60.7297],[71.2678]])

此时我们把模型中的w和b输出,看一下我们猜的结果

for name, param in net.named_parameters():print(name, param)
0.weight Parameter containing:
tensor([[0.5140, 3.7515, 0.1002, 0.1017]], requires_grad=True)
0.bias Parameter containing:
tensor([-0.0554], requires_grad=True)

把上面输出的结果和下面上帝手中公式中真实的w,b进行比较,总体来说还可以接收。

true_w = [0.5, 3.8, 0.1,  0.1]
true_b = 1

那我们现在猜的上帝手中的公式:y=0.514∗appearance+3.7515∗unhappy+0.1002∗study+0.1017∗game+by = 0.514*appearance + 3.7515 * unhappy + 0.1002*study + 0.1017*game + by=0.514∗appearance+3.7515∗unhappy+0.1002∗study+0.1017∗game+b 偏差b为-0.0554。
到此我们就破解了上帝手中的恋爱公式。

用上帝的恋爱公式讲线性回归-下相关推荐

  1. 用上帝的恋爱公式讲线性回归-上

    用上帝的恋爱公式讲线性回归-上 准备好开发环境 切换成上帝的身份 切换成数据科学家的视角 加载数据 用上帝的恋爱公式讲线性回归-下 上帝他老人家手里有一条神奇的恋爱公式!特别神奇,只要你告诉他怎么追妹 ...

  2. 在EXCEL表格中经常会遇到有合并单元格时,汇总计算的公式无法直接下拉自动填充计算,掌握这个小技巧一键汇总

    在EXCEL表格中经常会遇到有合并单元格时,汇总计算的公式无法直接下拉自动填充计算,掌握这个小技巧一键汇总 目录 在EXCEL表格中经常会遇到有合并单元格时,汇总计算的公式无法直接下拉自动填充计算,掌 ...

  3. CLIP 改进工作串讲(下)

    CLIP 改进工作串讲(下) 本文为 CLIP 改进工作串讲(下)[论文精读] 的学习笔记. 图像生成 最近一年图像生成领域扩散模型大火,尤其是文本生成图像,DALL-E.imagen 等工作层出不穷 ...

  4. WPS表格Excel:公式实现向下填充

    WPS表格Excel:公式实现向下填充 目标:用公式实现从A列到B列的转变 观察数据得到结论 如果B列中的第N行对应的A列中的数据不为空,则取A列中的数据,如果为空则取B列中N-1行的数据即可 使用I ...

  5. 欧拉公式-上帝创造的公式

    欧拉公式: (1)分式里的欧拉公式: a^r/(a-b)(a-c)+b^r/(b-c)(b-a)+c^r/(c-a)(c-b) 当r=0,1时式子的值为0 当r=2时值为1 当r=3时值为a+b+c ...

  6. asp.net电子商务开发实战 视频 第二讲 (下)

    第二讲主要是讲门类列表,第二讲(下)这里主要是业务层和表示层的代码编写演示.. 这里表示的页面代码我没有详细的演示,只是复制过去再解释了一下,有朋友告诉我仔细的演示下,逐个去敲下,我在第三讲里开始表示 ...

  7. 【ROS】ros入门21讲(下)

    前言:上文讲了话题的通信机制,接下来是ROS的第二种重要通信机制-服务. [ROS]ros学习21讲(上) 客户端请求,相当于开关,一次开,再一次关,控制运动的状态.服务端应答. 目录 ROS 七:客 ...

  8. 数值分析公式大赏(下)

    线性代数知识 向量的范数 设x和y为n维实向量,则x和y的任意范数具有以下基本性质: 正定性: ∣ ∣ x ∣ ∣ ≥ 0 ||x||≥0 ∣∣x∣∣≥0,只有x为零向量时||x||=0 齐次性:∀k ...

  9. JB的阅读之旅-软件测试52讲(下)

    17)精益求精:聊聊提高GUI测试稳定性的关键技术 问题:同样的测试用例在同样的环境上,时而测试通过,时而测试失败: 造成GUI测试不稳定的常见五种因素: 非预计的弹出对话框: 页面控件属性的细微变化 ...

最新文章

  1. 20181023-3 每周例行报告
  2. python 数据分析学什么-python数据分析哪些课程好?
  3. java throw 什么意思_[转载]java中throw和throws的区别
  4. 004-流程控制和类型转换
  5. 04. Web大前端时代之:HTML5+CSS3入门系列~HTML5 表单
  6. java分页数据导出excel
  7. batch size自适应log(1)
  8. js几种生成随机颜色方法
  9. element做树形下拉_Element input树型下拉框的实现代码
  10. 8uftp怎么下载客户文件,8uftp怎么实现下载客户文件
  11. 远程计算机没反映6678,6678 PCIe 与FPGA LINK UP 后 不能获得FPGA的DEVICE_ID和VENDDR_ID
  12. keil4在线仿真教程分享
  13. 做LeetCode题的感悟 (1-10题)
  14. 深入理解HashMap
  15. 全局鼠标手势linux,Firefox通过用户脚本和热键进行的全局鼠标手势(Win7 / Linux + FF 68 esr)...
  16. 软考中的嵌入式系统设计师为什么考的人少?
  17. webservice 缺少根元素_草莓种植,钙、硼元素十分重要,直接关系到草莓的产量和品质!...
  18. python-ip端口扫描器
  19. Python 用Ursina引擎制作一个3D迷宫游戏
  20. 找回win10自带的windows照片查看器

热门文章

  1. linux系统编译内核源码的步骤演示
  2. Spring boot 启动错误:Could not resolve placeholder
  3. Laytpl 1.2
  4. OFDM OFDMA
  5. transformer xl 用于文本生成
  6. matlab 反向二值化,MATLAB:图像二值化、互补图(反运算)(im2bw,imcomplement函数)...
  7. cesium修改灰色背景(默认蓝色)(cesium篇.81)
  8. 分享Python采集79个NET其他类别源码,总有一款适合您
  9. 首屏时间从12.67s到1.06s,我是如何做到的?
  10. 使用C#中的AutoCAD .NET API对CAD二次开发,获取动态块可见性值