一、训练显存计算

1. 模型参数

  • 模型的大小直接影响显存的需求

  • 每个模型参数(如权重和偏置)都占用一定的显存

    模型参数显存占用=模型参数数量×每个参数所占字节数

  • 如果是单精度浮点数(FP32),每个参数占用4字节

  • 如果使用半精度浮动点数(FP16),每个参数占用2字节

  • 假设模型有1B个参数,使用FP32,则所需的显存为:

$$
10^9 \times 4 \text{ Bytes} = 4 \text{ GB}
$$

2. 激活值(Activations)

  • 在训练过程中,每一层神经网络都会产生“激活值”,即通过该层的输出,这些激活值必须在前向传播和反向传播过程中存储,因此会占用显存

  • 激活值的大小与每一层的输出大小和批量大小(batch size)密切相关

  • 假设模型的每层输出激活值,且批量大小为 $N$,每个样本的输入大小为 $S$(例如,假设每个输入是一个 1024 维的向量)

  • 每层的激活尺寸将随网络架构变化,为了简化,假设网络中总共有 $L$ 层,且每一层的输出和输入维度相同

  • 激活值的内存需求:

$$
\text{Activations GPU} = L \times N \times S \times 4 \text{Bytes}
$$

3. 梯度(Gradients)

  • 在反向传播过程中,优化器会计算每个参数的梯度,梯度的数量和模型参数一样多
  • 因此,梯度会占用和模型参数一样多的显存

4. 优化器状态(Optimizer States)

  • 常见的优化器(如Adam)不仅需要存储梯度,还需要存储一组额外的参数,这些优化器状态的大小取决于所使用的优化器
  • 以Adam为例,它会为每个参数存储两个额外的矩阵(一阶矩和二阶矩,即动量和梯度的平方),所以优化器状态的显存占用通常是梯度和参数显存占用的两倍

5. 批量大小(Batch Size)

  • 批量大小直接影响显存占用,因为在每次训练迭代中,所有输入数据都需要加载到显存中进行计算
  • 批量大小越大,显存需求越高
  • 例如,假设输入图像的尺寸为224x224x3,使用FP32,每个图像占用的显存为:

$$
224 \times 224 \times 3 \times 4 \text{ Bytes} = 600,192 \text{ Bytes} \approx 600 \text{ KB}
$$

  • 如果批量大小为32,则输入数据的显存占用为:

$$
32 \times 600 \text{ KB} = 19.2 \text{ MB}
$$

6. 总显存消耗计算

总显存需求 = 模型参数显存 + 激活显存 + 梯度显存 + 优化器状态显存 + 输入数据显存

  • 假设以下条件:

    • 模型有1B个参数

    • 使用FP32,每个参数、梯度占用4字节

    • 批量大小为32

    • 激活大小与输入大小类似,假设每层的激活显存占用为64KB

  • 那么显存的粗略估计为:

    • 模型参数显存:4GB

    • 梯度显存:4GB

    • 优化器状态显存:8GB

    • 输入数据显存:19.2MB

    • 激活显存(假设有5层):5 × 64KB = 320KB ≈ 0.3MB

  • 总显存需求约为:

$$
4 \text{GB} \ (\text{Model}) + 12.8 \text{MB} \ (\text{Activations}) + 4 \text{GB} \ (\text{Gradient}) + 8 \text{GB} \ (\text{Optimizer}) = 16.0128 \text{GB}
$$

7. 其他因素

  • 激活检查点(Activation Checkpointing): 如果启用激活检查点,某些激活值在反向传播时会重新计算,减少显存消耗
  • 混合精度(Mixed Precision): 使用半精度浮动点数(FP16)可以将显存需求减少一半,因为每个参数、梯度和激活值的字节数都会减少一半

二、分布式训练

