Pytorch系列(1):torch.gather()
torch.gather
作用:收集输入的特定维度指定位置的数值
参数:
input(tensor): 待操作数。不妨设其维度为(x1, x2, …, xn)
dim(int): 待操作的维度。
index(LongTensor): 如何对input进行操作。其维度有限定,例如当dim=i时,index的维度为(x1, x2, …y, …,xn),既是将input的第i维的大小更改为y,且要满足y>=1(除了第i维之外的其他维度,大小要和input保持一致)。
out: 注意输出和index的维度是一致的
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
解释及举例
torch.gather()函数从公式上来说并不容易理解,我们以一个场景为例。
在序列标注问题上,我们给每一个单词都标上一个标签。不妨假设我们有4个句子,每个句子的长度不一定相同,标签如下:
input = [
[2, 3, 4, 5],
[1, 4, 3],
[4, 2, 2, 5, 7],
[1]
]
上例中有四个句子,长度分别为4,3,5,1,其中第一个句子的标签为2,3,4,5。我们知道,处理自然语言问题时,一般都需要进行padding,即将不同长度的句子padding到同一长度,以0为padding,那么上述经padding后变为:
input = [
[2, 3, 4, 5, 0, 0],
[1, 4, 3, 0, 0, 0],
[4, 2, 2, 5, 7, 0],
[1, 0, 0, 0, 0, 0]
]
那么问题来了,现在我们想获得每个句子中最后一个词语的标签,该怎么得到呢?既是,第一句话中的5,第二句话中的3,第三句话中7,第四句话中的1。
此时就需要用gather函数了(当然你说可以循环什么的,当我没问)。
此时我们的input就是填充之后的tensor,dim=1, index就是各个句子的长度,即[[4],[3],[5],[1]]。之所以维度是4*1,是为了满足index维度和input维度之间的关系(讲解参数时有讲)。
import torch
input = [[2, 3, 4, 5, 0, 0],[1, 4, 3, 0, 0, 0],[4, 2, 2, 5, 7, 0],[1, 0, 0, 0, 0, 0]
]
input = torch.tensor(input)
#注意index的类型
length = torch.LongTensor([[4],[3],[5],[1]])
#index之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(input, 1, length-1)
out
补充
此函数的作用感觉一句话说不出来,硬说的话,我感觉应该是:
利用index来索引input特定位置的数值
例如上例中的length,再加上dim=1,指定了索引每句话中的最后一个单词(length-1)。
另外可以琢磨一下gather的计算公式
Pytorch系列(1):torch.gather()相关推荐
- gather torch_浅谈Pytorch中的torch.gather函数的含义
pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...
- Pytorch中的torch.gather函数的含义
pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...
- Pytorch的使用:torch.gather函数
Pytorch的使用:torch.gather函数 **torch.gather()** 作用:方便从批量tensor中获取特定化维度指定索引下的数据,该索引往往是乱序的. 首先看一下官方文档中的3维 ...
- 小白学Pytorch 系列--Torch API(1)
小白学Pytorch 系列–Torch API Torch version 1.13 Tensors TORCH.IS_TENSOR 如果obj是PyTorch张量,则返回True. 注意,这个函数只 ...
- pycharm安装pytorch报错 提示系列问题 torch 包找不到因为pip版本低,结果升级了pip从19.3到20.2 4又提示高版不支持torch安装
pycharm安装pytorch报错 提示系列问题 torch 包找不到因为pip版本低,结果升级了pip从19.3到20.2 4又提示高版不支持torch安装 DEPRECATION: The - ...
- 小白学Pytorch系列-- Torch API (5)
小白学Pytorch系列-- Torch API (5) Math operations Pointwise Ops TORCH.ABS 计算输入中每个元素的绝对值. >>> t ...
- 小白学Pytorch系列--Torch API (7)
小白学Pytorch系列–Torch API (7) Comparison Ops allclose 此函数检查输入和其他是否满足条件: >>> torch.allclose(tor ...
- pytorch之torch.gather方法
首先,先给出torch.gather函数的函数定义: torch.gather(input, dim, index, out=None) → Tensor 官方给出的解释是这样的: ...
- [pytorch]——torch.gather(以BERT中的MLM为例)
前言 都知道BERT中有MLM的任务,假设此时ENCODER的输出output的大小为: batch_size x max_len x d_model,而对于每一个句子,都有对应的数个被mask掉的单 ...
最新文章
- 平台(洛谷P1105题题解,Java语言描述)
- 数据中台赋能企业数字化转型的四个关键成功因素
- 信息学奥赛一本通(1175:除以13)
- 2018微博词云项目深度解析
- sqlite expert 未找到提供程序。该程序可能未正确安装_SolidWorks2019安装过程中出现常见问题及解决方案...
- web前端基础知识-(六)jQuery-补
- ios应用提交审核出现的问题总结
- npm换成国内源 npm换源 npm换淘宝源镜像
- 【实践】BiLSTM上的CRF,用命名实体识别任务来解释CRF(1)
- 板子无法进入loader模式升级固件时需短接emmc或flash
- Ansible之playbook的使用总结 - 运维笔记
- phython编写图形界面
- 西方科学家依然对互联网的进化表示质疑
- Ubuntu使用WPS打开文档出现缺失字体情况解决方法
- 致远SPM解决方案之库存管理
- 企业为什么要建立独立电商网站?
- python蜂鸣器天空之城频率_用python来一首钢琴solo天空之城
- Journal of Computational Physics, latex模板
- Date Calendaer
- 蓄电池内阻测试仪分析软件,蓄电池内阻测试仪
热门文章
- 以颤抖之身追赶 怀敬畏之心挑战
- 小偷写给失主的信(爆笑)
- 幼师学计算机心得体会怎么写,从事幼儿教师心得体会范文(通用9篇)
- android radiobutton属性大全,Android RadioButton
- MAX之不关闭MAX脚本开发
- 40年前地球首次收到外星信号幕后真相
- 洛谷 P1344 [USACO4.4] 追查坏牛奶Pollutant Control【网络流】
- VScode第一行头文件报错,‘iostream‘ file not found
- 【转】 SumaTra PDF 常用快捷键
- P0级重大事故:超卖了100瓶飞天茅台,整个项目组慌了