曹子晗

个人博客

欢迎来到我的个人博客


一些你需要知道混合精度的事(amp)

目录

参考链接:

知乎 OpenMMLab

torch.cuda.amp精度转换的过程

  1. 维护一个 FP32 数值精度模型的副本
  2. 在每个iteration
    1. 拷贝并且转换成 FP16 模型
    2. 前向传播(FP16 的模型参数)
    3. loss 乘 scale factor $s$
    4. 反向传播(FP16 的模型参数和参数梯度)
    5. 参数梯度乘 $1/s$
    6. 利用 FP16 的梯度更新 FP32 的模型参数

基本操作

单卡训练的时候,在main thread这样写

# amp依赖Tensor core架构,所以model参数必须是cuda tensor类型
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# GradScaler对象用来自动做梯度缩放
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        # 在autocast enable 区域运行forward
        with autocast():
            # model做一个FP16的副本,forward
            output = model(input)
            loss = loss_fn(output, target)
        # 用scaler,scale loss(FP16),backward得到scaled的梯度(FP16)
        # 先是loss = loss*s,再backward
        scaler.scale(loss).backward()
        # scaler 更新参数,会先自动unscale梯度
        # 如果有nan或inf,自动跳过
        scaler.step(optimizer)
        # scaler factor更新
        scaler.update() # 更新s

梯度裁剪

scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()

        # unscale 梯度,可以不影响clip的threshold
        scaler.unscale_(optimizer)

        # clip梯度
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        # unscale_()已经被显式调用了,scaler正常执行step更新参数,有nan/inf也会跳过
        scaler.step(optimizer)
        scaler.update()

梯度累积

scaler = GradScaler()

for epoch in epochs:
    for i, (input, target) in enumerate(data):
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
            # loss 根据 累加的次数归一一下
            loss = loss / iters_to_accumulate

        # scale 归一的loss 并backward  
        scaler.scale(loss).backward()

        if (i + 1) % iters_to_accumulate == 0:
            # may unscale_ here if desired 
            # (e.g., to allow clipping unscaled gradients)

            # step() and update() proceed as usual.
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

梯度惩罚

scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        # 防止溢出,在不是autocast 区域,先用scaled loss 得到 scaled 梯度
        scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
                                                 inputs=model.parameters(),
                                                 create_graph=True)
        # 梯度unscale
        inv_scale = 1./scaler.get_scale()
        grad_params = [p * inv_scale for p in scaled_grad_params]
        # 在autocast 区域,loss 加上梯度惩罚项
        with autocast():
            grad_norm = 0
            for grad in grad_params:
                grad_norm += grad.pow(2).sum()
            grad_norm = grad_norm.sqrt()
            loss = loss + grad_norm

        scaler.scale(loss).backward()

        # may unscale_ here if desired 
        # (e.g., to allow clipping unscaled gradients)

        # step() and update() proceed as usual.
        scaler.step(optimizer)
        scaler.update()

多个模型

比如训练GAN的时候

scaler = torch.cuda.amp.GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer0.zero_grad()
        optimizer1.zero_grad()
        with autocast():
            output0 = model0(input)
            output1 = model1(input)
            loss0 = loss_fn(2 * output0 + 3 * output1, target)
            loss1 = loss_fn(3 * output0 - 5 * output1, target)

        # (retain_graph here is unrelated to amp, it's present because in this
        # example, both backward() calls share some sections of graph.)
        scaler.scale(loss0).backward(retain_graph=True)
        scaler.scale(loss1).backward()

        # You can choose which optimizers receive explicit unscaling, if you
        # want to inspect or modify the gradients of the params they own.
        scaler.unscale_(optimizer0)

        scaler.step(optimizer0)
        scaler.step(optimizer1)

        scaler.update()