首先我们先定义模型中的一些参数,transformer的层数记为 $l$,隐藏层维度为 $h$,注意力头数为 $a$,词表大小为 $V$,训练数据的批次大小为 $b$,序列长度为 $s$。
1. 模型参数量
transformer模型由 $l$ 个相同的层组成,每个层分为两部分:self-attention块和MLP块。
其中,self-attention块的模型参数有Q、K、V 的权重矩阵 $W_{Q}、W_{K}、W_{V}$ 和偏置,以及输出权重矩阵 $W_{O}$ 和对应的偏置,我们以4个权重矩阵的形状为[h,h]
为例,4个偏置的形状为[h]
,则这里self-attention块的参数量为 $4 * h^2 + 4h$。
MLP模块通常由2个线性层组成,一般地,第一个线性层主要进行升维,将维度从 $h$ 映射到 $4h$,第二个线性层再降维,将维度从 $4h$ 映射到 $h$。因此这里第一个线性层的权重矩阵 $W_1$ 的形状为[h,4h]
,偏置的形状为[4h]
。第二个线性层权重矩阵 $W_2$ 的形状为[4h,h]
,偏置形状为[h]
。可得这里MLP模块的参数量为 $2* 4h^2 + 5h$。
同时,self-attention块和MLP块中还各包含layer normalization(LN层),以每个模块包含一个LN层为例。每个LN层包含了2个可训练模型参数:缩放参数 $γ$ 和平移参数 $β$,形状都为[h]
。2个layer normalization的参数量为 $2 * 2h$。
因此,通过上面的参数量可以得到每个transformer block的参数量为 $12h^2 + 13h$。
此外,transformer模型中的词嵌入矩阵也涉及比较多的模型参数,其中每个词向量的维度通常为隐藏层维度 $h$,词嵌入矩阵的参数量则为 $Vh$。
综上,一个 $l$ 层的transformer模型的参数量一般为, $l(12h^2 + 13h) + Vh$,若模型隐藏层的维度 $h$ 远大于词表大小 $V$时,我们可以会略一次项,模型参数量约等于 $l(12h^2 + 13h)$。
2. 模型显存占用量
2.1 训练阶段显存
在模型训练阶段,显存占用主要在以下四部分:模型参数、反向传播计算得到的梯度、优化器状态 以及 前向计算过程中产生的中间激活,这里主要分析前三部分。在大模型训练过程中,通常使用AdamW优化器,并一般使用混合精度训练来加速模型训练,这里我们以优化器为AdamW的混合精度训练为例来分析训练阶段显存占用情况:
在训练阶段,每1个可训练模型参数都会对应1个梯度,并对应2个优化器状态(Adam优化器梯度的一阶动量和二阶动量)。这里我们设模型的参数量为 $Φ$, 那么梯度的元素数量为Φ,AdamW优化器的元素数量为2Φ。float16数据类型的元素占2个bytes,float32数据类型的元素占4个bytes。在混合精度训练中,会使用float16的模型参数进行前向传递和反向传递($2Φ$),计算得到float16的梯度($2Φ$);在优化器更新模型参数时,会使用float32的优化器状态($4Φ$ + $4Φ$)、float32的梯度($4Φ$)、float32的模型参数($4Φ$)来更新模型参数。总结下来为:
- model 本身的参数 fp16 (+$2Φ$)
- 梯度 fp16 (+$2Φ$)
- 优化器momentum fp32 (+$4Φ$)
- 优化器 var fp32 (+$4Φ$)
- 优化器中的模型参数 fp32 (+$4Φ$)
- 优化器中的梯度fp32(根据实现不同,不一定需要全部保存下来,如果只保留部分梯度,可以忽略)(+$4Φ$)
通过上面的计算得到,对于每个可训练模型参数,占用了(2 + 4) + (2 + 4) + (4 + 4) = 20 bytes。所以,使用AdamW优化器和混合精度训练来训练参数量为 $Φ$ 的模型,模型参数、梯度和优化器状态占用的显存大小为 $20Φ$ bytes。
2.2 推理阶段显存
在推理阶段,没有优化器状态和梯度,也不需要保存中间激活,因此模型在推理阶段占用的显存要远小于训练阶段。推理阶段显存占用的大头是模型参数部分,如果使用fp16进行推理,显存占用大约是 $2Φ$ bytes。如果使用KV cache来加速推理过程,KV cache也需要占用显存。
3. 计算量FLOPs估计
首先,对于 $A \in R^{1\times n}, B\in R^{n\times 1}$,计算$AB$需要进行$n$次乘法运算和$n$次加法运算,共计$2n$次浮点数运算,需要$2n$的FLOPs。对于$A\in R^{m\times n},B\in R^{n\times p}$,计算$AB$需要的浮点数运算次数为 $2mnp$。
在一次训练迭代中,假设输入数据的形状为 [b, s]。我们先分析self-attention块的计算,计算公式如下:
$$ Q = xW_Q, K = xW_K, V = xW_V, $$ $$ x_{out} = softmax\left(\frac{QK^T}{\sqrt{h}}\right) \cdot V \cdot W_o + x $$
计算 $Q, K, V$:矩阵乘法的输入和输出形状为 [b, s, h] × [h, h] → [b, s, h]。计算量为 $3 * 2bsh^2 = 6bsh^2$。
$QK^T$ 矩阵乘法的输入和输出形状为 [b, head_num, s, per_head_hidden_size] × [b, head_num, per_head_hidden_size, s] → [b, head_num, s, s]。计算量为 $2bs^2h$。
计算在 $V$ 上的加权score · $V$,矩阵乘法的输入和输出形状为 [b, head_num, s, s ] × [b, head_num, s, per_head_hidden_size] → [b, head_num, s, per_head_hidden_size] 。计算量为 $2bs^2h$。
attention 后的线性映射,矩阵乘法的输入和输出形状为 [b, s, h] × [h, h] → [b, s, h]。计算量为 $2bsh^2$。
接下来分析 MLP 块的计算,计算公式如下:
$$ x = f_{gelu}(x_{out}W_1)W_2 + x_{out} $$
第一个线性层,矩阵乘法的输入和输出形状为 [b, s, h] × [h, 4h] → [b, s, 4h]。计算量为 $8bsh^2$。
第二个线性层,矩阵乘法的输入和输出形状为 [b, s, 4h] × [4h, h] → [b, s, h]。计算量为 $8bsh^2$。
将上述计算量相加,得到每个 transformer 层的计算量大约为 $24bsh^2 + 4bs^2h$。
此外,另一个计算量的大头是 logits 的计算,将隐藏向量映射为词表大小。矩阵乘法的输入和输出形状为 [b, s, h] × [h, V] → [b, s, V],计算量为 $2bshV$。
因此,对于一个 l 层的 transformer 模型,输入数据形状为 [b, s] 的情况下,一次训练迭代的计算量为 $l * (24bsh^2 + 4bs^2h) + 2bshV$。
- 当隐藏维度 $h$ 比较大,且远大于序列长度 $s$ 时,我们可以忽略一次项,计算量可以近似为 $24bsh^2 * l$。前面提到模型参数量可近似为 $12lh^2$,所以当输入的tokens数为 $bs$时,$\frac{24bsh^2l} {12lh^2\times bs} = 2$。由此我们可以近似认为在一次前向传递中,对于每个token,每个模型参数需要进行2次浮点数运算,即一次乘法运算和一次加法运算。