1. 分布式训练的基本概念

  • 主节点 (Master Node)

    • 主节点是分布式训练系统中的一个角色,通常负责管理任务调度、参数初始化和分布式通信的协调

    • 在 PyTorch 分布式环境中,主节点通常具有以下职责:

      • 初始化分布式环境(init_process_group

      • 分发模型参数或数据到其他节点

      • 处理日志记录或统计信息

    • 在 PyTorch 中,主节点由环境变量 MASTER_ADDRMASTER_PORT 指定

    • 主节点并不一定承担实际训练任务,可能只负责通信和协调

  • 进程组 (Process Group)

    • 进程组是 PyTorch 分布式通信的基本单元,表示一组参与通信的分布式进程

      • 定义哪些进程参与分布式操作

      • 提供通信接口,如广播(Broadcast)、全归约(AllReduce)等

    • 初始化:

      • 使用 torch.distributed.init_process_group() 初始化
      • 支持多种后端,如 NCCL(GPU)、Gloo(CPU/GPU)和 MPI
    • 如果有 4 个 GPU 参与训练,可以初始化一个包含 4 个进程的进程组,每个 GPU 对应一个进程

  • Rank

    • 在分布式训练中,Rank 是用于标识每个进程的唯一编号,用于区分不同的计算任务和通信操作

      • Global Rank:在整个分布式系统中唯一标识一个进程

      • Local Rank:在某个节点(物理机器)中标识该节点内的进程编号

    • Rank 决定进程的角色,例如哪一个进程是主节点。

    • 在多 GPU 训练中,Rank 通常和 GPU ID 对应

    • 如果有 8 个 GPU 和 2 个节点,每个节点有 4 个 GPU:

      • 节点 1 的全局 Rank 为 0, 1, 2, 3。
      • 节点 2 的全局 Rank 为 4, 5, 6, 7。
  • NCCL (NVIDIA Collective Communication Library)

    • NCCL 是 NVIDIA 提供的高性能通信库,专为 GPU 集群上的深度学习训练优化

    • 提供高效的分布式通信操作,包括广播、归约(Reduce)、全归约(AllReduce)、聚合(AllGather)等

    • 利用 GPU 的高速 NVLink、PCIe 和 InfiniBand 实现低延迟、高带宽通信

      • 支持多 GPU 和多节点通信
      • 深度集成到 PyTorch 和 TensorFlow 等框架中
      • 自动优化通信拓扑,减少通信瓶颈
  • Backend

    • 分布式通信的底层实现方式,在 PyTorch 中常见的后端包括:

      • NCCL:用于 GPU 通信,性能最佳

      • Gloo:支持 CPU 和 GPU,适用于多种环境

      • MPI:使用消息传递接口,适合高性能计算环境

2. 分布式训练技术

DDP(Distributed Data Parallel)
  • 基本原理
    • 每个 GPU 保持一份模型的完整副本
    • 数据被DistributedSampler划分为多个子集,每个 GPU 处理一个子集(mini-batch)
    • 各 GPU 独立完成前向传播和梯度计算
    • 收集梯度,对梯度进行平均得到全局梯度(AllReduce 操作)
    • 通过通信同步梯度,更新模型参数(NCCL)
  • 细节
    • 优化器的随机种子也相同
  • 与DP(Data Parallel)相比
    • DP中,所有数据和梯度需通过主 GPU,导致通信开销大,主设备容易成为瓶颈
    • DP中,主 GPU 的额外显存占用导致不平衡
    • DP是单进程、多线程的,但它只能工作在单机上
FSDP(Fully Sharded Data Parallel)
  • 基本原理

    • 一种参数分片技术,为超大规模模型的分布式训练设计

    • 参数分片:模型的参数被分片存储到不同的 GPU,每个 GPU 仅存储自己负责的参数分片(Sharded Parameters)

    • 数据分发:与DDP类似

    • 前向传播:FSDP 将需要的参数分片加载到显存,在完成计算后,卸载这些参数以节省显存

    • 后向传播:梯度计算完成后,梯度也被分片,并通过 AllReduce 操作同步到所有 GPU

    • 优化器分片:优化器状态(如动量)也被分片存储并在必要时同步

  • 参数 ShardingStrategy 的不同取值决定了模型的划分方式

    • FULL_SHARD:将模型参数、梯度和优化器状态都切分到不同的GPU上,类似ZeRO-3
    • SHARD_GRAD_OP:将梯度、优化器状态切分到不同的GPU上,每个GPU仍各自保留一份完整的模型参数,类似ZeRO-2
    • NO_SHARD:不切分任何参数,类似ZeRO-0
  • 配置文件示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
fsdp_config:
world_size: 8
local_rank: 0
shard_optimizer_state: true
mixed_precision: true
fp16: true
activation_checkpointing: true
device: cuda
offload_params: true
offload_optimizer_state: false
use_reentrant: true
model_parallel_size: 2
checkpoint_interval: 1000
checkpoint_dir: ./checkpoints
  • 参数说明
参数名称 可选值 示例 含义
world_size 正整数 8 分布式训练中的总进程数,通常等于设备数
local_rank 正整数 0 当前进程的本地进程号,用于标识分布式训练中的某个节点
shard_optimizer_state truefalse true 是否对优化器状态进行切分以节省内存
mixed_precision truefalse true 是否启用混合精度训练,以提高训练速度并减少内存占用
fp16 truefalse true 是否启用16位浮动点数精度 (FP16),在一些硬件上可以显著提升训练速度
activation_checkpointing truefalse true 是否启用激活检查点,以减少显存使用
device cudacpu cuda 训练使用的设备类型,通常为 cuda(GPU)或 cpu
offload_params truefalse true 是否将模型参数卸载到CPU内存而不是GPU显存
offload_optimizer_state truefalse false 是否将优化器状态卸载到CPU内存而不是GPU显存
use_reentrant truefalse true 是否使用重入式 (reentrant) 机制,这对于某些特定的训练优化是必需的
model_parallel_size 正整数 2 模型并行的大小,指定了在多个设备之间划分模型的数量
checkpoint_interval 正整数 1000 每隔多少步进行一次检查点保存
checkpoint_dir 字符串(路径) ./checkpoints 存储检查点文件的目录路径
DeepSpeed
  • ZeRO技术(Zero Redundancy Optimizer)

    • Stage 1:分片优化器状态
    • Stage 2:分片梯度和优化器状态
    • Stage 3:分片模型参数、梯度和优化器状态
  • 模型并行

    • Pipeline Parallelism:将模型切分成多个阶段,跨设备或节点并行训练
    • Tensor Parallelism:将张量切分为多个分块,在不同设备上并行计算
  • 其他技术

    • Activation Checkpointing:节省激活存储空间
    • Offloading:将优化器状态和部分参数存储到 CPU 或 NVMe 设备,降低 GPU 显存压力
  • 配置文件示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
{
"train_batch_size": 32,
"gradient_accumulation_steps": 1,
"steps_per_print": 200,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
},
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 3e-5,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 1e-7,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},
"wall_clock_breakdown": false
}
  • 参数说明
