[LLM]从零开始搭建GPT-2(2)

07.25.24 23:04

这一章将探讨gpt2训练的一些基本优化方法。


1. 浮点精度优化(Tensor Cores)

在pythorch框架,张量的默认浮点数类型是float32,每一个浮点数都会占用32位,这其实是比较大的空间占用,对于深度学习方面的训练来说这种程度的精度是没必要的。

torch.set_float32_matmul_precision('high') # 设置张量默认为TF32
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y) # 设置张量为BF16

这并不会把模型里所有的参数都转化为bfloat16或者tfloat32,而是会和float32共存。


2. 引入torch.compile

torch.compile十分的强大,可以显著减少芯片的读写次数,仅需一行代码就可以大大提高训练速度。

model = torch.compile(model)

把模型传给compile后,它会在执行代码前会先总览一遍代码,因此在执行计算时它是会知道接下来的运算的,比如:

A * (B + C) / D

在开始进行乘法运算时程序就已经知道接下来会进行的加法和除法运算,以及要参加运算的变量。compile会在芯片上保留计算的中间变量,大量超速运算,一次性完成相关变量的所有运算,最后,一次读写把结果返回HBM(相当于GPU的内存)。


3. Flash Attention

torch.compile是很强大,但目前有些计算的优化还是遗漏了,这里我们可以使用Flash Attention,也是一行代码,更多信息点击这里

y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention


4. 替换为‘漂亮数字’

这是一种非常简单粗暴的方法,但总会有意想不到的提升效果。‘漂亮数字’就是12,32,64这种可以被2多次整除的数字,我们可以从头查阅代码,找到‘丑陋数字’用相近的漂亮数字替换,比如gpt2代码里的词汇表大小50257就是一个十分‘丑陋’的数字,就可以用50304来替换。

Comments