本文试图以最清晰的方式手动推导 Transformers 每一步的参数量到显存、计算量问题。理解底层,才能更好的做训练和优化。本文内容包括最基本的模型训练和推理过程中的显存占用,以及对其中中间激活值的优化方案和显存占用。****
训练中的显存占用分两块
写在前面的前置知识,
设模型参数量为 $\Phi$,模型参数(fp16)、模型梯度(fp16)和优化器状态(fp32),总参数量 =,总数为 $2 \Phi + 2 \Phi + (4+4+4) \Phi = 16 \Phi$。
优化器 | K值 | 构成 |
---|---|---|
adamw | 12 | fp32 主权重 4 + 动量 4 +方差 4 |
SGD | 8 | fp32 主权重 4 + 动量 4 |
bitsandbytes | 6 | fp32 主权重 + 动量 1 + 方差 1 |
LOMO |
激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。
中间激活值占用显存分两个部分分析:Attention 和 MLP,Embedding 没有中间值
Attention 部分计算公式如下: