PyTorch中Tensor的维度变换

对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看。

维度查看:torch.Tensor.size()

查看当前 tensor 的维度

举个例子:

>>> import torch
>>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]])
>>> a.size()
torch.Size([1, 3, 2])

张量变形:torch.Tensor.view(*args) → Tensor

返回一个有相同数据但大小不同的 tensor。 返回的 tensor 必须有与原 tensor 相同的数据和相同数目的元素,但可以有不同的大小。一个 tensor 必须是连续的 contiguous() 才能被查看。

举个例子:

>>> x = torch.randn(2, 9)
>>> x.size()
torch.Size([2, 9])
>>> x
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038,  0.5166,  0.9774,0.3455],[-0.2306,  0.4217,  1.2874, -0.3618,  1.7872, -0.9012,  0.8073, -1.1238,-0.3405]])
>>> y = x.view(3, 6)
>>> y.size()
torch.Size([3, 6])
>>> y
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038],[ 0.5166,  0.9774,  0.3455, -0.2306,  0.4217,  1.2874],[-0.3618,  1.7872, -0.9012,  0.8073, -1.1238, -0.3405]])
>>> z = x.view(2, 3, 3)
>>> z.size()
torch.Size([2, 3, 3])
>>> z
tensor([[[-1.6833, -0.4100, -1.5534],[-0.6229, -1.0310, -0.8038],[ 0.5166,  0.9774,  0.3455]],[[-0.2306,  0.4217,  1.2874],[-0.3618,  1.7872, -0.9012],[ 0.8073, -1.1238, -0.3405]]

可以看到 x 和 y 、z 中数据的数量和每个数据的大小都是相等的,只是尺寸或维度数量发生了改变。

压缩 / 解压张量:torch.squeeze()、torch.unsqueeze()

torch.squeeze(input, dim=None, out=None)

将输入张量形状中的 1 去除并返回。如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

当给定 dim 时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B),squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

举个例子:

>>> x = torch.randn(3, 1, 2)
>>> x
tensor([[[-0.1986,  0.4352]],[[ 0.0971,  0.2296]],[[ 0.8339, -0.5433]]])
>>> x.squeeze().size() # 不加参数,去掉所有为元素个数为1的维度
torch.Size([3, 2])
>>> x.squeeze()
tensor([[-0.1986,  0.4352],[ 0.0971,  0.2296],[ 0.8339, -0.5433]])
>>> torch.squeeze(x, 0).size() # 加上参数,去掉第一维的元素,不起作用,因为第一维有2个元素
torch.Size([3, 1, 2])
>>> torch.squeeze(x, 1).size() # 加上参数,去掉第二维的元素,正好为 1,起作用
torch.Size([3, 2])

可以看到如果加参数,只有维度中尺寸为 1 的位置才会消失

torch.unsqueeze(input, dim, out=None)

返回一个新的张量,对输入的制定位置插入维度 1

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果 dim 为负,则将会被转化 dim+input.dim()+1

接着用上面的数据举个例子:

>>> x.unsqueeze(0).size()
torch.Size([1, 3, 1, 2])
>>> x.unsqueeze(0)
tensor([[[[-0.1986,  0.4352]],[[ 0.0971,  0.2296]],[[ 0.8339, -0.5433]]]])
>>> x.unsqueeze(-1).size()
torch.Size([3, 1, 2, 1])
>>> x.unsqueeze(-1)
tensor([[[[-0.1986],[ 0.4352]]],[[[ 0.0971],[ 0.2296]]],[[[ 0.8339],[-0.5433]]]])

可以看到在指定的位置,增加了一个维度。

扩大张量:torch.Tensor.expand(*sizes) → Tensor

返回 tensor 的一个新视图,单个维度扩大为更大的尺寸。 tensor 也可以扩大为更高维,新增加的维度将附在前面。 扩大 tensor 不需要分配新内存,只是仅仅新建一个 tensor 的视图,其中通过将 stride 设为 0,一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。

举个例子:

>>> x = torch.Tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[1., 1., 1., 1.],[2., 2., 2., 2.],[3., 3., 3., 3.]])
>>> x.expand(3, -1)
tensor([[1.],[2.],[3.]])

原数据是 3 行 1 列,扩大后变为 3 行 4 列,方法中填 -1 的效果与 1 一样,只有尺寸为 1 才可以扩大,如果不为 1 就无法改变,而且尺寸不为 1 的维度必须要和原来一样填写进去。
重复张量:torch.Tensor.repeat(*sizes)

