最近在研究EDSR代码(项目地址:https://github.com/sksq96/pytorch-summary)的时候看到了forward_chop function,该参数的help写的是

parser.add_argument('--chop', action='store_true',help='use memory-efficient forward')

于是对其怎样加速进行了学习,将思路和大家分享一下。
这部分代码黏贴如下:

首先是开始部分的代码,我们假设最开始的输入是 tuple(tensor(1, 3, 678, 1020)) (里面的是size, 表示 1 x 3 x 678 x 1020)

def forward_chop(self, *args, shave=10, min_size=160000):n_GPUs = min(self.args.n_GPUs, 4)h, w = args[0].size()[-2:]top, left = slice(0, h//2 + shave), slice(0, w//2 + shave)bottom, right = slice(h - h//2 - shave, h), slice(w - w//2 - shave, w)x_chops = [torch.cat([a[..., top, left],a[..., top, right],a[..., bottom, left],a[..., bottom, right]]) for a in args]

这部分将输入的一张图片拆分成了左上、左下、右上、右下四个部分并按照batch dim进行了拼接。
这样得到 x_chops = [tensor(4, 3, 349, 520)],这里长宽各增加了一个shave * 2
接下来是对图片的大小进行判断:

 y_chops = []if h * w < 4 * min_size:for i in range(0, 4, n_GPUs):x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops]y = P.data_parallel(self.model, *x, device_ids=range(n_GPUs))if not isinstance(y, list):y = [y]if not y_chops:y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y]else:for y_chop, _y in zip(y_chops, y):y_chop.extend(_y.chunk(n_GPUs, dim=0))else:for x in zip(*x_chops):p = tuple(_x.unsqueeze(0) for _x in x)y = self.forward_chop(*p, shave=shave, min_size=min_size)if not isinstance(y, list):y = [y]if not y_chops:y_chops = [[_y] for _y in y]else:for y_chop, _y in zip(y_chops, y):y_chop.append(_y)

