本文试图以最清晰的方式手动推导 Transformers 每一步的参数量到显存、计算量问题。理解底层,才能更好的做训练和优化。本文内容包括最基本的模型训练和推理过程中的显存占用,以及对其中中间激活值的优化方案和显存占用。****

1 训练过程

训练中的显存占用分两块

  1. 模型状态,参数、梯度和优化器状态
  2. 剩余状态, 激活值、临时buffer

1-1 模型状态显存

写在前面的前置知识,

设模型参数量为 $\Phi$,模型参数(fp16)、模型梯度(fp16)和优化器状态(fp32),总参数量 =,总数为 $2 \Phi + 2 \Phi + (4+4+4) \Phi = 16 \Phi$。

  1. 这部分比较固定,且在整个训练过程中都要存在显存中。
  2. 一般只能通过并行切分(Tensor Parallelism/Pipeline Parallism)能减少。
  3. 不同优化器的 K 值不同,算法的中间变量、框架的实现都有可能有一定区别。
优化器 K值 构成
adamw 12 fp32 主权重 4 + 动量 4 +方差 4
SGD 8 fp32 主权重 4 + 动量 4
bitsandbytes 6 fp32 主权重 + 动量 1 + 方差 1
LOMO

1-2 中间激活值显存

激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。

中间激活值占用显存分两个部分分析:Attention 和 MLP,Embedding 没有中间值

Attention 部分计算公式如下: