混合精度训练真的降低显存占用吗

·717·2 分钟·
混合精度训练AMP是指将模型参数从fp32转化为fp16进行推理,由于位数少一半,在GPU运算的时候,速度更快,可以加速模型训练,但是混合精度训练真的一定降低显存占用吗 模型训练的参数一般分为: 1. 模型参数本身(比如7B的模型,fp32精度,光放入显卡就要28G显存) 2. 梯度(每个要优化

混合精度训练AMP是指将模型参数从fp32转化为fp16进行推理,由于位数少一半,在GPU运算的时候,速度更快,可以加速模型训练,但是混合精度训练真的一定降低显存占用吗

模型训练的参数一般分为:

  1. 模型参数本身(比如7B的模型,fp32精度,光放入显卡就要28G显存)
  2. 梯度(每个要优化的参数都带着一个梯度,如果全参数优化,就和模型参数等价,也是28G)
  3. 优化器参数(部分优化器为了让梯度优化更准更快,附带计算了很多额外动量,如AdamW,又额外带上了2 * 梯度参数,就是56G显存)
  4. 激活值,在反向传播的时候,计算梯度需要使用当前节点在前向计算时候输出的激活值,这部分非常灵活,取决于batch,layer层数等等

前向推理

模型参数变为fp16,前向推理的所有显存占用砍一半(梯度,优化器额外参数,激活值)

反向传播

由于反向传播在链式求导的时候涉及到梯度的乘法和除法,在fp16情况下非常容易溢出(fp16表达区间比较小)

因此,在梯度反向传播的时候,为了表达更大范围的数值,会将fp16再转化为fp32

但是模型参数是fp16,因此为了计算精度不丢失,必须要保存一份原始模型fp32参数

因此,混合精度会导致模型参数保存两份,显存占用为原来的32\frac{3}{2} (7B的模型,现在需要占用42G)

显存占用到底下降还是上升

在模型训练时,梯度和激活值才是显存占用大头,如7B模型(28G),优化器在fp32占据28 + 56 = 84G,共112G。

在混合精度下,模型参数42G,优化器42G,共84G。

因此混合精度可以降低显存占用,并且fp16的计算速度更快