沿着指定的维度重复 tensor。 不同于 expand(),本函数复制的是 tensor 中的数据。

举个例子:

>>> x = torch.Tensor([1, 2, 3])
>>> x.size()
torch.Size([3])
>>> x.repeat(4, 2)[1., 2., 3., 1., 2., 3.],[1., 2., 3., 1., 2., 3.],[1., 2., 3., 1., 2., 3.]])
>>> x.repeat(4, 2).size()
torch.Size([4, 6])

原数据为 1 行 3 列,按行方向扩大为原来的 4 倍,列方向扩大为原来的 2 倍,变为了 4 行 6 列。

变化时可以看成是把原数据作成一个整体,再按指定的维度和尺寸重复,变成一个 4 行 2 列的矩阵,其中的每一个单位都是相同的,再把原数据放到每个单位中。
矩阵转置:torch.t(input, out=None) → Tensor

输入一个矩阵(2维张量),并转置0, 1维。 可以被视为函数 transpose(input, 0, 1) 的简写函数。

举个例子:

>>> x = torch.randn(3, 5)
>>> x
tensor([[-1.0752, -0.9706, -0.8770, -0.4224,  0.9776],[ 0.2489, -0.2986, -0.7816, -0.0823,  1.1811],[-1.1124,  0.2160, -0.8446,  0.1762, -0.5164]])
>>> x.t()
tensor([[-1.0752,  0.2489, -1.1124],[-0.9706, -0.2986,  0.2160],[-0.8770, -0.7816, -0.8446],[-0.4224, -0.0823,  0.1762],[ 0.9776,  1.1811, -0.5164]])
>>> torch.t(x) # 另一种用法
tensor([[-1.0752,  0.2489, -1.1124],[-0.9706, -0.2986,  0.2160],[-0.8770, -0.7816, -0.8446],[-0.4224, -0.0823,  0.1762],[ 0.9776,  1.1811, -0.5164]])

必须要是 2 维的张量,也就是矩阵,才可以使用。

维度置换:torch.transpose()、torch.Tensor.permute()

torch.transpose(input, dim0, dim1, out=None) → Tensor

返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存,所以改变其中一个会导致另外一个也被修改。

举个例子:

>>> x = torch.randn(2, 4, 3)
>>> x
tensor([[[-1.2502, -0.7363,  0.5534],[-0.2050,  3.1847, -1.6729],[-0.2591, -0.0860,  0.4660],[-1.2189, -1.1206,  0.0637]],[[ 1.4791, -0.7569,  2.5017],[ 0.0098, -1.0217,  0.8142],[-0.2414, -0.1790,  2.3506],[-0.6860, -0.2363,  1.0481]]])
>>> torch.transpose(x, 1, 2).size()
torch.Size([2, 3, 4])
>>> torch.transpose(x, 1, 2)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],[-0.7363,  3.1847, -0.0860, -1.1206],[ 0.5534, -1.6729,  0.4660,  0.0637]],[[ 1.4791,  0.0098, -0.2414, -0.6860],[-0.7569, -1.0217, -0.1790, -0.2363],[ 2.5017,  0.8142,  2.3506,  1.0481]]])
>>> torch.transpose(x, 0, 1).size()
torch.Size([4, 2, 3])
>>> torch.transpose(x, 0, 1)
tensor([[[-1.2502, -0.7363,  0.5534],[ 1.4791, -0.7569,  2.5017]],[[-0.2050,  3.1847, -1.6729],[ 0.0098, -1.0217,  0.8142]],[[-0.2591, -0.0860,  0.4660],[-0.2414, -0.1790,  2.3506]],[[-1.2189, -1.1206,  0.0637],[-0.6860, -0.2363,  1.0481]]])

可以对多维度的张量进行转置

torch.Tensor.permute(dims)

将 tensor 的维度换位

接着用上面的数据举个例子:

>>> x.size()
torch.Size([2, 4, 3])
>>> x.permute(2, 0, 1).size()
torch.Size([3, 2, 4])
>>> x.permute(2, 0, 1)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],[ 1.4791,  0.0098, -0.2414, -0.6860]],[[-0.7363,  3.1847, -0.0860, -1.1206],[-0.7569, -1.0217, -0.1790, -0.2363]],[[ 0.5534, -1.6729,  0.4660,  0.0637],[ 2.5017,  0.8142,  2.3506,  1.0481]]])

