稀疏注意力再添一員,華為諾亞推出高效選擇注意力架構(gòu)ESA
AIxiv專欄是機(jī)器之心發(fā)布學(xué)術(shù)、技術(shù)內(nèi)容的欄目。過(guò)去數(shù)年,機(jī)器之心AIxiv專欄接收?qǐng)?bào)道了2000多篇內(nèi)容,覆蓋全球各大高校與企業(yè)的頂級(jí)實(shí)驗(yàn)室,有效促進(jìn)了學(xué)術(shù)交流與傳播。如果您有優(yōu)秀的工作想要分享,歡迎投稿或者聯(lián)系報(bào)道。投稿郵箱:[email protected];[email protected]
當(dāng) DeepSeek 的 NSA 與月之暗面的 MoBA 以稀疏注意力掀起長(zhǎng)序列技術(shù)熱潮,行業(yè)對(duì) “效率革命” 的追逐迎來(lái)關(guān)鍵一躍 —— 華為諾亞方舟實(shí)驗(yàn)室正式發(fā)布全新 ESA 算法(Efficient Selective Attention)。
論文地址:https://arxiv.org/pdf/2502.14477
通過(guò)稀疏化注意力的創(chuàng)新設(shè)計(jì),ESA 突破了大模型在長(zhǎng)文本處理中的瓶頸。ESA 不僅實(shí)現(xiàn)了數(shù)倍序列長(zhǎng)度的拓展,還引入獨(dú)創(chuàng)的動(dòng)態(tài)計(jì)算范式,結(jié)合鄰域影響力有效避免了單純選擇 top-ranked token 所帶來(lái)的性能損失。通過(guò)對(duì)關(guān)鍵 token 的精確選擇,ESA 在優(yōu)化長(zhǎng)序列處理效率的同時(shí),提升了計(jì)算性能,為大模型在長(zhǎng)序列任務(wù)中的應(yīng)用帶來(lái)了新的可能性。
在大語(yǔ)言模型的推理過(guò)程中,長(zhǎng)序列模型的訓(xùn)練需要極高的算力和海量數(shù)據(jù)支持,理想的解決方案是通過(guò)短序列的訓(xùn)練成果外推到長(zhǎng)序列。然而,隨著序列長(zhǎng)度的增加,注意力計(jì)算的復(fù)雜度呈平方級(jí)增長(zhǎng),這使得高效且準(zhǔn)確的長(zhǎng)序列推理成為了一大挑戰(zhàn)。為此,研究人員提出了多種方法,以應(yīng)對(duì)這一挑戰(zhàn)。
ESA 方案正是在這一背景下提出的創(chuàng)新外推解決方案。ESA 通過(guò)對(duì) query 和 key 的低維壓縮,有效減少了 token 選擇的計(jì)算復(fù)雜度。該方案通過(guò)靈活高效地選擇關(guān)鍵 token 進(jìn)行注意力計(jì)算,大幅度降低了 LLMs 在處理長(zhǎng)文本時(shí)的計(jì)算負(fù)擔(dān),且在性能上與全注意力外推方法相當(dāng),甚至在高倍外推場(chǎng)景下優(yōu)于全注意力算法,實(shí)現(xiàn)了上下文長(zhǎng)度的有效擴(kuò)展。
1. 高效外推
當(dāng)大模型訓(xùn)練長(zhǎng)度有限,隨著序列長(zhǎng)度的增長(zhǎng),一方面會(huì)出現(xiàn) OOD (out-of-distribution) 的問(wèn)題,另一方面注意力計(jì)算量會(huì)迅速增大。現(xiàn)有的研究表明,注意力矩陣具有稀疏性,對(duì)于長(zhǎng)序列而言,稀疏程度進(jìn)一步擴(kuò)大。選擇性注意力(Selective Attention)利用了稀疏性這一特性,選擇部分 token 來(lái)計(jì)算注意力,結(jié)合外推的位置編碼能將短序列模型應(yīng)用到長(zhǎng)序列任務(wù)上的同時(shí),顯著降低計(jì)算量。在計(jì)算稀疏注意力時(shí)細(xì)粒度的 token 選擇方法能夠更加靈活、精準(zhǔn)地定位到關(guān)鍵信息。然而,token 粒度選擇會(huì)引入巨大的計(jì)算開(kāi)銷。這引出了一個(gè)核心的問(wèn)題:如何在選擇性注意力方法中平衡靈活性與效率。針對(duì)這一挑戰(zhàn),ESA 方法通過(guò)將 query 和 key 進(jìn)行低維壓縮,顯著降低 token 選擇的計(jì)算復(fù)雜度,在外推場(chǎng)景下實(shí)現(xiàn) token 粒度動(dòng)態(tài)稀疏注意力機(jī)制。
具體而言,ESA 包括以下兩個(gè)核心步驟:
高效選擇:ESA 引入了一種基于 query 感知的 token 粒度選擇機(jī)制,基于壓縮后的 query 和 key 計(jì)算 token 的重要性分?jǐn)?shù),同時(shí)考慮周圍 token 的影響(鄰距影響力),以避免直接選擇 top-ranked token 導(dǎo)致的性能下降。
注意力計(jì)算:在選擇關(guān)鍵 token 后,ESA 使用被選中的 token 的完整的 query 和 key 進(jìn)行注意力計(jì)算,而非對(duì)所有前序 token 進(jìn)行計(jì)算,從而大幅降低復(fù)雜度。
2.ESA:基于 token 粒度的高效選擇性注意力
ESA 的主要?jiǎng)?chuàng)新點(diǎn)在于通過(guò) token 粒度選擇性注意力機(jī)制,在保持模型準(zhǔn)確率的同時(shí)顯著降低計(jì)算復(fù)雜度。具體來(lái)說(shuō),與現(xiàn)有的長(zhǎng)序列外推方法不同,ESA 提出了一種基于 token 的細(xì)粒度選擇注意力,能夠在 prefilling 和 decoding 階段動(dòng)態(tài)選擇最關(guān)鍵的少量 token,而不是固定 block 選擇或者永久丟棄不重要的 token。首先,ESA 將 query 和 key 經(jīng)過(guò)簡(jiǎn)單的一層 MLP 壓縮到原有維度的大約 3.2%,在低維空間計(jì)算重要性分?jǐn)?shù),顯著降低計(jì)算復(fù)雜度;其次,根據(jù)重要性分?jǐn)?shù)選擇 topk 的 token,控制 key 的長(zhǎng)度是固定的,這樣將注意力計(jì)算由原有的平方復(fù)雜度降低為線性復(fù)雜度。雖然選擇 token 是平方復(fù)雜度,但是由于將 query 和 key 壓縮到了更低維的空間,使得對(duì)于算力要求大大降低。
ESA 算法示意圖
ESA 的具體實(shí)現(xiàn)方式如下:輸入序列的 token 被分為 4 部分,注意力包括全局注意力和 window 的局部注意力,初始 token 和 ESA 選擇的 topk 中間 token 拼接起來(lái)計(jì)算全局注意力,localtoken 用于計(jì)算 window 的注意力,兩部分注意力進(jìn)行融合計(jì)算最終的注意力。ESA 按照 chunked-prefill 緩存 key 和 value,即基于當(dāng)前 chunk 的 query 選擇重要的中間 tokens,計(jì)算 token 的重要性時(shí)兼顧當(dāng)前的所有 query;在解碼階段,只需要考慮當(dāng)前的一個(gè) token 的 query 即可。如果計(jì)算中間某個(gè) token 重要性,需要計(jì)算和當(dāng)前所有 token 的重要性,其中單個(gè) token 的重要性用 query 和 key 的點(diǎn)積表示:
這里 H 是 head 的數(shù)量,為了降低復(fù)雜度 ESA 整合了所有的 head。為了進(jìn)一步降低計(jì)算復(fù)雜度,不要求準(zhǔn)確計(jì)算重要性分?jǐn)?shù),而是更關(guān)注相對(duì)大小,ESA 將 query 和 key 分別通過(guò)一層 MLP 進(jìn)行壓縮。ESA 采取 offline 的方式學(xué)習(xí) MLP 的權(quán)重:
ESA 使用一個(gè)小的校準(zhǔn)數(shù)據(jù)集用模型進(jìn)行推理,保存中間的 query、key 和 value,用于訓(xùn)練降維 MLP,只增加了極少量的降低 query 和 key 大小的網(wǎng)絡(luò)權(quán)重,且無(wú)需對(duì)模型微調(diào)。
為了確保分?jǐn)?shù)的相對(duì)大小,避免某個(gè) token 在重要性分?jǐn)?shù)中占據(jù)主導(dǎo)地位,ESA 對(duì)分?jǐn)?shù)進(jìn)行修正:
進(jìn)一步的,作者發(fā)現(xiàn)僅選擇 topk 的 token 模型在大海撈針任務(wù)中只能檢索到部分信息,提出了鄰距影響力的概念,即對(duì)于某個(gè)中間的 token,其重要性分?jǐn)?shù)不僅取決于自身的分?jǐn)?shù),還受到周圍 token 的影響,更新后的分?jǐn)?shù)為:
在選擇完重要 token 后,ESA 使用完整的 query、key 和 value 計(jì)算注意力,最終的注意力輸出如下所示:
ESA 的計(jì)算復(fù)雜度降低主要來(lái)源于低維的 query 和 key 計(jì)算重要性分?jǐn)?shù)以及選擇完成以后的線性注意力計(jì)算復(fù)雜度,經(jīng)過(guò)理論計(jì)算,一步 attention 計(jì)算在長(zhǎng)序列場(chǎng)景下能降低為原有的:
實(shí)際實(shí)驗(yàn)中我們將 query 和 key 壓縮為原有的 3.2%,一步 attention 計(jì)算量在輸入序列足夠長(zhǎng)時(shí)理論能降低至 1.6% 左右。
3. 實(shí)驗(yàn)結(jié)果
論文選擇開(kāi)源訓(xùn)練集 Pile 的 2 條 Books3 樣本收集用于訓(xùn)練降維 MLP 的 qk 樣本,query 和 key 從 4096 壓縮為 128,壓縮比例約為 l3.2%,注意力計(jì)算的窗口長(zhǎng)度約為 6k。為了將開(kāi)源的短序列模型應(yīng)用到長(zhǎng)序列中,ESA 沿用了 Infllm 的外推位置編碼設(shè)置,使用 Llama-3-8B-Instruct 和 Mistral-7B-Instruct-v0.2,在多個(gè)公開(kāi)的長(zhǎng)序列基準(zhǔn)測(cè)試中驗(yàn)證了 ESA 的性能,包括 Longbench、InfiniteBench、NeedleBench 等。作者對(duì)比了 full attention 的外推方法和同類型的基于 window 的外推方法,且同類型方法的 window 長(zhǎng)度一致。實(shí)驗(yàn)結(jié)果表明,ESA 通過(guò)高效靈活選擇重要的 token,總體性能在外推倍數(shù)足夠大時(shí)候優(yōu)于 full attention 的方法,且均明顯優(yōu)于同類型的方法,尤其在 multi needles 檢索場(chǎng)景下例如數(shù)星星和 NeedleBench,在其他同類型方法失效的時(shí)候,ESA 仍然有較高的準(zhǔn)確率。
ESA 不對(duì)每個(gè) head 單獨(dú)選擇 token,而是將所有 head 整合到一起計(jì)算重要性分?jǐn)?shù),有利于降低計(jì)算復(fù)雜度,提升效率,為了驗(yàn)證這一操作對(duì)算法的影響,作者做的對(duì)比實(shí)驗(yàn)如下所示,可以看出這樣的整合對(duì)于算法影響有限。
論文研究了鄰距影響力的超參數(shù)影響,結(jié)果如下所示,對(duì)不同的測(cè)評(píng)集該參數(shù)的影響不同,取值較小有利于 multi needles 類型的檢索任務(wù),取值較大則有利于 single needle 類型任務(wù),這可能是由于單針檢索任務(wù)只需要關(guān)注 ground truth 所在的片段即可,增大鄰距影響力有利于 attention 集中到較長(zhǎng)的片段上。
4. 總結(jié)
ESA 有效平衡了長(zhǎng)序列外推場(chǎng)景下的選擇性注意力中的靈活性和計(jì)算效率,用于在不進(jìn)行模型參數(shù)增量微調(diào)的情況下擴(kuò)展上下文長(zhǎng)度。ESA 的核心思想是在每個(gè)步驟中選擇固定數(shù)量的最重要 token 來(lái)計(jì)算注意力,利用注意力矩陣的稀疏性。當(dāng)輸入序列足夠長(zhǎng)時(shí),ESA 通過(guò)將 query 和 key 壓縮為低維表征,有效降低選擇 token 的計(jì)算復(fù)雜度。實(shí)驗(yàn)評(píng)估表明,ESA 能夠有效處理長(zhǎng)度為訓(xùn)練長(zhǎng)度 4 倍甚至 25 倍的各種長(zhǎng)序列任務(wù)。未來(lái)的研究需要探索更準(zhǔn)確、更高效的選擇重要 token 的方法,以及軟硬件協(xié)同的高效外推方案。
轉(zhuǎn)載請(qǐng)注明來(lái)自浙江中液機(jī)械設(shè)備有限公司 ,本文標(biāo)題:《稀疏注意力再添一員,華為諾亞推出高效選擇注意力架構(gòu)ESA》
百度分享代碼,如果開(kāi)啟HTTPS請(qǐng)參考李洋個(gè)人博客
每一天,每一秒,你所做的決定都會(huì)改變你的人生!
還沒(méi)有評(píng)論,來(lái)說(shuō)兩句吧...