参数名称 可选值/类型 示例 含义
train_batch_size 整数 32 每个训练步骤的批量大小
gradient_accumulation_steps 整数 1 每隔多少步进行一次梯度更新(梯度累积)
steps_per_print 整数 200 每多少步打印一次日志信息
fp16.enabled 布尔值 true 启用混合精度训练(FP16)
fp16.loss_scale 整数 0 动态损失缩放因子,通常设置为0表示自动选择
fp16.initial_scale_power 整数 16 初始损失缩放的幂
zero_optimization.stage 0, 1, 2, 3 2 Zero 优化阶段,0表示没有优化,1表示优化参数,2表示优化梯度,3表示优化参数和梯度
zero_optimization.offload_param.device cpu, nvme cpu 参数卸载到的设备类型,可以选择 cpunvme
zero_optimization.offload_optimizer.device cpu, nvme cpu 优化器状态卸载到的设备类型
optimizer.type 字符串(如 Adam, AdamW 等) Adam 使用的优化器类型
optimizer.params.lr 浮动数值 3e-5 学习率
optimizer.params.betas 数组 [0.9, 0.999] 优化器的 beta 参数
optimizer.params.eps 浮动数值 1e-8 优化器的 eps 参数
scheduler.type 字符串(如 WarmupLR, CosineAnnealingLR 等) WarmupLR 学习率调度器类型
scheduler.params.warmup_min_lr 浮动数值 1e-7 学习率的最小值(用于预热阶段)
scheduler.params.warmup_max_lr 浮动数值 3e-5 学习率的最大值(用于预热阶段)
scheduler.params.warmup_num_steps 整数 500 预热阶段的步数
wall_clock_breakdown 布尔值 false 是否打印每个阶段的壁钟时间分解
对比
特性 DDP FSDP DeepSpeed
并行方式 数据并行 数据并行 + 参数分片 数据并行 + ZeRO 分片
显存需求 高(完整模型参数和梯度) 中(参数、梯度分片) 低(参数、梯度、优化器状态分片)
通信开销 低(只同步梯度) 中(同步分片参数和梯度) 高(ZeRO 通信优化,依赖通信效率)
适用模型 中小型模型 大模型 超大模型(数百亿到万亿参数)
显存优化 参数级别分片 全量分片(参数、梯度、优化器状态)
适用场景 单机多GPU 单机/多节点 单机/多节点

3. 分布式训练库

torchrun
  • 功能

    • torchrun 是 PyTorch 提供的一个 CLI 工具,用于管理分布式训练任务,特别是基于 torch.distributed 的分布式环境
    • 负责初始化 PyTorch 的分布式训练环境
    • 自动设置分布式所需的主节点地址 (MASTER_ADDR) 和端口 (MASTER_PORT)
    • 简化进程启动,支持多 GPU 和多节点训练
    • 替代了旧的 torch.distributed.launch 工具,提供更易用和灵活的接口
  • 工作原理

    • torchrun 为每个 GPU 启动一个进程
    • 使用 NCCL(默认后端)或其他后端初始化通信
    • 进程通过 torch.distributed.init_process_group 互相通信,完成梯度同步
  • 单卡多GPU

    1
    torchrun --nproc_per_node=4 train.py
  • 多卡多GPU

    1
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" --master_port=12345 train.py
Accelerate
  • 功能

    • 自动处理多 GPU、TPU 和混合精度训练(FP16)
    • 自动设备分配:在 CPU、单 GPU 和多 GPU 环境中无缝切换
    • 支持分布式训练的关键操作,如梯度累积、参数同步等
  • 示例代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    from accelerate import Accelerator

    accelerator = Accelerator()

    # 自动分配设备
    device = accelerator.device
    model.to(device)

    # 自动包装模型和优化器
    model, optimizer = accelerator.prepare(model, optimizer)

    for batch in dataloader:
    # 自动将数据分配到正确设备
    batch = batch.to(accelerator.device)
    outputs = model(batch)

    # 自动反向传播和梯度同步
    loss = loss_fn(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()