日本a√视频在线,久久青青亚洲国产,亚洲一区欧美二区,免费g片在线观看网站

        <style id="k3y6c"><u id="k3y6c"></u></style>
        <s id="k3y6c"></s>
        <mark id="k3y6c"></mark>
          
          

          <mark id="k3y6c"></mark>

          "); //-->

          博客專欄

          EEPW首頁 > 博客 > 手撕大模型|KVCache 原理及代碼解析

          手撕大模型|KVCache 原理及代碼解析

          發(fā)布人:地平線開發(fā)者 時間:2025-09-13 來源:工程師 發(fā)布文章

          在大型語言模型(LLM)的推理過程中,KV Cache 是一項關(guān)鍵技術(shù),它通過緩存中間計算結(jié)果顯著提升了模型的運行效率。本文將深入解析 KV Cache 的工作原理、實現(xiàn)方式,并通過代碼示例展示其在實際應(yīng)用中的效果。

          一、為什么需要 KV Cache?

          在 Transformer 進(jìn)行自回歸推理(如文本生成,每次生成一個 token 的時候需要結(jié)合前面所有的 token 做 attention 操作)時,計算注意力機制時需要存儲 Key(K) 和 Value(V),以便下一個時間步可以復(fù)用這些緩存,而不必重新計算整個序列。

          在標(biāo)準(zhǔn) Transformer 解碼時,每次生成新 token 時:

          • 需要 重新計算所有之前 token 的 K 和 V,并與當(dāng)前 token 進(jìn)行注意力計算。

          • 計算復(fù)雜度是 O(n2)(對于長度為 n 的序列)。

          img

          而 KV Cache 通過存儲 K 和 V 的歷史值,避免重復(fù)計算:

          • 只需計算 新 token 的 K 和 V,然后將其與緩存的值結(jié)合使用。

          • 計算復(fù)雜度下降到 O(n)(每個 token 只與之前緩存的 token 計算注意力)。

          二、KV Cache 的工作原理

          KV Cache 的核心思想是緩存歷史計算中的鍵(Key)和值(Value)矩陣,避免重復(fù)計算。具體來說:

          1. 在生成第一個 token 時,模型計算并緩存所有輸入 token 的 K 和 V 矩陣

          2. 生成后續(xù) token 時,只需要計算新 token 的查詢(Query)矩陣

          3. 將新的 Q 矩陣與緩存的 K、V 矩陣進(jìn)行注意力計算,同時將新 token 的 K、V 追加到緩存中

          這個過程可以用偽代碼直觀展示:

          初始輸入: [t0, t1, t2]
          首次計算: K=[K0,K1,K2], V=[V0,V1,V2] → 生成t3
          緩存狀態(tài): K=[K0,K1,K2], V=[V0,V1,V2]
          第二次計算: 新Q=Q3
          注意力計算: Attention(Q3, [K0,K1,K2]) → 生成t4
          更新緩存: K=[K0,K1,K2,K3], V=[V0,V1,V2,V3]
          第三次計算: 新Q=Q4
          注意力計算: Attention(Q4, [K0,K1,K2,K3]) → 生成t5
          更新緩存: K=[K0,K1,K2,K3,K4], V=[V0,V1,V2,V3,V4]
          ...

          通過這種方式,每次新生成 token 時,只需計算新的 Q 矩陣并與歷史 KV 矩陣進(jìn)行注意力計算,將時間復(fù)雜度從 O (n2) 降低到 O (n),極大提升了長序列生成的效率。

          下面,我們結(jié)合示意圖進(jìn)一步剖析一下 KV Cache 部分的邏輯。

          img

          img

          img

          img

          KV Cache 核心節(jié)約的時間有三大塊:

          1. 前面 n-1 次的 Q 的計算,當(dāng)然這塊對于一次一個 token 的輸出本來也沒有用;

          2. 同理還有 Attention 計算時對角矩陣變?yōu)樽詈笠恍?,?b 是同理的,這樣 mask 矩陣也就沒有什么用了;

          3. 前面 n-1 次的 K 和 V 的計算,也就是上圖紫色部分,這部分是實打?qū)嵄?Cache 過不需要再重新計算的部分。

          這里還有個 softmax 的問題,softmax 原本就是針對同一個 query 的所有 key 的計算,所以并不受影響。

          2.1 KV Cache 的技術(shù)細(xì)節(jié)
          1. 緩存結(jié)構(gòu)

          KV Cache 通常為每個注意力頭維護(hù)獨立的緩存,結(jié)構(gòu)如下:

          1. Key 緩存:形狀為 [batch_size, num_heads, seq_len, head_dim]

          2. Value 緩存:形狀為 [batch_size, num_heads, seq_len, head_dim]

          其中,seq_len 會隨著生成過程動態(tài)增長,直到達(dá)到模型最大序列長度限制。

          1. 內(nèi)存與速度的權(quán)衡

          KV Cache 雖然提升了速度,但需要額外的內(nèi)存存儲緩存數(shù)據(jù)。以 GPT-3 175B 模型為例,每個 token 的 KV 緩存約占用 20KB 內(nèi)存,當(dāng)生成 1000 個 token 時,單個樣本就需要約 20MB 內(nèi)存。在批量處理時,內(nèi)存消耗會線性增加。

          實際應(yīng)用中需要根據(jù)硬件條件在以下方面進(jìn)行權(quán)衡:

          1. 最大緩存長度(影響能處理的序列長度)

          2. 批量大?。ㄓ绊懖l(fā)處理能力)

          3. 精度選擇(FP16 比 FP32 節(jié)省一半內(nèi)存)

          4. 滑動窗口機制

          當(dāng)處理超長序列時,一些模型(如 Llama 2)采用滑動窗口機制,只保留最近的 N 個 token 的 KV 緩存,以控制內(nèi)存占用。這種機制在犧牲少量上下文信息的情況下,保證了模型能處理更長的對話。

          四、代碼實現(xiàn)解析

          下面以 PyTorch 為例,展示 KV Cache 在自注意力計算中的實現(xiàn)方式。

          1. 基礎(chǔ)自注意力實現(xiàn)(無緩存)

          首先看一下標(biāo)準(zhǔn)的自注意力計算,沒有緩存機制:

          import torch
          import torch.nn as nn
          import torch.nn.functional as F
          class SelfAttention(nn.Module):
              def __init__(self, embed_dim, num_heads):
                  super().__init__()
                  self.embed_dim = embed_dim
                  self.num_heads = num_heads
                  self.head_dim = embed_dim // num_heads
                  
                  # 定義Q、K、V投影矩陣
                  self.q_proj = nn.Linear(embed_dim, embed_dim)
                  self.k_proj = nn.Linear(embed_dim, embed_dim)
                  self.v_proj = nn.Linear(embed_dim, embed_dim)
                  self.out_proj = nn.Linear(embed_dim, embed_dim)
              
              def forward(self, x):
                  batch_size, seq_len, embed_dim = x.shape
                  
                  # 計算Q、K、V
                  q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                  k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                  v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                  
                  # 計算注意力分?jǐn)?shù)
                  attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
                  attn_probs = F.softmax(attn_scores, dim=-1)
                  
                  # 應(yīng)用注意力權(quán)重
                  output = attn_probs @ v
                  output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
                  
                  return self.out_proj(output)
          1. 帶 KV Cache 的自注意力實現(xiàn)

          下面修改代碼,加入 KV Cache 機制:

          class CachedSelfAttention(nn.Module):
              def __init__(self, embed_dim, num_heads):
                  super().__init__()
                  self.embed_dim = embed_dim
                  self.num_heads = num_heads
                  self.head_dim = embed_dim // num_heads
                  
                  # 定義投影矩陣
                  self.q_proj = nn.Linear(embed_dim, embed_dim)
                  self.k_proj = nn.Linear(embed_dim, embed_dim)
                  self.v_proj = nn.Linear(embed_dim, embed_dim)
                  self.out_proj = nn.Linear(embed_dim, embed_dim)
                  
                  # 初始化緩存
                  self.cache_k = None
                  self.cache_v = None
              
              def forward(self, x, use_cache=False):
                  batch_size, seq_len, embed_dim = x.shape
                  
                  # 計算Q、K、V
                  q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                  k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                  v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                  
                  # 如果使用緩存且緩存存在,則拼接歷史KV
                  if use_cache and self.cache_k is not None:
                      k = torch.cat([self.cache_k, k], dim=-2)
                      v = torch.cat([self.cache_v, v], dim=-2)
                  
                  # 如果使用緩存,更新緩存
                  if use_cache:
                      self.cache_k = k
                      self.cache_v = v
                  
                  # 計算注意力分?jǐn)?shù)(注意這里的k是包含歷史緩存的)
                  attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
                  attn_probs = F.softmax(attn_scores, dim=-1)
                  
                  # 應(yīng)用注意力權(quán)重
                  output = attn_probs @ v
                  output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
                  
                  return self.out_proj(output)
              
              def reset_cache(self):
                  """重置緩存,用于新序列的生成"""
                  self.cache_k = None
                  self.cache_v = None
          1. 生成過程中的緩存使用

          在文本生成時,我們可以這樣使用帶緩存的注意力機制:

          def generate_text(model, input_ids, max_length=50):
              # 初始化模型緩存
              model.reset_cache()
              
              # 處理初始輸入
              output = model(input_ids, use_cache=True)
              next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True)
              generated = [next_token]
              
              # 生成后續(xù)token
              for _ in range(max_length - 1):
                  # 只輸入新生成的token
                  output = model(next_token, use_cache=True)
                  next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True)
                  generated.append(next_token)
                  
                  # 如果生成結(jié)束符則停止
                  if next_token.item() == 102:  # 假設(shè)102是[SEP]的id
                      break
              
              return torch.cat(generated, dim=1)
          五、KV Cache 的優(yōu)化策略

          在實際部署中,為了進(jìn)一步提升 KV Cache 的效率,還會采用以下優(yōu)化策略:

          1. 分頁 KV Cache(Paged KV Cache):借鑒內(nèi)存分頁機制,將連續(xù)的 KV 緩存分割成固定大小的塊,提高內(nèi)存利用率,代表實現(xiàn)有 vLLM。

          2. 動態(tài)緩存管理:根據(jù)輸入序列長度動態(tài)調(diào)整緩存大小,在批量處理時優(yōu)化內(nèi)存分配。

          3. 量化緩存:使用 INT8 或 INT4 等低精度格式存儲 KV 緩存,在犧牲少量精度的情況下大幅減少內(nèi)存占用。

          4. 選擇性緩存:對于一些不重要的層或注意力頭,選擇性地不進(jìn)行緩存,平衡速度和內(nèi)存。

          六、總結(jié)

          KV Cache 通過緩存中間計算結(jié)果,有效解決了 Transformer 模型在生成式任務(wù)中的效率問題,是大模型能夠?qū)崿F(xiàn)實時交互的關(guān)鍵技術(shù)之一。理解 KV Cache 的工作原理和實現(xiàn)方式,對于優(yōu)化大模型推理性能、解決實際部署中的挑戰(zhàn)具有重要意義。

          七、參考鏈接

          https://zhuanlan.zhihu.com/p/670515231

          https://zhuanlan.zhihu.com/p/714288577

          https://zhuanlan.zhihu.com/p/715921106https://zhuanlan.zhihu.com/p/19489285169

          https://medium.com/@joaolages/kv-caching-explained-276520203249


          *博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。



          相關(guān)推薦

          技術(shù)專區(qū)

          關(guān)閉