众所周知,现在transformer及其变种是NLP和CV领域已经杀疯了。但其中最核心的self-attention机制因为其O(N2)的时间复杂度(二次依赖问题)被诟病。

在不改变transformer block这个整体架构的前提下,现在学术界解决二次依赖问题的主要是两个思路。一种是实现self-attention的线性化。这方面的工作是很多的,比如Performer[5]、Reformer[6]、Linformer[7]、Nyströmformer[9]、AdaMRA[10]等。关于这部分工作更多的内容大家可以在苏剑林的博客中了解到[8].虽然关于线性attention的工作很多,但参考AdaMRA[10]论文的图。只有Nyströmformer[9]和AdaMRA[10]相较于Transformer能获得速度和效果的双重提升,其他的大多需要付出效果的代价才能获取一定的速度提升。但就是这哥俩由于用了平均池化作为特征聚类,因此无法mask未来信息从而丧失了自回归的能力。因此通过替换线性attention从而提升transformer速度这一思路是必须付出代价的。

另一种思路将self-attention换成其他线性复杂度的部件。比如前段时间谷歌发现用膨胀卷积取代self-attention也能取到不错的效果[1]。而在CV领域杀疯的MLP-Mixer[2],兼具CV和NLP能力的gMLP、aMLP,[3]MLP-Mixer的NLP版本Synthesizer[4]。但都有或多或少的缺点,就比如Synthesizer和gMLP在NLP领域相较于self-attention还是差了点的。而aMLP虽然效果好了吧,但其实还是要用到self-attention,提速的目的还是没达到。不过今年暑假那会,苹果提出的AFT模型[11]号称自己是最快的transformer模型。

上述是标准AFT的公式,其中σ是sigmoid函数,QKV就是sefl-attention的那一套,w是一个训练出来的参数矩阵。不难看出AFT是通过点乘的方式实现的注意力,在做自回归时只需要对W矩阵进行mask即可。并且W矩阵是自带位置信息的,不仅解决了部分线性attention不能做自回归的问题,还顺便把transformer里位置编码的问题给解决了。可以说AFT实现了一举三得。但成也萧何败也萧何,W矩阵是AFT成功的核心也是AFT的最大缺点。一般来说W应该是一个[max_len,max_len]大小的方阵。换而言之AFT所能处理的文本长度受限于W矩阵的大小,如果想要处理一万字的长文本,W矩阵的参数量就快赶上Bert了。为了解决这个问题,下面该本文的主角RWKV出场了。RWKV的原文在RWKV is all you need?一种新语言模型,改进 Transformer - 知乎,不过原文实在过于简短了不便阅读和理解。因此笔者写了此文介绍一下RWKV是怎么实现鱼和熊掌兼得的。

RWKV

整体结构 RWKV的整体结构依然采用的是transformer block的思路,其整体结构如图所示。相较于原始transformer block的结构,RWKV将self-attention替换为Position Encoding和TimeMix,将FFN替换为ChannelMix。其余部分与transfomer一致的。

Position Matrix RWKV采用的位置编码类似于AliBi编码[12]的形式。原文作者并没有给他的位置编码命名,为了便于介绍参考该位置编码主要考虑距离衰减的特性,本文将其命名为distance编码。对于第i个head的第j个token而言,其位置编码如下述公式所示。其中nhead表示头的数量,max_len表示为所允许的最大长度。

目前学术界的主流观点是RNN结构是天然的时序结构,不需要transformer模型必须的位置编码。而如果我们查看RNN的计算流程,可以发现RNN只考虑到当前token及之前的信息,而随着距离的延长前面的信息会逐渐减少。而distance位置编码便是参考RNN时序特点所设计的。

不过RWKV模型中,不会直接对输入的X进行上述计算。而是得到类似AFT中的W矩阵参与后续Time-Mix计算。其中W矩阵的形状为[n_head,seq_len,seq_len]。因此对于W矩阵中的而言,其数值如下述公式所示。

从这里不难看出,AFT中的W矩阵在RWKV中是通过公式得到而不是训练得到的,因此解决了AFT中无法解决长文本,或解决长文本时参数爆炸的问题。

当然,在处理的任务文本长度有限的情况下。比如机器翻译,或者是RWKV目前应用的ai写小说这类应用场景。在这类应用场景中,由于不会面临长文本的情况,因此可以为W矩阵添加更多的位置信息。参考公式如下

其中和分别为形状[n_head,seq_len,1]和[n_head,1,seq_len]的向量,在初始化时为全1矩阵。即将作为W矩阵的初始化。结合该步后,在形式上W矩阵融合了distance编码中的距离信息与相对信息。

值得注意的是,原作者是设计distance编码时专门设计了一个不考虑位置信息衰减的头。即该头的W矩阵是一个全一的下三角矩阵。