这里由于 678 x 1020 - 4 * 160000 = 51560 > 0,所以走的 else语句
这里得到 p = (1, 3, 349, 520) 然后再作为输入调用forward_chop。前面还是一样的,只不过在这里走的是if语句。注意此时if语句中的x 的大小是和n_GPUs 有关的,也就是说如果你的n_GPUs == 1,那么你 x 的size为(1, 3, 184, 270), 如果是2,就是(2, 3, 184, 270),这样也就将多个batch 作为并行输入到你的GPU device 上加快速度。这里由于我的是1个GPU,整个for循环运行完之后可以得到y_chops = [[tensor(1, 3, 386, 540), tensor(1, 3, 386, 540), tensor(1, 3, 386, 540), tensor(1, 3, 386, 540)]]
接下来就是恢复原大小了:

 h *= self.args.scalew *= self.args.scaletop, left = slice(0, h//2), slice(0, w//2)bottom, right = slice(h - h//2, h), slice(w - w//2, w)bottom_r, right_r = slice(h//2 - h, None), slice(w//2 - w, None)b, c = y_chops[0][0].size()[:-2]y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops]for y_chop, _y in zip(y_chops, y):_y[..., top, left] = y_chop[0][..., top, left]_y[..., top, right] = y_chop[1][..., top, right_r]_y[..., bottom, left] = y_chop[2][..., bottom_r, left]_y[..., bottom, right] = y_chop[3][..., bottom_r, right_r]if len(y) == 1:y = y[0]return y

可以看到 通过 _y 将拆分的四个部分再按位置拼接回原来的大小,最后得到一个tensor(1, 3, 698, 1040)和原输入大小相同的tensor。
总结一下,在这里通过将一张大的图片分成四个部分并行做输入进行测试,再将分别测试得到的输出拼接成原大小得到原图直接做forward的结果,相当于是减少了长和宽增加了batch_size。

以上是个人观点,如有不对还请大佬指出。

Pytorch 在 forward 函数中加速神经网络。相关推荐

  1. 1分钟理解pytorch的reshape函数中-1表示的意义

    先说答案,reshape函数中-1代表的是n,什么意思呢,函数中另一个参数决定了-1的值,看下面三张图就很容易理解了 定义34的张量,reshape(-1, 1),你把它想象成要转换成n1的矩阵,那是 ...

  2. Pytorch的BCEWithLogitsLoss函数中忽视标签怎么实现

    1.尝试: >>> import torch >>> from torch import nn >>> loss = nn.BCEWithLogi ...

  3. forward函数——浅学深度学习框架中的forward

    1.什么是forward函数 (本应该出一篇贯穿神经网络的文章的,但是由于时间关系,就先浅浅记录一下,加深自己的理解吧吧). forward 函数是深度学习框架中常见的一个函数,用于定义神经网络的前向 ...

  4. 继承nn.Module后的 init与forward函数【trian_val、vgg16、faster_rcnn、rpn】.py 学习 文件结构 大工程安排

    本篇文章主要是用来学习大工程的构造,具体包括如何进行init和forward,如何层层递进,高层设置输入,传入底层的input中. 从train_val.py中的初始化vgg开始,这里调用了vgg的初 ...

  5. 神经网络模型中class的forward函数何时调用_总结深度学习PyTorch神经网络箱使用...

    ↑ 点击蓝字 关注极市平台来源丨计算机视觉联盟编辑丨极市平台 极市导读 本文介绍了Pytorch神经网络箱的使用,包括核心组件.神经网络实例.构建方法.优化器比较等内容,非常全面.>>加入 ...

  6. pytorch中的forward函数详细理解

    文章目录 前言 forward 的使用 forward 使用的解释 前言 最近在使用pytorch的时候,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数就可以自动调用 fo ...

  7. Pytorch中什么时候调用forward()函数

    Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型.下面继承Module类构造本节开头提到的多层感知机.这里定义的MLP类重载了Module类 ...

  8. PyTorch函数中的__call__和forward函数

    初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录 init & call 代码: class A():def __init__(self):print('init ...

  9. pytorch自定义forward和backward函数

    pytorch会自动求导,但是当遇到无法自动求导的时候,需要自己认为定义求导过程,这个时候就涉及到要定义自己的forward和backward函数. 举例如下: 看到这里,大家应该会有很多疑问,比如: ...

最新文章

  1. mysql存储的判断if_if在数据库mysql存储中判断
  2. win32应用程序创建流程
  3. Django框架(24.Django中的模板的自定义过滤器)
  4. boost::mpl模块实现upper_bound相关的测试程序
  5. 线程基础知识系列(三)线程的同步
  6. 余额宝好日子到头,以后不能随存随取了!
  7. thinkpaidE480office安装文件夹
  8. Tomcat 日志文件分割
  9. LINUX没有SVN,怎么知道哪些文件修改了
  10. 创维E900V21E机顶盒刷机固件 解决:不用设置有线自动连网
  11. 无线路由器密码破解最新教程完整版
  12. Qt之如何识别小键盘(数字键盘)
  13. keras优化算法_Keras实现两个优化器:Lookahead和LazyOptimizer
  14. java相对路径的写法格式_java相对路径的写法
  15. 白糖详细 制造工艺、等级划分、国家标准号和注意事项
  16. 最新!7月份火爆Github的热门Python项目
  17. 水苔可以种什么植物? 湖南水苔农业开发有限公司
  18. 华硕无线路由打印机服务器,华硕RT-AC86U路由器怎么共享打印机
  19. [原]解密Airbnb 自助BI神器:Superset 颠覆 Tableau
  20. 80端口被占用的解决办法

热门文章

  1. 计算机桌面右下没有网络连接,笔记本电脑,WIN10系统,右下角没有网络连接..._网络编辑_帮考网...
  2. [Err] 1055 - Expression #1 of ORDER BY clause is not in GROUP BY clause and contains nonaggregated解决
  3. Spring3和Yii示范程序性能初探
  4. 选择云计算机时首要考虑的因素是,用户选择云计算时的首要考虑因素是什么
  5. 龙之谷穿越java游戏_打开次元梦境! 《龙之谷》平行世界大穿越
  6. java 实现邮件的发送, 抄送及多附件
  7. openwrt网上资料
  8. 我的2011 — 珍惜大四美好的时光
  9. 爆点游戏搭建教程H5
  10. 堆和栈最通俗的_堆与栈的区别