torch学习二之nn.Convolution

  • nn.Conv1d
    • 函数参数
    • 输入数据维度转换
    • 关于kernel
  • nn.Conv2D

nn.Conv1d

一维卷积通常用于处理文本数据

函数参数

首先看一下官网定义

CLASS torch.nn.Conv1d(in_channels: int, out_channels: int, kernel_size: Union[T, Tuple[T]],
stride: Union[T, Tuple[T]] = 1, padding: Union[T, Tuple[T]] = 0,
dilation: Union[T, Tuple[T]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros')

一般我们需要关注的是以下几个参数:

  1. int_channels:

    输入通道,在文本数据中,即词向量的维数 word_embedding_size

  2. out_channels:

    输出通道,可以认为是特征的维数,做一卷积(对应一个卷积核)可以得到一个特征,所以输出通道数决定了做卷积的
    (这里对于遍的理解看完下面的示例图就明白了)

  3. kernel_size:

    卷积核的大小,可以理解为多少个词之间进行交互
    卷积之后的维数为 ( (word_embedding_size - kernel_size) / stride + 1)
    但是一般都会做一个pooling,因为不同的卷积核最后输出的维数不同,用pooling统一,一般为max_pool,最后都是 out_channels * 1
    此时,word_embedding_size = 10,kernel_size = 2,stride = 1

  4. stride:

    kernel 移动的步长,参照上图

  5. padding:

    由上面可知,做完一遍卷积后维数降低,如果希望维数保持不变(或自定义),可以在向量两边做padding,padding则定义了你希望在两边增加的维数,padding_mode 则定义了你希望补充的数是多少

输入数据维度转换

通常我们输入的数据格式如下:

batch_size(sentence_number) * sentence_max_length * word_embedding_size

但是,要明确,我们希望的卷积是在一个句子的词与词之间进行,而不是在一个句子的一个词的embedding里进行,而一维卷积是对最后一维进行的,所以需要先对tensor进行一下维度转换,使得sentence_max_length那一维是最后一维。

tensor.permute()

batch_size(sentence_number) * word_embedding_size * sentence_max_length

关于kernel

虽然只定义了 kernel_size,但是 kernel 并不是一维的,其实是一个二维的矩阵,大小为 kernel_size * in_channels


这样做完一遍卷积会得到一个channel,其维数为( (sentence_max_length - kernel_size) / stride + 1)

需要做 out_channels遍卷积,最终输出为:

batch_size(sentence_number) * out_channels * ((sentence_max_length - kernel_size) / stride + 1)


参考 : 链接

更新:对于这里 out_channel 的理解

图片源于:链接

nn.Conv2D

这里一般处理图片数据

基本与Conv1D类似,注意这里 kernel 大小为 kernel_size * kernel_size (或者是定义的元组的大小)

详情请看 链接

torch学习二之nn.Convolution相关推荐

  1. torch学习 (二十四):卷积神经网络之GoogleNet

    文章目录 引入 1 Inception块 2 GoogleNet模型 3 模型训练 完整代码 util.SimpleTool 引入   GoogleNet吸收了NIN网络串联网络的思想,并在此基础上做 ...

  2. Pytorch学习(二)—— nn模块

    torch.nn nn.Module 常用的神经网络相关层 损失函数 优化器 模型初始化策略 nn和autograd nn.functional nn和autograd的关系 hooks简介 模型保存 ...

  3. Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化

    Pytorch 学习(6):Pytorch中的torch.nn  Convolution Layers  卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...

  4. PyTorch框架学习二十——模型微调(Finetune)

    PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...

  5. 深度学习二(Pytorch物体检测实战)

    深度学习二(Pytorch物体检测实战) 文章目录 深度学习二(Pytorch物体检测实战) 1.PyTorch基础 1.1.基本数据结构:Tensor 1.1.1.Tensor数据类型 1.1.2. ...

  6. [pytorch] PyTorch Metric Learning库代码学习二 Inference

    PyTorch Metric Learning库代码学习二 Inference Install the packages Import the packages Create helper funct ...

  7. torch学习 (十七):填充和步幅

    文章目录 引入 1 填充 2 步幅 引入   一般来说,对于输入为 n h × n w n_h \times n_w nh​×nw​的矩阵,以及 k h × k w k_h \times k_w kh ...

  8. 幻方萤火 | 性能卓越的深度学习算子 hfai.nn

    深度学习框架的流行(如 PyTorch,Tensorflow 等)极大方便了我们研发设计各种各样的 AI 模型,而在实际落地的环节中,孵化于实验室里的模型代码往往在生产环境上面临着性能.准确度.资源等 ...

  9. PyTorch框架学习二——基本数据结构(张量)

    PyTorch框架学习二--基本数据结构(张量) 一.什么是张量? 二.Tensor与Variable(PyTorch中) 1.Variable 2.Tensor 三.Tensor的创建 1.直接创建 ...

最新文章

  1. 2022-2028年中国锂电池用聚烯烃隔膜行业市场发展调研及投资方向分析报告
  2. 优先深度搜索判断曲线相交_程序员必知的十大基础实用算法之-DFS(深度优先搜索)...
  3. windwos -- bat脚本
  4. JavaScript之Style属性学习
  5. android标尺自定义view,android尺子的自定义view——RulerView详解
  6. 薅羊毛的齐家网遭增长瓶颈,互联网家装迎来破局者!1-06-13
  7. 标签体系、用户分群、用户画像「玩味」解读,你沦为形式主义了吗?
  8. [Android]第四次作业
  9. php根据地址获取经纬度
  10. readonly和const比较
  11. 圆形取景框 相机_据说这款设备可以使老旧单反相机解决无线联机拍摄方案
  12. nagios监控mysql主从
  13. JAVA Pattern和Matcher 的用法
  14. 决策树C4.5算法的不足
  15. ask调制与解调matlab仿真,ask调制与解调的matlab仿真.doc
  16. zookeeper因内存不足造成的CPU占用率高
  17. h3c交换机配置远程管理_h3c 交换机配置VLAN和远程管理
  18. yaml-cpp保存标定文件-Node/Emitter
  19. 设置mathtype章节号显示与隐藏
  20. 海思开发板FFmpeg+Nginx,推流RTMP播放(优秀教程收集+实操整理)

热门文章

  1. 计算机在护士行业的应用情况,【医院护理论文】医院护理信息化实施的现状及未来分析(共3632字)...
  2. 黑社会交易用计算机,遭遇网络黑社会亲们都是怎么处理的
  3. 【Tableau Desktop 企业日常技巧6.0】Tableau如何将示例工作簿替换为自定义工作簿?(windows版本)
  4. 全球最大的社交编程及代码托管网站Github介绍
  5. [转]Jarvis OJ- [XMAN]level2/3_x64-Writeup
  6. 学习笔记-Metasploit
  7. 低性能单用户计算机,低性能单用户计算机I/O系统的设计主要考虑解决好CPU与内存、I/O设备在速度上的巨大差距。...
  8. vue环境中bpmn工具实现翻译汉化
  9. 有没有不用布线的家用监控摄像头?
  10. 基于MMDetection训练VOC格式数据集