一、前言
大模型(參數(shù)規(guī)模通常數(shù)十億至萬(wàn)億級(jí))在處理復(fù)雜任務(wù)時(shí)面臨三大核心問(wèn)題:
顯式關(guān)聯(lián)的局限性:傳統(tǒng) Multi-head Attention 依賴(lài)輸入數(shù)據(jù)的顯式特征(如文本中的詞向量、圖像中的像素特征)計(jì)算注意力,難以捕捉深層語(yǔ)義(如 “同義詞替換”“上下文隱喻”)或抽象結(jié)構(gòu)(如 “邏輯推理鏈”)。
數(shù)據(jù)效率與泛化瓶頸:大模型訓(xùn)練需海量數(shù)據(jù),但在低資源語(yǔ)言、專(zhuān)業(yè)領(lǐng)域(如醫(yī)學(xué)、法律)中,顯式關(guān)聯(lián)數(shù)據(jù)稀缺,導(dǎo)致模型泛化能力驟降。
多模態(tài)融合難點(diǎn):跨模態(tài)任務(wù)(如圖文生成、視頻理解)中,不同模態(tài)的特征空間差異大(如文本的離散符號(hào) vs 圖像的連續(xù)像素),顯式關(guān)聯(lián)(如 “圖像中的貓” 與文本 “貓”)之外的隱式關(guān)聯(lián)(如 “圖像風(fēng)格” 與 “文本情感”)難以建模。
在前面的文章中,筆者已經(jīng)講解了 LLM 推理的關(guān)鍵技術(shù)-KV Cache(【手撕大模型】KVCache 原理及代碼解析),但是隨著大模型功能的不斷強(qiáng)化,其容量也在增加,當(dāng)前的 KVCache 技術(shù)已經(jīng)不能滿(mǎn)足發(fā)展需要了,所以,各種針對(duì)于 KVCache 優(yōu)化的技術(shù)應(yīng)時(shí)而生。
二、優(yōu)化 KV cache 的方法
參考 https://zhuanlan.zhihu.com/p/16730036197
當(dāng)前,業(yè)界針對(duì) KV Cache 的優(yōu)化方法可以總結(jié)為有四類(lèi):
共享 KV:多個(gè) Head 共享使用 1 組 KV,將原來(lái)每個(gè) Head 一個(gè) KV,變成 1 組 Head 一個(gè) KV,來(lái)壓縮 KV 的存儲(chǔ)。代表方法:GQA,MQA 等。
窗口 KV:針對(duì)長(zhǎng)序列控制一個(gè)計(jì)算 KV 的窗口,KV cache 只保存窗口內(nèi)的結(jié)果(窗口長(zhǎng)度遠(yuǎn)小于序列長(zhǎng)度),超出窗口的 KV 會(huì)被丟棄,通過(guò)這種方法能減少 KV 的存儲(chǔ),當(dāng)然也會(huì)損失一定的長(zhǎng)文推理效果。代表方法:Longformer 等。
量化壓縮:基于量化的方法,通過(guò)更低的 Bit 位來(lái)保存 KV,將單 KV 結(jié)果進(jìn)一步壓縮,代表方法:INT8/INT4 等。
計(jì)算優(yōu)化:通過(guò)優(yōu)化計(jì)算過(guò)程,減少訪存換入換出的次數(shù),讓更多計(jì)算在片上存儲(chǔ) SRAM 進(jìn)行,以提升推理性能,代表方法:flashAttention 等。
共享 KV 主要有兩種方法,MQA 和 GQA 都是 Google 提出的,詳見(jiàn): MQA(2019),GQA(2023)。
三、MQA &
MQA(多查詢(xún)注意力)和 GQA(分組查詢(xún)注意力)作為自注意力機(jī)制的優(yōu)化版本,主要作用是加快推理進(jìn)程、減少內(nèi)存占用,同時(shí)努力維持模型原有的性能表現(xiàn)。
以 Llama 7B 模型為例,其隱藏層維度為 4096,這意味著每個(gè) K、V 向量都包含 4096 個(gè)數(shù)據(jù)。若采用半精度浮點(diǎn)(float16)格式存儲(chǔ),單個(gè) Transformer 模塊中,單序列的 K、V 緩存空間就達(dá)到 4096×2×2=16KB。由于 Llama 2 包含 32 個(gè) Transformer 模塊,單個(gè)序列在整個(gè)模型中的緩存需求便為 16KB×32=512KB。
那么多序列的情況呢?倘若句子長(zhǎng)度為 1024,緩存空間就會(huì)增至 512MB。目前英偉達(dá)性能頂尖的 H100 顯卡,其 SRAM 緩存約為 50MB,A100 則為 40MB,顯然難以滿(mǎn)足需求。盡管可將數(shù)據(jù)存于 GPU 顯存(DRAM),但會(huì)對(duì)性能產(chǎn)生影響。7B 規(guī)模的模型已是如此,175B 規(guī)模的模型面臨的問(wèn)題更嚴(yán)峻。
解決這一問(wèn)題的思路可從硬件與軟件兩方面展開(kāi):
硬件層面,可借助 HBM(高帶寬內(nèi)存)提高數(shù)據(jù)讀取速度;或擺脫馮?諾依曼架構(gòu)的束縛,改變計(jì)算單元從內(nèi)存讀取數(shù)據(jù)的方式,轉(zhuǎn)而以存儲(chǔ)為核心,構(gòu)建計(jì)算與存儲(chǔ)一體化的 “存內(nèi)計(jì)算” 模式,例如采用 “憶阻器” 技術(shù)。
軟件層面則通過(guò)算法優(yōu)化來(lái)解決,Llama 2 所采用的 GQA(分組查詢(xún)注意力)便是其中一種方案。
下面將通過(guò)圖示來(lái)展示 MQA、GQA 與傳統(tǒng) MHA(多頭注意力)的差異:
多頭注意力機(jī)制(MHA)就是多個(gè)頭各自擁有自己的 Q,K,V 來(lái)算各自的 Self-Attention,而 MQA(Multi Query Attention)就是 Q 依然保持多頭,但是 K,V 只有一個(gè),所有多頭的 Q 共享一個(gè) K,V ,這樣做雖然能最大程度減少 KV Cache 所需的緩存空間,但是可想而知參數(shù)的減少意味著精度的下降,所以為了在精度和計(jì)算之間做一個(gè) trade-off,GQA (Group Query Attention)孕育而生,即 Q 依然是多頭,但是分組共享 K,V,即減少了 K,V 緩存所需的緩存空間,也暴露了大部分參數(shù)不至于精度損失嚴(yán)重。
四、MQA
MQA 的思路比較簡(jiǎn)單,詳見(jiàn)上圖,每一層的所有 Head,共享同一個(gè) KV 來(lái)計(jì)算 Attention。相對(duì)于 MHA 的單個(gè) Token 需要保存的 KV 數(shù)減少了 n_h 倍(head 數(shù)量),即每一層共享使用一個(gè) Q 向量和一個(gè) V 向量。
使用 MQA 的模型包括 PaLM、StarCoder、Gemini 等。很明顯,MQA 直接將 KV Cache 減少到了原來(lái)的 1/n_h,這是非??捎^的,單從節(jié)省顯存角度看已經(jīng)是天花板了。
效果方面,目前看來(lái)大部分任務(wù)的損失都比較有限,且 MQA 的支持者相信這部分損失可以通過(guò)進(jìn)一步訓(xùn)練來(lái)彌補(bǔ)回。此外,注意到 MQA 由于共享了 K、V,將會(huì)導(dǎo)致 Attention 的參數(shù)量減少了將近一半,而為了模型總參數(shù)量的不變,通常會(huì)相應(yīng)地增大 FFN/GLU 的規(guī)模,這也能彌補(bǔ)一部分效果損失。
五、GQA
GQA 是平衡了 MQA 和 MHA 的一種折中的方法,不是每個(gè) Head 一個(gè) KV,也不是所有 Head 共享一個(gè) KV,而是對(duì)所有 Head 分組,比如分組數(shù)為 g ,那么每組: n_h/g 個(gè) Head 共享一個(gè) KV。當(dāng) g=1 時(shí),GQA 就等價(jià)于 MQA,當(dāng) g=n_h 時(shí), GQA 就等價(jià)于 MHA。
為了方便更清晰的理解 GQA 和 MQA ,使用一個(gè) Token 計(jì)算 KV 過(guò)程來(lái)進(jìn)行演示:
總結(jié)下單 token 計(jì)算下,幾種方法 KV Cache 的存儲(chǔ)量(模型層數(shù):l,每層 Head 數(shù)量:n_h )
六、參考鏈接
https://zhuanlan.zhihu.com/p/16730036197
六、參考鏈接
https://zhuanlan.zhihu.com/p/16730036197
https://spaces.ac.cn/archives/10091
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。