[每日一氵] torch.nn.linear 的高维度输入 (文本生成)
torch 更新是真的快啊,stable 版本已经 1.11 了吗…
看别人的代码,总能反映出我 “孤陋寡闻”
做一个唐诗的生成任务,Embedding, LSTM + Linear 就行了,很简单
然而我之前从没做过
别人的代码里是这样的数据集结构,蓝色框框是input,绿色框框是gt值
但是他的output部分:
class Model(nn.Module):def __init__(self, dataset):XXXXXXXXXXXXXX......self.fc = nn.Linear(self.lstm_size, n_vocab)def forward(self, x, prev_state):XXXXXXXXXXXXXX......logits = self.fc(output) # <--------return logits, state
只有这一行,啊?
而定义部分,也只有这一行:
self.fc = nn.Linear(self.lstm_size, n_vocab)
那我只能认为,他的输入input和gt是酱紫的:
仅仅只预测这一个字而已
这什么鬼??我从github找了个代码,进去单步调试看看:
这是那个代码博客:
https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html
这是那个代码地址:
https://github.com/closeheat/pytorch-lstm-text-generation-tutorial
output 和 logits 的shape 竟然都有3个axis!!!
我给 output 赋值为 output[0], 其shape为:torch.Size([4, 128])
再次执行:
logits = self.fc(output)
logits
的shape是:
>>> logits.shape
torch.Size([4, 6925])
这个才是最常用的操作吧…
点开那个 Torch 的官方网站:
https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
果然,这里写着 * 的维度是任意的,包括None,
也就是说,哪怕只有一个axis,也可以用:
>>> self.fc(output[0][0]) # <------------ torch.Size([128])
tensor([-0.0436, 0.0072, 0.0553, ..., -0.0649, -0.0595, -0.0814],grad_fn=<AddBackward0>)>>> self.fc(output[0][0]).shape
torch.Size([6925])
呦西,呦西,这个好诶
[每日一氵] torch.nn.linear 的高维度输入 (文本生成)相关推荐
- PyTorch 笔记(16)— torch.nn.Sequential、torch.nn.Linear、torch.nn.RelU
PyTorch 中的 torch.nn 包提供了很多与实现神经网络中的具体功能相关的类,这些类涵盖了深度神经网络模型在搭建和参数优化过程中的常用内容,比如神经网络中的卷积层.池化层.全连接层这类层次构 ...
- Lesson 8.18.2 单层回归神经网络torch.nn.Linear实现单层回归神经网络的正向传播
在之前的介绍中,我们已经了解了神经网络是模仿人类大脑结构所构建的算法,在人脑里,我们有轴突连接神经元,在算法中,我们用圆表示神经元,用线表示神经元之间的连接,数据从神经网络的左侧输入, ...
- torch.nn.Linear()函数的理解
import torch x = torch.randn(128, 20) # 输入的维度是(128,20) m = torch.nn.Linear(20, 30) # 20,30是指维度 outpu ...
- torch.nn.Linear 笔记
最多支持两维, 我准备用这个代替1*1的卷积核 import torchx = torch.randn(128, 20) # 输入的维度是(128,20) m = torch.nn.Linear(20 ...
- torch.nn.Linear
x = torch.randn(128, 20) # 输入的维度是(128,20) m = torch.nn.Linear(20, 30) # 20,30是指维度 output = m(x) prin ...
- 为什么torch.nn.Linear的表达形式为y=xA^T+b而不是常见的y=Ax+b?
今天看代码,对比了常见的公式表达与代码的表达,发觉torch.nn.Linear的数学表达与我想象的有点不同,于是思索了一番. 众多周知,torch.nn.Linear作为全连接层,将下一层的每个结点 ...
- torch.nn.Linear详解
在学习transformer时,遇到过非常频繁的nn.Linear()函数,这里对nn.Linear进行一个详解. 参考:https://pytorch.org/docs/stable/_module ...
- (五)处理多维特征的输入(上)+torch.nn.Linear(8,1)表示什么+代码
目录 1.普通逻辑回归 一个样本中一条数据有1个特征 2.多维特征:一个样本中一条数据有n个特征(以n=8为例) 计算流程:矩阵乘法 (8,1)表示什么? 3.代码: 1.普通逻辑回归 一个样本中一条 ...
- 【每日一练】105—CSS实现一款输入文本动画的效果
文 | 杨小爱 写在前面 关于这个CSS实现的文本动画效果,我们在前面也分享过很多,今天这个是一个输入文本框样式的动画效果,它的最终效果如下: 它的HTML代码很简单,主要是在CSS这块,具体的实现过 ...
- 模型的第一层:详解torch.nn.Embedding和torch.nn.Linear
文章目录 1.概述 2.Embedding 2.1 nn.Linear 2.2 nn.Embedding 对比 初始化第一层 1.概述 torch.nn.Embedding是用来将一个数字变成一个指定 ...
最新文章
- 『TensorFlow』卷积层、池化层详解
- 老外谈设计: 2015年WEB设计趋势
- CentOS 7核心安装及基本配置
- 实用--HTML的命名规范
- bootstrap轮播,播放到最后一张图片的时候,就不正确了。
- 机器人防火墙:人机识别在应用安全及风控领域的一点实践
- Netty学习笔记(四)EventLoopGroup续篇
- 30岁前不要在乎的29件事(转载)
- 选择排序算法流程图_常用排序算法之选择排序
- 补习系列(22)-全面解读 Spring Profile 的用法
- java bigram_Java BiGramDictionary.getBiFrequency方法代碼示例
- log添加 oracle redo_Oracle更改redo log大小 or 增加redo log组
- js调用数科阅读器_aspx调用js函数
- w10系统桌面的计算机找不到,w10桌面我的电脑图标不见了怎么办
- OneApiConnect通讯demo,fins欧姆龙协议实现
- Markdown从入门到放弃
- lame库(iOS 和 Android)
- 读《蔡康永的说话之道》
- 2020中职技能高考计算机,我市62名中职学生获得2020年技能高考操作考试满分
- Oracle获取一周前,一个月前,一年前的日期