pytorch_bannar.png

import numpy as np

import torch

甜点

在学习神经网时,我们总是喜欢将时间花在如何通过用代码实现模型上,而往往对 tensor 在网络每一层的变化似乎不那么在意,可能觉得观察 tensor 形状比较简单和枯燥。

pytorch 改变 tensor 形状的 Api

view/reshape 改变形状

Squeeze/unsqueeze 增加维度/删减维度

transpose/permute 变换维度

Expand/repeat 维度扩展

高维 tensor

对于高纬 tensor,我们主要理解好后 2 个维度,可以理解为平面,3 维表示立体形状,随着维度增加我们就可以将每一个维度理解为容器或者盒子,更高维可以理解为装着低纬的容器或盒子。

改变形状

在 numpy 中使用 reshape 对 tensor 的形状进行改变,而在 pytorch 我们可以用 view 和 reshape 方法对 tensor 形状进行改变,他们除了名字不同,其他并没有什么区别,所以这里就以 view 为例来说一说如何改变 tensor 形状。

如果大家写过几个图片分类简单网络,这个形状 tensor 应该不会陌生,表示 4 张 1 个通道高度和宽度分别为 28 的图片。如果我们要用全连接网络进行识别,需要将高度和宽度拉平再输入到全连接神经网。这是就会用 view ,通过调用 tensor 的 view 然后传入要转换的形状的即可。

a = torch.rand(4,1,28,28)

a.view(4,28*28)

tensor([[0.6980, 0.3745, 0.9242, ..., 0.0148, 0.1390, 0.5306],

[0.5350, 0.2231, 0.6127, ..., 0.3930, 0.1939, 0.8876],

[0.3861, 0.1119, 0.3781, ..., 0.1558, 0.6248, 0.4389],

[0.3650, 0.8685, 0.7593, ..., 0.9291, 0.0493, 0.5362]])

# 总的维度数是不变的

a.view(4,28*28).shape

torch.Size([4, 784])

a.view(4*28,28)

tensor([[0.6980, 0.3745, 0.9242, ..., 0.8922, 0.5229, 0.4496],

[0.0137, 0.8016, 0.1643, ..., 0.1254, 0.4681, 0.6502],

[0.4278, 0.9356, 0.3542, ..., 0.5995, 0.4755, 0.6840],

...,

[0.9431, 0.3490, 0.0361, ..., 0.5326, 0.4426, 0.3506],

[0.6691, 0.0943, 0.7266, ..., 0.6576, 0.3677, 0.4801],

[0.2223, 0.3585, 0.4722, ..., 0.9291, 0.0493, 0.5362]])

a.view(4*28,28).shape

torch.Size([112, 28])

b = a.view(4,784)

b.shape

torch.Size([4, 784])

b.view(4,28,28,1).shape

torch.Size([4, 28, 28, 1])

# 值得注意得是 view 将通道从 1 位置变为 3 位置

a.shape

torch.Size([4, 1, 28, 28])

增加和删除维度

squeeze 和 unsqueeze 分别是对 tensor 进行删除维度和增加维度。

Pos.Idx

0

1

2

3

Neg.Idx

-4

-3

-2

-1

在 Pos.Idx(正向)指定维度前插入一个维度,在 Neg.Idx(负向)指定维度之后插入维度

# 在 0 前插入维度,那么 (4,1,28,28) 0 维前添加 (1,4,1,28,28)

a.unsqueeze(0).shape

torch.Size([1, 4, 1, 28, 28])

# 在 -1(3) 后插入维度,那么 (4,1,28,28) 0 维前添加 (4,1,28,28,1)

a.unsqueeze(-1).shape

torch.Size([4, 1, 28, 28, 1])

# 在 -4(0) 后插入维度,那么 (4,1,28,28) 0 维前添加 (4,1,1,28,28)

a.unsqueeze(-4).shape

torch.Size([4, 1, 1, 28, 28])

# 在 -5 后插入维度就相当在 0 维度前添加维度,那么 (4,1,28,28) 0 维前添加 (1,4,1,28,28)

a.unsqueeze(-5).shape

torch.Size([1, 4, 1, 28, 28])

a = torch.tensor([1.2,2.2])

a

tensor([1.2000, 2.2000])

a.unsqueeze(-1)

tensor([[1.2000],

[2.2000]])

a.unsqueeze(0)

tensor([[1.2000, 2.2000]])

# 是列

b = torch.rand(32)

f = torch.rand(4,32,14,14)

b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)

b.shape

torch.Size([1, 32, 1, 1])

b.squeeze().shape

torch.Size([32])

b.squeeze(0).shape

torch.Size([32, 1, 1])

b.squeeze(-1).shape

torch.Size([1, 32, 1])

# 没有报错

b.squeeze(1).shape

torch.Size([1, 32, 1, 1])

b.squeeze(-4).shape

torch.Size([32, 1, 1])

维度扩展

Expand 返回当前 tensor 在某维扩展更大后的 tensor expand不会分配新的内存,只是在存在的 tensor 上创建一个新的视图 view.

Repeat: 沿着特定的维度重复这个 tensor ,和 expand()不同的是,这个函数拷贝 tensor 的数据。

a = torch.rand(4,32,14,14)

b.shape

torch.Size([1, 32, 1, 1])

# 将原有维度进行扩展

b.expand(4,32,14,14).shape

torch.Size([4, 32, 14, 14])

# -1 表示在该维度上并不改变形状

b.expand(-1,32,-1,-1).shape

torch.Size([1, 32, 1, 1])

b.expand(-1,32,-1,-4).shape

torch.Size([1, 32, 1, -4])

# 对 0 和 1 维进行扩展维

b.repeat(4,32,1,1).shape

