1 研究动机

选择这篇论文来读,有一点奇文共欣赏的意思。 区别于现在主流的框架比拼算力,本文重新思考是不是可以通过lstm 和 单头的attention就可以在现在的数据集上完成大型框架类似的指标。

作者在文章里花了很大的篇幅去讨论,如何会去思考来构建sha-rnn这个模型。他类比了计算机的发展史和摩尔定律,讨论了语言模型和tokern。 作者认为减少缓存,让语言模型的实现可以跑在较低的资源上,不失为一个值得去研究的技术方向,就像计算机发展当年的故事,如果所有的研究都投入在集群和大型机,怎么会有二十世纪末期微机的大行其道。作者认为,即使是transformer已经是主流,也可以继续尝试用lstm + attention,通过精心的设计,仔细的调差,一样可以用显存消耗较小的模型达到较好的效果。

2 研究内容和方法

sha-rnn的设计架构,如下图所示,仔细看其实并没有特别出彩的地方。撇除那些各条路线上的FusedLayerNorm (LN)层,其实架构和transformer是非常接近的。 沿用传统的lstm 而不是算力消耗或者说参数量更大的self attention层。 attention的k,q,v其实均来自lstm的输出,然后依然是类似transformer的旁路设计(残差)。具体可以看源码关于这一块的核心设计。

def forward(self, h, pe, attn_mask, mem=None, hidden=None):new_mem = Noneh = self.lnstart(h)if self.rnn:x, new_hidden = self.rnn(h, None if hidden is None else hidden)#x = self.rnn_down(self.drop(x))# Trim the end off if the size is differentninp = h.shape[-1]z = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)# Divide the hidden size evenly into chunksz = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)# Collapse the chunks through summation#h = h + self.drop(x).sum(dim=-2)x = self.drop(z).sum(dim=-2)#x = x + z.sum(dim=-2)h = h + x if self.residual else x.float()focus, new_mem = None, []if self.attn is not None:mh = self.lnmem(h)h = self.lnmid(h)if mem is not None:bigh = torch.cat([mem, mh], dim=0)else:bigh = mhnew_mem = bigh[-len(pe):]q, k = h, bighx, focus = checkpoint(self.attn, q, k, bigh, attn_mask)#x, focus = tcheckpoint(self.attn, q, k, bigh, attn_mask)x = self.drop(x)h = x + hif self.ff:h, x = self.lnff(h), self.lnxff(h)x = checkpoint(self.ff, x)#x = tcheckpoint(self.ff, h)x = self.drop(x)h = x + hreturn h, new_mem, new_hidden, focus

sha-rnn关于attention的设计,最主要的着眼点还是减少矩阵乘法带来的消耗,从下图可以看出,整个过程其实只有一次的矩阵乘法

3 实验

对于论文的实验 ,我们主要关注 ENWIK8这个数据集,源码中还包含wikitext-2,wikitext-103和PTB等数据集。下图展示 sha-rnn和其他模型的参数对比:

对于sha-rnn训练的实验结果和截图如下:


其实训练的过程,也应用了很多基本的技巧,比如warmup,比如一开始训练(作者建议32个epoch,实际我因为意外大概训练了10个左右,其实bpc和loss基本已经变化很小),我decay一下lr,又先后训练了两个epoch和1个epoch,最后的结果如下:

4 创新点和个人点评

本文其实架构的创新不是特别大,但是思路其实有可取之处,特别是坚持保留主流之外其他架构设计的可能性,是非常值得我们研究者学习的一种精神。而且,作者的代码,有大量的工程和试验的部分,都是值得学习和借鉴的,比如boom层的设计中的切块。 最后,其实,文章还有很多的细节,我后续读参考文献及其代码,会补充或者单开文章来写,比如作者用的优化器LAMB,以及英伟达的混合精度和分布式训练的库APEX,当然作者提到的tokenization attack也待补充。