直接在方法中填入各个维度的索引,张量就会交换指定维度的尺寸,不限于两两交换。

参考博客:https://blog.csdn.net/weixin_44613063/article/details/89521464

Pytorch常用张量变换操作相关推荐

  1. 收藏!PyTorch常用代码段合集

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:Jack Stark,来源:极市平台 来源丨https://zhu ...

  2. PyTorch常用代码段合集

    ↑ 点击蓝字 关注视学算法 作者丨Jack Stark@知乎 来源丨https://zhuanlan.zhihu.com/p/104019160 极市导读 本文是PyTorch常用代码段合集,涵盖基本 ...

  3. 【深度学习】PyTorch常用代码段合集

    来源 | 极市平台,机器学习算法与自然语言处理 本文是PyTorch常用代码段合集,涵盖基本配置.张量处理.模型定义与操作.数据处理.模型训练与测试等5个方面,还给出了多个值得注意的Tips,内容非常 ...

  4. pytorch list转tensor_PyTorch 52.PyTorch常用代码段合集

    本文参考于: Jack Stark:[深度学习框架]PyTorch常用代码段​zhuanlan.zhihu.com 1. 基本配置 导入包和版本查询: import torch import torc ...

  5. 收藏 | PyTorch常用代码段合集

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨Jack Stark@知乎 来源丨https://zh ...

  6. 神经网络与深度学习(二) pytorch入门——张量

    本文章通过参考飞桨AI Studio - 人工智能学习与实训社区  教程进行pytorch相关学习. 目录 一. 概念:张量.算子 二. 使用pytorch实现张量运算 1.2.1 创建张量 1.2. ...

  7. 深度盘点:PyTorch常用代码段合集

    本文是PyTorch常用代码段合集,涵盖基本配置.张量处理.模型定义与操作.数据处理.模型训练与测试等5个方面,还给出了多个值得注意的Tips,内容非常全面. PyTorch最好的资料是官方文档.本文 ...

  8. PyTorch 常用代码段整理合集

    PyTorch 常用代码段整理合集 来源:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知 ...

  9. 赶快收藏,PyTorch 常用代码段PDF合辑版来了

    前段时间我分享了 PyTorch 常用代码段合集,涵盖基本配置.张量处理.模型定义与操作.数据处 理.模型训练与测试等5个方面,还给出了多个值得注意的Tips. 这篇文章发布后,收到了很多朋友的喜爱和 ...

最新文章

  1. Platform Builder 6.0与Windows 7兼容性的问题
  2. REVIT使用中遇到的各种问题汇总
  3. [NewLife.XCode]实体工厂(拦截处理实体操作)
  4. ios进度条Demo一个
  5. SQLSERVER数据库设置varchar类型主键自增方法
  6. nyoj 8 一种排序(用vector,sort,不用set)
  7. 软件工程 超市库存管理系统 UML模型
  8. PB高拍仪无纸化软件方案
  9. 分位数回归 Quantile Regression,python 代码
  10. 【MM 容差】采购订单中的容差
  11. MMR 排序多样化重排序算法
  12. Serein 【懒人神器】一款图形化、批量采集url、批量对采集的url进行各种nday检测的工具 摸鱼项目问题解决
  13. 保护Android网络数据教程
  14. 双向可控硅晶片光耦 (TLP160J TLP260J TLP525G) 基本原理及应用实例
  15. 元宇宙:虚拟仿真技术的全面提升
  16. pattern和match的用法 java篇
  17. BigDecimal 元转分-加减乘除、百分比
  18. 万字技术干货 |YMatrix 高性能时序数据库引擎的技术实践
  19. 用计算机弹出当当当,电脑发出“当当”的声音怎么办
  20. 现代金融经济的眼重看历史[程序员学经济二]

热门文章

  1. ar 华为路由器 端口映射_求教华为AR2200路由器端口映射配置
  2. 修改了svn服务器配置,配置http方式访问svn服务器
  3. Eslint +Vue配置
  4. 【转】浅说语音用户界面:VUI+GUI
  5. 关闭Eslint中的规则 no-unused-vars
  6. 传智播客我来啦!!!
  7. 传智播客mysql分页的实现_传智播客 2010-03-07 员工信息的AJAX分页实现
  8. ceph---luminous版的安装
  9. 操盘手与散户老妈的对话 看完后所有人都沉默了
  10. 照片尺寸对照表[转]