Pytorch爆显存的一些常见解决办法

·3595·8 分钟·
AI摘要: 本文针对PyTorch训练过程中显存爆炸的问题提供了两种解决方案:一是通过`torch.cuda.memory_allocated(device)`逐步骤监控显存占用情况以定位瓶颈;二是分析计算图累积机制,指出在循环中不当操作(如append/累加loss)会导致历史变量长期驻留内存,并强调需使用detach或item主动释放张量而非仅依赖no_grad()。文中还展示了自定义Trainer类的实现案例,说明如何避免hidden_states等中间结果被意外保留。

torch.cuda.memory_allocated(device)

通过在Forward的每个步骤log出此时占用的显存,就能知道哪一步的显存占用最大。

for i, batch in enumerate(train_loader):
    print("1:", torch.cuda.memory_allocated(0))

    outputs = model(**batch)
    print("2:", torch.cuda.memory_allocated(0))

    loss = outputs.loss
    print("3:", torch.cuda.memory_allocated(0))

    loss.backward()
    print("4:", torch.cuda.memory_allocated(0))

    optimizer.step()
    optimizer.zero_grad()
    print("5:", torch.cuda.memory_allocated(0))

累计计算图

Pytorch中的很多变量背后都是放在计算图中,因此很容易误导致计算图累计。

比如在for循环中append loss,或者累加loss,那么此时,就会导致计算图累计,之前的所有变量都将保存,最终爆显存。

with torch.no_grad()可以组织计算梯度,但是并不会释放累计的张量,只有detach或者item才能。



import swanlab
from swanlab.integration.transformers import SwanLabCallback
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer
from torch.utils.data import DataLoader, Sampler
import torch
import math
import random


swanlab_callback = SwanLabCallback(
    project="map", 
    experiment_name=config['name'],
)

callbacks = []

if ENV == "SERVER":
    callbacks.append(swanlab_callback)



class SupConTrainer(Trainer):
    def __init__(self, contrastive_weight=0.1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.contrastive_weight = contrastive_weight

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):

        # ⚡ 训练阶段需要 hidden states
        # print("1:", torch.cuda.memory_allocated(0))
        # 标准的 forward
        outputs = model(**inputs, output_hidden_states=True)
        # print("outputs:", outputs)
        # print("2:", torch.cuda.memory_allocated(0))
        logits = outputs.logits
        labels = inputs["labels"]

        # -------- 1. 交叉熵损失 --------
        ce_loss = F.cross_entropy(logits, labels)
        # print("3:", torch.cuda.memory_allocated(0))
        # -------- 2. 对比学习损失 --------
        # 取最后一层 hidden states (batch_size, hidden_dim)
        hidden_states = outputs.hidden_states[-1][:, 0, :]  # [CLS]向量

        # 单位化
        hidden_states = F.normalize(hidden_states, dim=-1)
        # print("4:", torch.cuda.memory_allocated(0))
        # 相似度矩阵 (batch, batch)
        similarity_matrix = torch.matmul(hidden_states, hidden_states.T)
        # print("5:", torch.cuda.memory_allocated(0))
        # 只取同类样本作为正样本
        mask = labels.unsqueeze(0) == labels.unsqueeze(1)  # (batch, batch)

        # InfoNCE / SupCon 损失
        logits_contrastive = similarity_matrix / 0.1  # 温度参数 0.1
        contrastive_loss = -torch.log_softmax(logits_contrastive, dim=1)[mask].mean()
        # print("6:", torch.cuda.memory_allocated(0))
        # -------- 3. 总损失 --------
        loss = ce_loss + self.contrastive_weight * contrastive_loss
        # print("7:", torch.cuda.memory_allocated(0))
        # else:
        #     # ⚡ eval阶段只需要 logits,避免返回hidden_states
        #     outputs = model(**inputs, output_hidden_states=False)
        #     logits = outputs.logits
        #     labels = inputs["labels"]
        #     loss = self.ce_loss(logits, labels)

        outputs = {
            "loss": outputs['loss'],
            "logits": logits,
        }

        return (loss, outputs) if return_outputs else loss
        

trainer = SupConTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_map3,
    # data_collator=data_collator,
    callbacks=callbacks,
    contrastive_weight= config['contrastive_weight'],  # 控制对比损失权重
)


trainer.train()

上面是魔改的trainer函数,如果outputs不修改,那么就会将hidden_state保留并传入到外部,导致hidden_state不断累计,最终OOM。