Single Headed Attention RNN: Stop ThinkingWith Your Head 论文笔记相关推荐

  1. Single Headed Attention RNN: Stop Thinking With Your Head

    这篇论文的语言真的很有趣,很多地方我翻译的不是很好大家谅解,可以尽量欣赏原文! Abstract 语言建模的主要方法都是痴迷于我年轻时的电视节目--变形金刚和芝麻街.我们选择了老办法和经过验证的技术, ...

  2. CVPR 2020 HAN:《Hypergraph Attention Networks for Multimodal Learning》论文笔记

    目录 简介 动机 贡献 方法 实验 简介 本文提出了一种用于多模态学习的超图注意力网络,作者来自Kakao公司和首尔大学. Kakao公司的主要产品是Kakao talk,类似于国内的微信,且腾讯是其 ...

  3. 【显著性物体检测】【ECCV2018】Reverse Attention for Salient Object Detection【论文笔记】

    简介:在不怎么增加计算量的前提下,采用从粗到精的思想,由高级特征到低级特征,补全显著性检测的轮廓[最近很多都是基于这个思想].模型的速度与效果都占优.具体关注,是怎么实现特征的多级利用的. ECSSD ...

  4. Spatially and Temporally Efficient Non-local Attention Net work for Video-based Re-Id 论文笔记

    作者的目标非常明确(刷分),利用注意力机制,首次将Mars数据集的rank-1突破90难关. Abstract 在神经网络中利用注意力机制来学习图像特征是近几年来比较流行的方法,同样地,也适用于视频序 ...

  5. 《Single Image Depth Prediction with Wavelet Decomposition》论文笔记

    参考代码:wavelet-monodepth 1. 概述 导读:对一幅深度图进行分析可以观察到其是由一些平滑区域和边缘区域组合起来的,对应的可以参考频域中的低频和高频分量.而这篇文章正是提出一种基于频 ...

  6. 【论文笔记】Dynamic Convolution: Attention over Convolution Kernels

    Dynamic Convolution: Attention over Convolution Kernels,CVPR2020 论文地址:https://openaccess.thecvf.com/ ...

  7. 【论文笔记】Neural Relation Extraction with Multi-lingual Attention

    一.概要   该paper发于ACL2017上,作者主要基于关系事实通常在各种语言中存在某种模式表达,并且不同语言之间的模式是不同的这两个动机,针对于当前存在的单语言关系抽取的方法,从而存在忽略不同语 ...

  8. 论文笔记2:Deep Attention Recurrent Q-Network

    参考文献:[1512.01693] Deep Attention Recurrent Q-Network (本篇DARQN) [1507.06527v3] Deep Recurrent Q-Learn ...

  9. Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion论文笔记

    CVPR2021论文笔记 题目:Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Mod ...

最新文章

  1. 使用 python 的单人AI 扫雷游戏
  2. linux shell awk BEGIN END 处理文本之前之后执行操作 简介
  3. RESTORE DATABASE的standby选项
  4. [电脑问题解决]在windows 8.1升级后,电脑重启时不显示ubuntu的系统引导界面,而总是直接进入windows
  5. MX250和MX350哪个好一点,区别和差距在哪里?
  6. Block实现iOS回调
  7. mysql为什么要重建索引_MySQL表索引为什么会遭破坏?
  8. dcp9030cdn定影_兄弟DCP-9030CDN驱动下载
  9. 微软快捷键截图_如何在Microsoft Office的屏幕提示中显示快捷键
  10. 《Beyond Part Models: Person Retrieval with Refined Part Pooling 》PCB论文解读
  11. 金蝶云苍穹开发实用整理
  12. 海康硬盘录像机 rtsp/onvif 视频配置
  13. 计算机无法用630打印机,爱普生LQ-630型针式打印机突然打印几行乱码,就不工作了...
  14. 中国人民银行面试题目(经典题目2)
  15. Linux-C C语言编译过程
  16. bitcoinj生成中文助记词
  17. 盘点北京周边最适合爬的10座山
  18. php 没有后缀名下载,javascript - 没有后缀名的链接?
  19. 【不想读paper的时候看看】阅读文献?
  20. MVP模式 项目练习 Pas --新闻 音乐 图片 三个模块

热门文章

  1. java 正则 小数_java用正则表达式判断是否是小数的方法
  2. Towards Fast, Accurate and Stable 3D Dense Face Alignment(3DDFA_V2)论文与项目学习
  3. [kubernetes]-filebeat以sidecar模式收集pod日志
  4. 华蓥计算机培训机构,华蓥考研专业课培训班
  5. 简单利用C语言 解决停车场管理问题
  6. pivot的用法,一看就会
  7. [论文笔记]Swarm of micro flying robots in the wild
  8. 在亚马逊上你知道怎么定价-跨境知道
  9. 剪贴板增强工具Ditto
  10. 社会责任审核-安全出口