首页 文章详情

Transformer太大了,我要把它微调成RNN

极市平台 | 1303 2021-04-12 00:00 0 0 0
UniSMS (合一短信)
↑ 点击蓝字 关注极市平台

作者丨炼丹学徒
来源丨夕小瑶的卖萌屋
编辑丨极市平台

极市导读

 

本文介绍了来自一篇微软的论文:Finetuning Pretrained Transformers into RNNs,在保持性能的情况下,将预训练好的Transformer模型微调到其RNN变体,极大地降低显存使用和计算开销。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

从前车马很慢,显卡跑的也慢,一生只够爱一个RNN。后来时代进步了,数据量和计算力阔绰了,堆叠起来的Transformer能够在更深更宽的模型结构里吃下去更多的数据。从19年的预训练浪潮开始,暴力美学兴起,更深的Transformer更久的预训练更大的模型参数量,暴力出奇迹一个个NLP榜单被刷新,但谁又记得起来当初Transformer论文里“解决RNN无法并行化训练问题”的追求效率的motivation呢?身在普通高校,手握2080Ti和Titan V,向着大厂的预训练模型望洋兴叹,我们开始怀念起当初人人都训练得起的LSTM和GRU。那是精巧轻量的模型,那是人人都刷的起SOTA的时代。

今天这篇来自微软的论文告诉我们,大厂里有一些研究员也还是爱我们的,Finetuning Pretrained Transformers into RNNs,在保持性能的情况下,将预训练好的Transformer模型微调到其RNN变体,极大地降低显存使用和计算开销。

论文题目:  Finetuning Pretrained Transformers into RNNs

论文链接:  https://arxiv.org/abs/2103.13076_Arxiv

本文提出的模型名为 T2R,代表 Transformer to RNN 。转换的过程为 swap-then-finetune ,即,对于一个预训练好的 Transformer 模型,我们将其 的注意力计算改为线性 的替换模块,然后进行微调。可以预感到,其核心就在于如何用线性的子层对注意力层进行模拟。接下来,我们对其进行详解。

概述

在2019年EMNLP论文 Transformer Dissection [1] 中,作者提出:可以将注意力层的相似度计算()替换为核函数的分数

ICML'20的另一工作_Transformers are RNNs_ [2]则在此基础上进一步优化,提出了将的注意力计算替换为线性的模块

今天要讲的 T2R 这篇文章是紧随上面 ICML'20 这篇工作进行的。之前 Transformers are RNNs 的方法中,使用的核函数没有参数,不可训。而 T2R 把核函数里封装了一个MLP变成可训练的。T2R原文的推导直接使用了 Transformers are RNNsTransformer Dissection 的结论,因而推导过程并不完整。我们今天也沿着T2R的思路进行讲解,如果想要更深入了解 Transformer 转 RNN 领域的,可以阅读下面两篇论文:

[1] Tsai et al. Transformer Dissection: A Unified Understanding of Transformer's Attention via the Lens of Kernel. EMNLP 2019 [2] Katharopoulos et al. Transformers are RNNs: Fast autoregressive transformers with linear attention. ICML 2020

Transformer开销

Transformer 由多头注意力层、前馈层、层归一化层堆叠后组成。本篇论文中要替换的,就是其中的多头注意力层。

在开始讲解如何替换之前,我们还是先梳理一下传统Transformer的多头注意力层。整个计算过程可以总结如下图所示:

▲传统Transformer的多头注意力层计算过程

这张图我们自下往上看。首先,我们将多头注意力层的source隐状态记作,target隐状态记作。

如何理解此处的source和target: 比如,在解码器的编码器-解码器注意力层中,就是编码器端的序列长度,就是解码器端的长度。在自回归推断的解码器自注意力层中,就是已生成序列(加上自己)的长度,等于1,指当前要预测的这个字符。

从隐状态,我们通过线性变换得到。则,注意力层的输出为:

其中, 操作 旨在计算和的相似度(这里划重点!等一会儿就要对这个计算动手脚了!):

上述的多头注意力的计算是我们熟知的。论文对其复杂度进行了分析。设多头数为,每个头的隐状态长度,每个的隐状态总长 ,则有如下结论:

  • 特征计算:即由隐状态计算得到的过程,复杂度分别为 , 和
  • 注意力计算: 由 计算得到最终输出的过程,复杂度为 ,与 的长度成平方关系。
  • 推断时的显存:,与已经解码的长度线性相关。

