cover_image

上周面的一个985女生,问了Transformer模型的内存优化

丁师兄 丁师兄大模型 2024年08月21日 19:37

 我是丁师兄,专注于智能驾驶大模型,持续分享LLM面试干货。


 大模型1v1辅导,已帮助多名同学成功上岸





Transformer 模型现在很火,内存优化又很重要。上周面试了一个 985 大学的女生,跟她谈到了 Transformer 模型的内存优化问题。

那么这个女生到底给出了哪些关于 Transformer 模型内存优化的独特思路呢?一起来看看。

01
什么是Transformer模型中的KV缓存?

Transformer 中文本是逐个 token 生成的,每次新的预测都会基于之前生成的所有 tokens 的上下文信息。

这种对顺序数据的依赖会减慢生成过程,因为每次预测下一个 token 都需要重新处理序列中所有之前的 tokens。

例如,要预测第 100 个 token,模型必须使用前 99 个 token 的信息,需要对这些 token 进行复杂的矩阵运算。预测第 101 个 token 时,也要对前 99 个 token 做类似计算,以及对第 100 个 token 的新计算。

如何简化呢?

答案是使用 KV 缓存KV 缓存通过保存这些计算结果,使模型可以在生成后续 tokens 时直接访问这些结果,而不需要重新计算。

换句话说,在生成第 101 个 token 时,模型只需从 KV 缓存中检索前 99 个 token 的已存储数据,并只对第 100 个 token 执行必要的计算。 

02
如何估算KV缓存消耗的内存大小?

KV 缓存通常使用 float16 或 bfloat16 数据类型以 16 位的精度存储张量。对于一个 token,KV 缓存会为每一层和每个注意力头存储一对张量(键和值)。

这些张量的大小由注意力头的维度决定,这对张量的总内存消耗(以字节为单位)可以通过以下公式计算: 
层数 × KV 注意力头的数量 × 注意力头的维度 × (位宽 / 8) × 2

最后的 "2" 是因为有两组张量,也就是键和值。位宽通常为 16 位,由于 8 位是 1 字节,因此我们将位宽除以 8,这样在 KV 缓存中每 16 位参数占用 2 个字节。

我们以 Llama 3 8B 为例,这个公式就变为:
32 × 8 × 128 × 2 × 2 = 131,072

注意:Llama 3 8B 有 32 个注意力头,不过由于 GQA 的存在,只有 8 个注意力头用于键和值。

从上面可以看到,对于一个 token,KV 缓存占用 131,072 字节,差不多 0.1 MB。这看起来好像不大,但对于许多不同类型的应用,大模型需要生成成千上万的 tokens。

举个例子,如果我们想利用 Llama 3 8B 的全部 context 大小(8192),KV 缓存将为 8191 个 token 存储键值张量,差不多占用 1.1 G 内存。换句话说,对于一块 24G 显存的消费级 GPU,KV 缓存将占用其总内存的 4.5%。

而对于更大的模型,KV 缓存增长得更快。比如对于 Llama 3 70B,它有 80 层,公式变为:
80 × 8 × 128 × 2 × 2 = 327,680  

对于 8191 个 token,Llama 3 70B 的 KV 缓存将占用 2.7 GB。并且注意,这只是单个序列的内存消耗,如果我们进行批量解码,还需要将这个值乘以  batch size。

比如 batch size=32 的 Llama 3 8B 模型,将需要 35.2 GB 的 GPU 显存,一块消费级 GPU 显然搞不定了。

因此虽然在推理阶段用 KV 缓存可以提高处理速度,并且已经是业界标准做法,但是 KV 缓存在深层模型和长序列场景下,也会占据大量 GPU 内存。

而实际开发中,我们可以通过 KV 缓存量化,来降低推理阶段的 LLM 内存需求。后面我们将通过实际的例子(Llama 3 8B 模型),来看看如何对 KV 缓存进行量化的!



END


加入学习




 我是丁师兄,专注于智能驾驶大模型,持续分享LLM面试干货。


 大模型1v1辅导,已帮助多名同学成功上岸

图片

微信:dsxaigc

微信扫一扫
关注该公众号

继续滑动看下一个
丁师兄大模型
向上滑动看下一个