实例
数据集
笔记本
笔记本

显存不够用怎么办 —— 梯度累积
🎉 论如何用 GTX 3080 跑出 GTX 3090 的效果,节约显存好办法 —— 梯度累积 🎉
笔记本内容
OOM 的困扰 #
经常有小伙伴遇到 OOM 报错以后来问 Dave 怎么办,如果遭受下面这些报错困扰那就是了。
内存溢出(Out-Of-Memory) 是计算机操作中的一种通常不希望遇到的状态,在这种状态下,无法分配额外的内存以供程序使用。这样状态下的系统将无法加载任何其他程序,并且由于许多程序可能在执行期间将额外的数据加载到内存中,因此这些程序将停止正常运行。
在神经网络的训练中,经常会出现图像尺寸很大又想增大batch size,无奈显存不足,但是大显存的 A6000 又挺贵的💰,那怎么办呢?
人民币小伙伴:没事,我有💰啊?
Dave:冒犯了冒犯了。。。
因为 Dave 家境贫寒,所以经常想用 GTX 3080 跑出 GTX 3090 的效果,除了之前 Dave 介绍过的 半精度训练以外 还有一个方法一直在用,推荐给小伙伴们 —— 梯度累加(积)
!pip install seaborn tqdm
import os
import timm
import torch
import random
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
import torchvision.transforms as transforms
import torchvision.datasets as dataset
import seaborn as sns
import numpy as np
from tqdm import tqdm
def seed_torch(seed=99):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
seed_torch()
# 准备 MNIST 数据
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dataset.MNIST(root='./', train=True, transform=trans, download=True)
test_set = dataset.MNIST(root='./', train=False, transform=trans)
正常训练(半精度) #
# Dataloader
train_loader = torch.utils.data.DataLoader(
dataset=train_set,
batch_size=32768,
num_workers=8,
shuffle=True)
model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=10).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
losses = []
scaler = GradScaler()
for epoch in range(10):
loss_epoch = 0
for i, (input, target) in enumerate(train_loader):
with autocast():
output = model(input.cuda())
loss = loss_fn(output, target.cuda())
scaler.scale(loss).backward()
loss_epoch += loss.item() * len(input)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
print('Loss: ', loss_epoch/len(train_set))
losses.append(loss_epoch/len(train_set))
Loss: 4.541711765797933 Loss: 4.546310216267903 Loss: 3.84944557317098 Loss: 2.795659669748942 Loss: 2.4573636914571124 Loss: 2.2008343949635822 Loss: 2.0022270856221516 Loss: 1.843101537958781 Loss: 1.7295104548772176 Loss: 1.6263628295898438
sns.relplot(kind="line",data=losses);
!nvidia-smi
Thu Dec 9 17:01:18 2021 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 495.44 Driver Version: 495.44 CUDA Version: 11.5 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA RTX A6000 Off | 00000000:10:00.0 Off | Off | | 30% 23C P8 20W / 300W | 29987MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 891356 C ...iconda3-4.7.12/bin/python 29985MiB | +-----------------------------------------------------------------------------+
正常训练占用 30 GB 显存
训练中加入梯度累加 #
- batchsize 选择 8192
- 梯度累加策略:每 4 个 iteration 进行一次模型更新(8192 * 4 = 32768)
- 将批次敏感的 Batch Norm 更换成 Group Norm
train_loader = torch.utils.data.DataLoader(
dataset=train_set,
batch_size=8192,
num_workers=8,
shuffle=True)
model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=10).cuda()
losses = []
# 将模型的 BatchNorm2d 更换成 GroupNorm
def convert_bn_to_gn(model):
for child_name, child in model.named_children():
if isinstance(child, torch.nn.BatchNorm2d):
num_features = child.num_features
setattr(model, child_name, torch.nn.GroupNorm(num_groups=1, num_channels=num_features))
else:
convert_bn_to_gn(child)
convert_bn_to_gn(model)
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()
scaler = GradScaler()
iters_to_accumulate = 4
for epoch in range(10):
loss_epoch = 0
for i, (input, target) in enumerate(train_loader):
with autocast():
output = model(input.cuda())
loss = loss_fn(output, target.cuda())
loss_epoch += loss.item() * len(input)
loss = loss / iters_to_accumulate
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
print('Loss: ', loss_epoch/len(train_set))
losses.append(loss_epoch/len(train_set))
Loss: 9.297225899251302 Loss: 5.8672450810750325 Loss: 4.790800758870443 Loss: 3.686163105646769 Loss: 3.2845786440531413 Loss: 2.0480232447306315 Loss: 2.3190684238433836 Loss: 1.8347549475987752 Loss: 1.7392349668502807 Loss: 1.6327584085464477
sns.relplot(kind="line",data=losses);
!nvidia-smi
Thu Dec 9 11:54:48 2021 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 495.44 Driver Version: 495.44 CUDA Version: 11.5 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA RTX A6000 Off | 00000000:10:00.0 Off | Off | | 30% 23C P8 8W / 300W | 13701MiB / 48685MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 886056 C ...iconda3-4.7.12/bin/python 13699MiB | +-----------------------------------------------------------------------------+
占用显存明显降低到 13.7 GB左右,同时可以看到对 Loss 影响并不大。
评论(0条)