torch.Size([4, 1024, 1, 1])

b.repeat(4,1,1,1).shape

torch.Size([4, 32, 1, 1])

b.repeat(4,1,32,32).shape

torch.Size([4, 32, 32, 32])

tensor 转置

转置只能适用于 2 维 tensor,通过交换行列来实现 tensor 维度变换。

a = torch.randn(2,3)

a.t().shape

torch.Size([3, 2])

# 交换的操作 [bcHW] 交换后变为 [bWHc]

a = torch.rand(4,3,32,32)

# transpose 操作交换了 1 轴和 3 轴 view 操作就是改变 a 的形状(reshape),

# [b,c,W,H],

# 原来存储数据顺序结构也改变了

a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)

---------------------------------------------------------------------------

RuntimeError Traceback (most recent call last)

in

----> 1 a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)

a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)

torch.all(torch.eq(a,a1))

tensor(False)

torch.all(torch.eq(a,a2))

tensor(True)

# 创建表示图片

a = torch.rand(4,3,28,28)

# 对 tensor 1 和 3 轴进行交换

a.transpose(1,3).shape

torch.Size([4, 28, 28, 3])

b = torch.rand(4,3,28,32)

b.transpose(1,3).shape

torch.Size([4, 32, 28, 3])

b.transpose(1,3).transpose(1,2).shape

torch.Size([4, 28, 32, 3])

# permut

# [b c H W] 变为 [b H W c] [0 1 2 3] 变为 [0 2 3 1]

b.permute(0,2,3,1).shape

torch.Size([4, 28, 32, 3])

tensor如何实现转置_pytorch tensor 变换相关推荐

  1. tensor如何实现转置_PyTorch中的傅立叶卷积:通过FFT有效计算大核卷积的数学原理和代码实现...

    卷积 卷积在数据分析中无处不在.几十年来,它们已用于信号和图像处理.最近,它们已成为现代神经网络的重要组成部分. 在数学上,卷积表示为: 尽管离散卷积在计算应用程序中更为常见,但由于本文使用连续变量证 ...

  2. python开方运算符_Pytorch Tensor基本数学运算详解

    1. 加法运算 示例代码: import torch # 这两个Tensor加减乘除会对b自动进行Broadcasting a = torch.rand(3, 4) b = torch.rand(4) ...

  3. Tensor to img imge to tensor (pytorch的tensor转换)

    Tensor to img && imge to tensor 在pytorch中经常会遇到图像格式的转化,例如将PIL库读取出来的图片转化为Tensor,亦或者将Tensor转化为n ...

  4. IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in

    使用python pytorch框架出现问题: IndexError: invalid index of a 0-dim tensor. Use tensor.item() in Python or ...

  5. torch学习笔记--tensor介绍2,对tensor的结构

    本章将介绍tensor的结构与函数 torch.Tensor():返回一个空tensor. torch.Tensor(tensor):返回一个拥有相同内存的tensor,类似于指针.不是重新开辟一个内 ...

  6. torch.Tensor(dim)与torch.Tensor((dim)), torch.Tensor(dim1,dim2)与torch.Tensor((dim1,dim2))的区别

    1 torch.Tensor(dim)与torch.Tensor((dim))的区别 从三张截图可以看出这两者其实是完全一样的,都表示的是这个张量的维度而不是这个张量的数据,其中第一处之所以不同是因为 ...

  7. TypeError: can‘t convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory

    项目场景: 运行程序,出现报错信息 TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor ...

  8. TypeError: can‘t convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to

    问题描述: Traceback (most recent call last):File "D:\rotation-yolov5-master\detect.py", line 1 ...

  9. tensor如何实现转置_转置()TensorFlow中的函数

    转置是TensorFlow中提供的函数.此函数用于转置输入张量.语法:转置(input_tensor,perm,conjugate)参数:input_tensor:顾名思义,它是要转置的张量.类型:T ...

最新文章

  1. python在会计工作中的应用-浅谈各行各业到底该如何应用python?
  2. git clone 失败
  3. WeihanLi.Npoi 1.18.0 Released
  4. php 点对点,浅析点对点(End-to-End)的场景文字识别
  5. 专栏推荐丨Oracle Database 21c 专栏
  6. 软硬交互代码示例_HarmonyOS应用开发-元程序交互
  7. Htmlunit 使用记录
  8. opencv鼠标回调函数实现ROI区域像素值相同化
  9. 彻底卸载sql sever 2005
  10. 求1到20的阶乘之和
  11. Ubuntu文件目录结构详解
  12. SSM框架的原理和运行流程
  13. 中国医科大学网络教育学院试卷计算机,中国医科大学网络教育学院补考试卷
  14. 全面详解互联网企业开放API的 “守护神”
  15. ubantu apt命令失败
  16. ios 系统状态栏样式修改_iOS 导航栏颜色和状态栏颜色修改
  17. jsp页面打开为空白页
  18. cs224w(图机器学习)2021冬季课程学习笔记16 Community Detection in Networks
  19. EXCEL表格-系统时间及进度自动记录工具制作
  20. linux 内核 文件到磁盘影射

热门文章

  1. 静态页面 常见问题 margin-top塌陷、padding把盒子撑大
  2. java读取文件夹下所有文件并替换文件每一行中指定的字符串
  3. php语言中的符号,php语言中的面向对象
  4. css3 性能优化之 will-change 属性
  5. 个人微信壁纸小程序正式上线
  6. 路由控制配置apply cost命令解析
  7. redirect重定向
  8. 手机触摸屏扫描信号实测波形
  9. from Crypto.Cipher import AES报错解决【WindowsLinux】
  10. BUUCTF MISC刷题笔记(三)