Time-shit 在介绍TimeMix之前,要先介绍一下RWKV所使用的Time-shit技巧。

原文:Time-shift: 一行代码,免费提高 Transformer 性能(无参数,无耗时) - 知乎

Time-shiit是原作者提出的一种几乎零成本提升模型效果的trick,实现代码如下所示。

Torch实现
C=x.shape[-1]
self.time_shift = nn.ZeroPad2d((0,0,1,0))
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
Keras实现
d=K.shape(x)[-1]
x=K.concatenate(K.temporal_padding(x,(1,0))[:,:-1,:d//2],x[:,:,d//2:])

可以看出不论哪个框架也就两行就能实现了,为了便于读者理解。假设存在一个3x4的矩阵。

在经过time-shift后变为

其实就相当于插入一个小的RNN,实验表明简单的trick能让模型的更快更好地收敛。

TimeMix TimeMix是RWKV中用于代替self-attention的部分,基于AFT的基础上做出改进兼具了线性的速度和较好的性能。在进行该步前,需要对输入的x进行time-shift。

同self-attention中的QKV矩阵一样,RWKV中也有对应的RKV矩阵。对与输出矩阵中第i个头的第j个token而言计算步骤如下所示。

这其中是一个[hiden_size,hiden_size]大小的方阵,与常规attention一样用于最后的输出。而是一个[seq_len,hiden_size]大小的矩阵,其作用笔者猜测应该是类似于bias的作用。

ChannelMix ChannelMix 是RWKV中用于替代FFN的部分。类似于tiny attention之于attention。ChannelMix本质上来说是一个tiny TimeMix。

在进行该步计算前,和TimeMix一样要先进行一次time-shift。随后依然要计算出RKV矩阵和W权重。不过有所不同的是在这一步中假设输入x的维度是embed_size,则R的维度应和X相同。KV的维度是用户所自定义的hidden_size,W的形状为[hidden_size,embed_size].

通过设置较小的hidden_size可以实现一个tiny版TimeMix,能在对性能影响较小的情况下实现提速。当hidden_size==embed_size时,可以看作一个不考虑位置信息和归一化的TimeMix或者看作点乘式的FFN。

具体计算公式如下所示

总结 本文介绍了一种鱼和熊掌兼得的模型。既能和AFT一样兼具通用性和高效,distance位置编码的设计使得模型也具备面对超长文本的能力。

实际实验效果可以去看原文的内容,本文只对其结构进行介绍。但总体而言,笔者测试过基于GPT的ai写小说和基于RWKV的ai写小说。相比较而言,RWKV的写出来的文章会更流畅,并且在训练时收敛速度页更快。

参考文献

[1] Are Pre-trained Convolutions Better than Pre-trained Transformers https://arxiv.org/pdf/2105.03322.pdf

[2] MLP-Mixer: An all-MLP Architecture for Vision https://arxiv.org/pdf/2105.01601.pdf

[3] Pay Attention to MLPs https://arxiv.org/pdf/2105.08050.pdf

[4] Synthesizer: Rethinking Self-Attention in Transformer Models https://arxiv.org/abs/2005.00743

[5] Rethinking Attention with Performers https://arxiv.org/abs/2009.14794

[6] Reformer: The Efficient Transformer https://arxiv.org/abs/2001.04451

[7] Linformer: Self-Attention with Linear Complexity https://arxiv.org/abs/2006.04768

[8] 线性Attention的探索:Attention必须有个Softmax吗? https://spaces.ac.cn/archives/7546

[9] Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention https://arxiv.org/abs/2102.03902

[10] Adaptive Multi-Resolution Attention with Linear Complexity https://arxiv.org/abs/2108.04962

[11] An Attention Free Transformer https://arxiv.org/abs/2105.14103

[12] Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

https://arxiv.org/abs/2108.12409




RWKV:一种鱼和熊掌兼得的线性transformer模型 - 知乎

https://www.youtube.com/watch?v=oaP8_fUFVWw

RWKV:一种鱼和熊掌兼得的线性transformer模型相关推荐

  1. 鱼和熊掌兼得——解密阿里云PCDN如何实现高质量低价格

    在内容分发领域,之前CDN产品一直占据着主导地位,CDN的质量好,但价格偏高,正应了那句老话"一分价钱一分货".质量和价格似乎是矛盾体,鱼和熊掌不可兼得.随着云计算的高速发展,不断 ...

  2. python gui web_一篇让你大开眼界的Python教程:让Web和GUI鱼和熊掌兼得

    本期教程给大家更新绝对让你大吃一惊的效果-web嵌入GUI 随着Web技术的蓬勃发展,以网页形式在浏览器上显示图表已经逐渐成为一种主流的形式. 网页的实现是由HTML.CSS和Javascript三者 ...

  3. 生活,又怎能鱼和熊掌兼得?

    ❣️ 身在井隅,心向璀璨| 第139篇 如何成为一名有钱的程序员,一直是我在思索的问题. 老婆说,"要脚踏实地,不要总是在做梦." 事实上,我也的确是在"脚踏实地&quo ...

  4. 鱼和熊掌兼得!这些应用是如何使用 Material Design 的?

    在任何产品团队看来,应用构建的速度.美学还有可用性常常是不可兼得的.但是现在,凭借最新版本的 Material Design,团队可以在不牺牲质量的情况下 "兼得鱼和熊掌".自从我 ...

  5. 云端容灾演练,如何鱼和熊掌兼得?

    在数字化进程不断加快的今天,一个优秀的组织或企业都有一套优秀的灾备系统,而一套优秀的灾备系统一定也有一个与之匹配的灾备演练机制. 两千多年前,有个国王因为拿灾备演练当成儿戏,最后不仅丢了性命,还把江山 ...

  6. DevOps与合规性:鱼和熊掌兼得指南

    本文转自微信号EAWorld.扫描下方二维码,关注成功后,回复"普元方法+",将会获得热门课堂免费学习机会!本文转自微信号EAWorld. 编者按:很多行业身处强力监管领域,因而格 ...

  7. 机房租赁,如何鱼和熊掌兼得?

    对于企业而言,选择一个合适的机房来托管服务器,是一件需要深思熟虑的事情.机房的重点参数和运营商的服务能力就如鱼和熊掌,企业在选择的时候自然希望兼得.那么,关注以下几点,会帮助您选择到一个更合适的机房. ...

  8. 新华三 VDI java,鱼和熊掌兼得:新华三vGPU云桌面方案革新VDI性能体验

    在过去的很长时间里,虚拟桌面基础架构(VDI)在企业办公环境中长短互现,其优势体现在海量部署时的可管理性.安全性.成本及能效方面,但面对一些特殊的应用场景,例如3D渲染.CAD.视频编辑等,VDI往往 ...

  9. 鱼和熊掌兼得:同时使用 JPA 和 Mybatis

    前言 JPA 和 Mybatis 的争论由来已久,还记得在 2 年前我就在 spring4all 社区就两者孰优孰劣的话题发表了观点,我当时是力挺 JPA 的,这当然跟自己对 JPA 熟悉程度有关,但 ...

  10. 解《鱼和熊掌不可兼得》

    鱼和熊掌不可兼得 以下言论是本人扮演角色的个人言论,与本人无关 想必各位肯定知道"鱼和熊掌不可兼得"这件事,当年上学的时候学了而已,看了看译文,如果说有人问你:鱼和熊掌为什么不可兼 ...

最新文章

  1. dwarf tower
  2. 排序算法c语言和oc实现的,几种常用的排序算法,OC实现
  3. 【死磕 Spring】—– IOC 之 Factory 实例化 bean
  4. helm3添加harbor仓库:带鉴权--username --password
  5. 体验 ASP.NET Core 1.1 中预编译 MVC Razor 视图
  6. Qt编写echart仪表盘JS交互程序支持webkit和webengine(开源)
  7. 重修 mongoDB 系列(一) 配置环境
  8. struts1基础入门
  9. 为什么以太网有最短帧长度的要求_线束工程师:车载以太网介绍
  10. hbase的region分区
  11. 用python写Hello World
  12. 学年总结(2015-2016学年回顾)
  13. Unicast与Multicast
  14. 怎样测试手机性能软件,如何检测手机性能的软件
  15. 病毒全攻略:我是怎样让你感冒的
  16. 用Python的matplotlib绘制残差分析散点误差棒图
  17. 济南江苏商会成立 全国工商联·万祥军:商协社团厚德聚苏商
  18. 使用sil9233a芯片控制海思hi3531d的hdmi输入
  19. 在Linux中网络性能审计,安全以及排错
  20. DZ插件 [1314]模板自由切换 2.0.1版

热门文章

  1. UI 的设计, 就是精雕细琢, 然后简洁明了
  2. 【电磁泄漏还原】通过matlab实现电磁泄漏信号的还原
  3. 益聚星荣:国人纷纷“逃离”支付宝,“投奔”银行,不只是因为利率降了?
  4. php eof bof,bof或者eof中有一个是真,或者当前的纪录已被删除,所需的操作要求...
  5. html 内容的高度设定,让文字不超过固定宽度和高度怎么设置?
  6. 1077 皇宫看守(树形dp)
  7. 艾美捷内毒素纯化树脂说明书
  8. Android:ComponentCallbacks/ComponentCallbacks2与glide
  9. WoShop国内多商户商城系统直播短视频分销积分秒杀100%无加密
  10. EAS BOS 8.0 样板工程(成品检验取样单)