pytorch中的gather函数

pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验。
立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑

今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)

观察它的输出结果:

1  2  34  5  6
[torch.FloatTensor of size 2x3]1  26  4
[torch.FloatTensor of size 2x2]1  5  61  2  3
[torch.FloatTensor of size 2x3]

这里是官方文档的解释


torch.gather(input, dim, index, out=None) → TensorGathers values along an axis specified by dim.For a 3-D tensor the output is specified by:out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2Parameters: input (Tensor) – The source tensordim (int) – The axis along which to indexindex (LongTensor) – The indices of elements to gatherout (Tensor, optional) – Destination tensorExample:>>> t = torch.Tensor([[1,2],[3,4]])>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))1  14  3[torch.FloatTensor of size 2x2]

可以看出,gather的作用是这样的,index实际上是索引,具体是行还是列的索引要看前面dim 的指定,比如对于我们的栗子,【1,2,3;4,5,6,】,指定dim=1,也就是横向,那么索引就是列号。index的大小就是输出的大小,所以比如index是【1,0;0,0】,那么看index第一行,1列指的是2, 0列指的是1,同理,第二行为4,4 。这样就输入为【2,1;4,4】,参考这样的解释看上面的输出结果,即可理解gather的含义。

gather在one-hot为输出的多分类问题中,可以把最大值坐标作为index传进去,然后提取到每一行的正确预测结果,这也是gather可能的一个作用。

2018年05月30日20:05:01

春去夏来,温情演为欲望。 —— 作家, 安德烈莫罗阿

Pytorch中的torch.gather函数的含义相关推荐

  1. gather torch_浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  2. gather torch_pytorch中的Torch.gather函数的含义

    在动手学习深度学习中学到了一个函数gather,原文是说可以通过gather得到标签的预测概率. y_hat = torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]]) ...

  3. Pytorch中的torch.where函数

    首先我们看一下Pytorch中torch.where函数是怎样定义的: @overload def where(condition: Tensor) -> Union[Tuple[Tensor, ...

  4. pytorch中的torch.squeeze()函数

    torch.squeeze(input, dim=None, out=None) squeeze()函数的功能是维度压缩.返回一个tensor(张量),其中 input 中大小为1的所有维都已删除. ...

  5. Pytorch中的torch.cat()函数

    转载自:https://www.cnblogs.com/JeasonIsCoding/p/10162356.html 1. 字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是 ...

  6. Pytorch的使用:torch.gather函数

    Pytorch的使用:torch.gather函数 **torch.gather()** 作用:方便从批量tensor中获取特定化维度指定索引下的数据,该索引往往是乱序的. 首先看一下官方文档中的3维 ...

  7. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  8. python中squeeze函数_详解pytorch中squeeze()和unsqueeze()函数介绍

    squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的 ...

  9. pytorch 中 expand ()函数

    pytorch 中 expand ()函数 expand函数的功能就是 用来扩展张量中某维数据的尺寸,它返回输入张量在某维扩展为更大尺寸后的张量. 例如: x = torch.tensor([1, 2 ...

最新文章

  1. 了解过去与理解现在的一把钥匙
  2. js中event,event.srcElement,event.target在IE和firefox下的兼容性
  3. SQL语句书可以提高执行效率的5种需要注意的书写方法
  4. python官网下载哪个版本-python下载哪个版本好
  5. vue-自定义过滤器--时间
  6. 设置Tomcat字符编码UTF-8
  7. 程序员:你真的该养生了
  8. Linux下安装配置MySQL
  9. python变量名称跟着循环,在Python中使用列表中的名称循环创建新变量
  10. Font Awesome 中文网
  11. 幸福框架:模式驱动开发
  12. 苹果企业账号使用注意事项
  13. c语言static int x,为什么要使用static_cast int (x)而不是(int)x?
  14. 【愣锤笔记】能解决80%场景的Git必会知识点
  15. 计算机网络tcp/ip协议,UDP,HTTP/HTTPS基础知识
  16. PreparedStatement的使用
  17. 破解tomcat管理员密码
  18. c# ascii转换方法
  19. HHDBCS的快捷命令使用
  20. java向飞秋发文件_飞秋如何发文件夹

热门文章

  1. 自动执行一个php文件,使用crontab自动执行php文件
  2. day_08 字符编码乱码处理
  3. 定向耦合器——谈谈隔离度(四)
  4. 操作系统,生产者-消费者问题详解
  5. uboot-linux 官方网站及下载地址列表
  6. 轻轻挥别2014,悄悄迎来2015
  7. 关于cankao.com数据更新的说明
  8. 5g上行速率怎么提升_5G手机到底牛逼在哪里?(SRS轮发)
  9. C语言 printf 格式化 输出 右对齐补零
  10. 利用langchain-ChatGLM、langchain-TigerBot实现基于本地知识库的问答应用