混合精度训练真的降低显存占用吗

717
2 分钟阅读
混合精度训练真的降低显存占用吗
"AI摘要: 本文探讨了混合精度训练(AMP)是否真的能降低显存占用。通过分析模型参数、梯度、优化器参数和激活值在FP32与FP16下的存储需求,指出虽然前向推理时显存减半,但反向传播因数值范围限制需转回FP32并保留原始FP32模型副本,导致总显存变为原来的1.5倍。然而实际训练中梯度和激活值才是主要开销,综合计算显示混合精度仍可减少整体显存使用量(如7B模型从112G降至84G),同时提升计算速度。"

混合精度训练AMP是指将模型参数从fp32转化为fp16进行推理,由于位数少一半,在GPU运算的时候,速度更快,可以加速模型训练,但是混合精度训练真的一定降低显存占用吗

模型训练的参数一般分为:

  1. 模型参数本身(比如7B的模型,fp32精度,光放入显卡就要28G显存)
  2. 梯度(每个要优化的参数都带着一个梯度,如果全参数优化,就和模型参数等价,也是28G)
  3. 优化器参数(部分优化器为了让梯度优化更准更快,附带计算了很多额外动量,如AdamW,又额外带上了2 * 梯度参数,就是56G显存)
  4. 激活值,在反向传播的时候,计算梯度需要使用当前节点在前向计算时候输出的激活值,这部分非常灵活,取决于batch,layer层数等等

前向推理

模型参数变为fp16,前向推理的所有显存占用砍一半(梯度,优化器额外参数,激活值)

反向传播

由于反向传播在链式求导的时候涉及到梯度的乘法和除法,在fp16情况下非常容易溢出(fp16表达区间比较小)

因此,在梯度反向传播的时候,为了表达更大范围的数值,会将fp16再转化为fp32

但是模型参数是fp16,因此为了计算精度不丢失,必须要保存一份原始模型fp32参数

因此,混合精度会导致模型参数保存两份,显存占用为原来的32\frac{3}{2} (7B的模型,现在需要占用42G)

显存占用到底下降还是上升

在模型训练时,梯度和激活值才是显存占用大头,如7B模型(28G),优化器在fp32占据28 + 56 = 84G,共112G。

在混合精度下,模型参数42G,优化器42G,共84G。

因此混合精度可以降低显存占用,并且fp16的计算速度更快