✅ 我是丁师兄,专注于智能驾驶大模型,持续分享LLM面试干货。
✅ 大模型1v1辅导,已帮助多名同学成功上岸
Transformer 模型现在很火,内存优化又很重要。上周面试了一个 985 大学的女生,跟她谈到了 Transformer 模型的内存优化问题。
那么这个女生到底给出了哪些关于 Transformer 模型内存优化的独特思路呢?一起来看看。
Transformer 中文本是逐个 token 生成的,每次新的预测都会基于之前生成的所有 tokens 的上下文信息。
这种对顺序数据的依赖会减慢生成过程,因为每次预测下一个 token 都需要重新处理序列中所有之前的 tokens。
例如,要预测第 100 个 token,模型必须使用前 99 个 token 的信息,需要对这些 token 进行复杂的矩阵运算。预测第 101 个 token 时,也要对前 99 个 token 做类似计算,以及对第 100 个 token 的新计算。
如何简化呢?
答案是使用 KV 缓存。KV 缓存通过保存这些计算结果,使模型可以在生成后续 tokens 时直接访问这些结果,而不需要重新计算。
换句话说,在生成第 101 个 token 时,模型只需从 KV 缓存中检索前 99 个 token 的已存储数据,并只对第 100 个 token 执行必要的计算。
KV 缓存通常使用 float16 或 bfloat16 数据类型以 16 位的精度存储张量。对于一个 token,KV 缓存会为每一层和每个注意力头存储一对张量(键和值)。
层数 × KV 注意力头的数量 × 注意力头的维度 × (位宽 / 8) × 2
最后的 "2" 是因为有两组张量,也就是键和值。位宽通常为 16 位,由于 8 位是 1 字节,因此我们将位宽除以 8,这样在 KV 缓存中每 16 位参数占用 2 个字节。
32 × 8 × 128 × 2 × 2 = 131,072
注意:Llama 3 8B 有 32 个注意力头,不过由于 GQA 的存在,只有 8 个注意力头用于键和值。
从上面可以看到,对于一个 token,KV 缓存占用 131,072 字节,差不多 0.1 MB。这看起来好像不大,但对于许多不同类型的应用,大模型需要生成成千上万的 tokens。
举个例子,如果我们想利用 Llama 3 8B 的全部 context 大小(8192),KV 缓存将为 8191 个 token 存储键值张量,差不多占用 1.1 G 内存。换句话说,对于一块 24G 显存的消费级 GPU,KV 缓存将占用其总内存的 4.5%。
80 × 8 × 128 × 2 × 2 = 327,680
对于 8191 个 token,Llama 3 70B 的 KV 缓存将占用 2.7 GB。并且注意,这只是单个序列的内存消耗,如果我们进行批量解码,还需要将这个值乘以 batch size。
比如 batch size=32 的 Llama 3 8B 模型,将需要 35.2 GB 的 GPU 显存,一块消费级 GPU 显然搞不定了。
因此虽然在推理阶段用 KV 缓存可以提高处理速度,并且已经是业界标准做法,但是 KV 缓存在深层模型和长序列场景下,也会占据大量 GPU 内存。
而实际开发中,我们可以通过 KV 缓存量化,来降低推理阶段的 LLM 内存需求。后面我们将通过实际的例子(Llama 3 8B 模型),来看看如何对 KV 缓存进行量化的!
微信扫一扫
关注该公众号