KV Cache

·1594·4 分钟·
AI摘要: 本文介绍了LLM(大型语言模型)中KV Cache的重要性,它是加速运算的关键步骤,确保在对话过程中输入文本增长时,推理速度不受影响。文章详细解释了Self Attention机制和因果掩码的概念,并讨论了KV Cache如何通过缓存计算结果来优化预测过程。

KV Cache是LLM中加速运算的非常重要的一步,能够保证模型在对话过程中,输入文本越来越长却不影响推理速度。

如图所示,在LLM推理过程中,KV Cache占据了很大一部分的显存。

image-20241201142926563

注意力机制

LLM的流程本质是输入n个token, 输出第n+1个token,当得到第n +1个token,此时再根据0~n+1个token预测第n+2个token,以此反复。

LLM中预测下一个token中,最关键的一步是Self Attention的计算,假设我们输入的N个token得到了N个q,k,v向量,用数学表示就是qi,ki,vi(i(0,n))q_i, k_i, v_i (i \in (0, n))

Self Attention的过程是对N个token进行qkv计算,具体的公式为j=0nSoftmax(1dkqiTkj)vj\sum_{j = 0}^{n} Softmax( \frac{1}{\sqrt{d_k}} q_i^T k_j) v_j, 去除固定系数1dk\frac{1}{\sqrt{d_k}}, 得到j=0nSoftmax(qiTkj)vj\sum_{j = 0}^n Softmax( q_i^T k_j ) v_j

这样就得到了第ii个token的注意力向量为j=0nSoftmax(qiTkj)vj\sum_{j = 0}^n Softmax( q_i^T k_j) v_j

因果掩码

在预测第 i+1i + 1 个token的过程中,llm只能看到 00ii 个位置的token,而不能看到 i+1i+1 以及之后的token。从数学上描述就是j=0iSoftmax(qiTkj)vj\sum_{j = 0}^i Softmax( q_i^T k_j) v_j , 注意 jj 的范围变化。

KV Cache

可以注意到,计算i+1i + 1个token的时候,需要对于每个 qiq_i都要计算j=0iSoftmax(qiTkj)vj\sum_{j = 0}^i Softmax( q_i^T k_j) v_j , 其中kjvjk_j v_j 是共有的。

其次当计算i+2i + 2个token的时候,需要对于每个qiq_i都要计算j=0i+1Softmax(qiTkj)vj=j=0iSoftmax(qiTkj)vj+qiTkj+1vj+1\sum_{j = 0}^{i + 1} Softmax( q_i^T k_j) v_j = \sum_{j = 0}^i Softmax( q_i^T k_j ) v_j + q_i^T k_{j + 1} v_{j + 1} , 又出现kjvjk_j v_j

因此kv可以进行缓存,当计算i+2i + 2个token的时候,只需要再计算一下kj+1vj+1k_{j + 1} v_{j + 1}就行

image-20241201164808502

从矩阵的角度更容易理解,如图计算出了过往所有Token的KV,公式为QKTV=Q(KTV)QK^T V = Q (K^TV) 。在通过当前Token来计算下一个Token时,公式为qiKiiTVii=qikiTvi+qiK(i1)(i1)TV(i1)(i1)q_i K_{i * i}^TV_{i * i} = q_i k_i^T v_i + q_i K^T_{(i - 1) * (i - 1)} V_{(i - 1)* (i - 1)} , 因此缓存KTVK^TV就行了

为什么没有Q Cache

从上面的推导可以知道,预测下一个token并不需要和以前token的Q向量进行计算,只需要使用当前token对应的Q向量,因此没必要缓存

Kaggle学习赛初探