在动手学习深度学习中学到了一个函数gather,原文是说可以通过gather得到标签的预测概率。

y_hat = torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])

y = torch.LongTensor([0,2])

y_hat.gather(1,y.view(-1,1))

tensor([[0.1000],

[0.5000]])

开始我看到这个输出一头雾水 不知道怎么回事

查了查 gather的时候我才知道

torch.gather(input,dim,index,out=None)

example:

t = torch.Tensor([1,2],[3,4])

torch.gather(t,1,torchLongTensor([[0,0],[1,0]]))

1,1

4,3

可以看出gather的作用是根据索引返回该项元素,首先先输入一个Tensor 然后根据dim进行判断是是行的还是列的,当dim=0 时候竖行查找,当dim=1的时候是横向查找

上题中,dim=1,那么索引就是列号。index的大小就是输出的大小,比如index是[1,0;0,0]其实就是第一行的第二个元素和第一个元素,第二行的第一个元素也就是返回的是2,1 3,3

所以例子中是[0,0],[1,0] 返回的就是[1,1],[4,3]

在例题中的他是通过view函数来返回index的,开始不知道view的意思,查过后知道了,他实际上和resize的意思差不多。

a = torch.Tensor([[1,2,3],[4,5,6]])

b = torch.Tensor([1,2,3,4,5,6])

print(a.view(1,6))

print(b.view(1,6))

得到的都是

tensor([[1,2,3,4,5,6]])

再看一个例子

a = torch.Tensor([[1,2,3],[4,5,6]])

print(a.view(3,2))

将会得到

tensor([[1,2],

[3,4],

[5,6]

])

相当于就是从1,2,3,4,5,6 顺序的拿数组来填充需要的形状。

参数中的-1就代表这个位置由其他位置的数字来进行推断,只要不在歧义的情况下,view参数就可以推断出来,也就是人可以推断出形状的情况下,view也是可以推断出来的,比如a tensor的数据个数是6个,如果view(1,-1)我们就可以推断出来-1代表6。而如果view(-1,-1,2)的话,人也不知道的话,机器也不会知道的,所以就会报错

gather torch_pytorch中的Torch.gather函数的含义相关推荐

  1. gather torch_浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  2. Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  3. Pytorch中的torch.where函数

    首先我们看一下Pytorch中torch.where函数是怎样定义的: @overload def where(condition: Tensor) -> Union[Tuple[Tensor, ...

  4. mysql的or能去重吗_mysql中bit_count和bit_or函数的含义

    翻阅mysql手册时,看到有个示例使用了bit_or方法来去除重复的数据,一开始没看明白,后来看明白之后感觉非常巧妙.示例要实现的功能就是计算每月有几天有访问,先把示例摘录在这里. -- 创建表 CR ...

  5. pytorch中的torch.squeeze()函数

    torch.squeeze(input, dim=None, out=None) squeeze()函数的功能是维度压缩.返回一个tensor(张量),其中 input 中大小为1的所有维都已删除. ...

  6. torch的拼接函数_Pytorch中的torch.cat()函数

    cat( )的用法 按维数0拼接(竖着拼) C = torch.cat( (A,B),0 ) 按维数1拼接(横着拼) C = torch.cat( (A,B),1 ) 按维数0拼接 A=torch.o ...

  7. Pytorch中的torch.cat()函数

    转载自:https://www.cnblogs.com/JeasonIsCoding/p/10162356.html 1. 字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是 ...

  8. Pytorch的使用:torch.gather函数

    Pytorch的使用:torch.gather函数 **torch.gather()** 作用:方便从批量tensor中获取特定化维度指定索引下的数据,该索引往往是乱序的. 首先看一下官方文档中的3维 ...

  9. gather torch_我对torch中的gather函数的一点理解

    本文首发于公众号[拇指笔记] 官方文档的解释 torch.gather(input,dim,index,out=None) → Tensor torch.gather(input, dim, inde ...

最新文章

  1. find cp命令的用法
  2. learning中的数学
  3. php deprecated是什么意思,解决:PHP Deprecated: Comments starting with '#' are deprecated in ……...
  4. 中移动飞信2010Beta1.0体验版
  5. php网站服务器500,php服务器错误500
  6. 厦门大学c语言上机答案,厦门大学C语言程序设计2016模拟题讲评及课程复习.pptx...
  7. [SQL实战]之统计出当前各个title类型对应的员工当前薪水对应的平均工资
  8. Tomcat 部署多个项目出现错误
  9. 密歇根州立大学联合领英提出基于AutoML的Embedding框架AutoDim
  10. 《彩虹屁》快夸夸我!彩虹屁生成器
  11. 计算机 64虚拟内存设置方法,win7 64位系统虚拟内存设置及虚拟内存太小的影响...
  12. 密码学大事件! SHA-1 哈希碰撞实例
  13. java正态分布随机数产生方法
  14. 微信开发如何优雅的注入token(2)
  15. 日出日落时间和年均光照时长计算 java
  16. 获取163联系人名字和邮箱地址
  17. poi实现单元格行合并
  18. 排名趋于稳定后,最新的博主排名(TOP10)
  19. Fortran语法汇总(上)
  20. 闽南师范大学计算机系实力,这5所地方师范大学实力挺强,在本地很受认可,性价比高...

热门文章

  1. bind9配置转发服务器
  2. SSH2整合完整案例(四十三)
  3. Pivot Table
  4. 南京趋势科技开发测试岗实习面经(2020-11-11)
  5. Android Dialog中监听返回键事件
  6. 【计算机网络】第七话·计算机网络的流量控制
  7. 程序员版本的八荣八耻~
  8. 虚拟机下linux的静态ip地址配置
  9. 首个双手控制脑机接口:开颅手术 10 小时植入 6 个电极,瘫痪人士用意念吃蛋糕
  10. 记:解决腾讯云服务器程序被恶意登录植入进程