注意力层的RNN替代方案

T2R的注意力层计算过程则如下图所示:

首先,我们注意到原始的注意力计算中, 和 的相似度计算方式()需要先进行点乘,放缩后再进行指数运算,难以开展后续的近似优化。所以这里的关键之处就在于,T2R把的相似度计算方案替换为核函数的乘积

此处,和的参数都是通过一个单层MLP学习得到的。是维矩阵,是维bias向量,即,T2R的相似度计算核函数将原本维的向量降到了维然后进行相似度计算。对于多头计算中的每一个头,他们的和是独立学出来的。因此,T2R在每一层中,共增加了个可学习的参数(小于总参数量的2%)。

我们把新的相似度计算方法代入到注意力的输出式中,得到:

记,,则:

而根据 Transformers are RNNs [2] 的结论,此处的可以视作RNN递归的隐状态。比如,在解码器端做自回归生成时,每个词向它前文的单词进行注意力计算来预测下一个词,和可以被定义为递归的隐状态:

注意到我们主要讨论的函数是针对来计算相似度的,而是由喂入该层的隐状态线性变化得到的。为了加速推断速度,具体实现中把和代入,得到从隐状态,直接线性变换得到的结果,从而在推断的时候不需要计算,而从隐状态直接计算得到相似度的值,即:

其中,

此时的开销:

  • 特征计算:我们记输出维的特征向量,则生成的复杂度为 , 和
  • 注意力计算: 由计算得到最终输出的过程,假设k<<M,N,此时复杂度为,与的长度成线性关系。
  • 推断时的显存:假设k<<M,则占用显存,为常数。

Transformer和T2R对比

讲到这里,我们再对比一下传统Transformer和T2R的差异:

  • 特征计算:计算不变,计算由, 降为,
  • 注意力计算: 由降为,平方->线性。
  • 推断时的显存:由降为,线性->常数。

实验

数据集的效果

T2R主要使用ELU和RFA作为baseline进行比较。ELU和RFA为此前的另外两篇使用核函数转Transformer为RNN工作。因为ELU和RFA的核函数都是不可训练的,所以无法取代预训练好的模型里的注意力层进行功能上的替换和拟合。

首先,T2R在语言模型上开展了实验。数据集使用WikiText-103,评测指标使用困惑度 perplexity 。发现T2R因为在核函数中放置了可训练的MLP,在加载预训练模型时获得更大的收益。

此外,T2R在翻译任务上开展实验,使用数据集 WMT14 EN-DE,WMT14 EN-FR 和 WMT17 ZH-EN。研究员们发现虽然随机初始化时,T2R弱于另外两个baseline,但是加载预训练后反超另外两个baseline。

生成时的加速和显存节省

研究员发现 T2R 比另外两个模型的推断速度更快(如下左图所示),因为使用了更小的特征维度,以及更快的特征计算方法。对于推断时的显存占用,Transformer 随着输出序列的增长而线性增加,转为 RNN 结构的模型则保持常数(如下右图所示)。

消融实验

随着核函数输出特征尺寸的增大,其效果也更加接近Transformer。相比于之前的工作,T2R 可以通过控制特征尺寸从而在效果和速度间权衡。

小结

本文提出的T2R,在 Transformers are RNNs 的基础上,将无参数的核函数封装为 MLP 加激活函数,从而可训练。在此基础上,T2R 替换掉预训练 Transformer 的注意力层,从而降低了计算消耗和显存使用,并且得到和原预训练模型相似的结果。

推荐阅读


解决训练不稳定性,何恺明团队新作来了!自监督学习+Transformer=MoCoV3

2021-04-06

一个 Transformer,很强;两个,更强?

2021-04-01

站在CNN肩膀上的巨大腾飞,Swin Transformer:实现对各类SOTA的降维打击

2021-03-29



# CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart2)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~


△点击卡片关注极市平台,获取最新CV干货

觉得有用麻烦给个在看啦~  
good-icon 0
favorite-icon 0
收藏
回复数量: 0
    暂无评论~~
    Ctrl+Enter