实例
数据集
笔记本

笔记本

显存不够用怎么办 —— 梯度累积

🎉 论如何用 GTX 3080 跑出 GTX 3090 的效果,节约显存好办法 —— 梯度累积 🎉
Dave上传于 3 years ago
标签
PyTorch
浏览2862
笔记本内容

OOM 的困扰 #

经常有小伙伴遇到 OOM 报错以后来问 Dave 怎么办,如果遭受下面这些报错困扰那就是了。

image.png

image.png

内存溢出(Out-Of-Memory) 是计算机操作中的一种通常不希望遇到的状态,在这种状态下,无法分配额外的内存以供程序使用。这样状态下的系统将无法加载任何其他程序,并且由于许多程序可能在执行期间将额外的数据加载到内存中,因此这些程序将停止正常运行。

在神经网络的训练中,经常会出现图像尺寸很大又想增大batch size,无奈显存不足,但是大显存的 A6000 又挺贵的💰,那怎么办呢?

人民币小伙伴:没事,我有💰啊?

Dave:冒犯了冒犯了。。。

因为 Dave 家境贫寒,所以经常想用 GTX 3080 跑出 GTX 3090 的效果,除了之前 Dave 介绍过的 半精度训练以外 还有一个方法一直在用,推荐给小伙伴们 —— 梯度累加(积)

什么是梯度累加(积) #

梯度累加是对多个批次的训练梯度进行累计,然后同时执行权重更新。这样的好处是可以只用一个批次占用的 GPU 显存,达到多个批次数量相加的 batch-size。

image.png

!pip install seaborn tqdm
import os
import timm
import torch
import random

from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

import torchvision.transforms as transforms
import torchvision.datasets as dataset
import seaborn as sns
import numpy as np

from tqdm import tqdm

def seed_torch(seed=99):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
seed_torch()
# 准备 MNIST 数据
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dataset.MNIST(root='./', train=True, transform=trans, download=True)
test_set = dataset.MNIST(root='./', train=False, transform=trans)

正常训练(半精度) #

# Dataloader
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=32768,
                 num_workers=8,
                 shuffle=True)

model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=10).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

losses = []
scaler = GradScaler()

for epoch in range(10):
    loss_epoch = 0
    for i, (input, target) in enumerate(train_loader):
        with autocast():
            output = model(input.cuda())
            loss = loss_fn(output, target.cuda())

        scaler.scale(loss).backward()
        loss_epoch += loss.item() * len(input)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    print('Loss: ', loss_epoch/len(train_set))
    losses.append(loss_epoch/len(train_set))
Loss:  4.541711765797933
Loss:  4.546310216267903
Loss:  3.84944557317098
Loss:  2.795659669748942
Loss:  2.4573636914571124
Loss:  2.2008343949635822
Loss:  2.0022270856221516
Loss:  1.843101537958781
Loss:  1.7295104548772176
Loss:  1.6263628295898438
sns.relplot(kind="line",data=losses);

image.png

!nvidia-smi
Thu Dec  9 17:01:18 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 495.44       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:10:00.0 Off |                  Off |
| 30%   23C    P8    20W / 300W |  29987MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    891356      C   ...iconda3-4.7.12/bin/python    29985MiB |
+-----------------------------------------------------------------------------+

正常训练占用 30 GB 显存

训练中加入梯度累加 #

  • batchsize 选择 8192
  • 梯度累加策略:每 4 个 iteration 进行一次模型更新(8192 * 4 = 32768)
  • 将批次敏感的 Batch Norm 更换成 Group Norm
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=8192,
                 num_workers=8,
                 shuffle=True)
model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=10).cuda()
losses = []

# 将模型的 BatchNorm2d 更换成 GroupNorm
def convert_bn_to_gn(model):
    for child_name, child in model.named_children():
        if isinstance(child, torch.nn.BatchNorm2d):
            num_features = child.num_features
            setattr(model, child_name, torch.nn.GroupNorm(num_groups=1, num_channels=num_features))
        else:
            convert_bn_to_gn(child)

convert_bn_to_gn(model)
model = model.cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()

scaler = GradScaler()
iters_to_accumulate = 4

for epoch in range(10):
    loss_epoch = 0
    for i, (input, target) in enumerate(train_loader):
        with autocast():
            output = model(input.cuda())
            loss = loss_fn(output, target.cuda())
            loss_epoch += loss.item() * len(input)
            loss = loss / iters_to_accumulate

        scaler.scale(loss).backward()
        if (i + 1) % iters_to_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
    print('Loss: ', loss_epoch/len(train_set))
    losses.append(loss_epoch/len(train_set))
Loss:  9.297225899251302
Loss:  5.8672450810750325
Loss:  4.790800758870443
Loss:  3.686163105646769
Loss:  3.2845786440531413
Loss:  2.0480232447306315
Loss:  2.3190684238433836
Loss:  1.8347549475987752
Loss:  1.7392349668502807
Loss:  1.6327584085464477
sns.relplot(kind="line",data=losses);

image.png

!nvidia-smi
Thu Dec  9 11:54:48 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 495.44       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:10:00.0 Off |                  Off |
| 30%   23C    P8     8W / 300W |  13701MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    886056      C   ...iconda3-4.7.12/bin/python    13699MiB |
+-----------------------------------------------------------------------------+

占用显存明显降低到 13.7 GB左右,同时可以看到对 Loss 影响并不大。

image

评论(0条)