Skip to content

人工智能时代,计算大模型的资源消耗是个基础问题。无论是模型的预训练,还是垂直领域的模型微调,都需要评估显卡需求数量,其中最基础的是显存评估。

本篇我们分析下大模型训练和微调阶段需要的显存大小。

我们首先统一定义核心符号:

  • N: 模型总参数量 (单位:十亿, Billion)
  • L: 模型层数 (Number of Layers)
  • H: 隐藏层维度 (Hidden Dimension, d_model)
  • T: 序列长度 (Sequence Length)
  • B: 单个GPU的微批次大小 (Micro-batch Size)
  • r: LoRA的秩 (Rank)
  • b(x): 数据类型 x 的字节数 (FP32=4, BF16/FP16=2, INT8/FP8=1, INT4=0.5)
  • DP, TP, PP, D: 分别代表数据并行、张量并行、流水线并行、ZeRO数据并行的并行度
  • GB: 10^9 字节 (注意:与模型参数量保持一致方便计算,本文统一按 10^9 字节估算,而nvidia-smi等监控工具多以 GiB = 2^30 字节计量,两者间约有7%的数值偏差)

一、模型训练阶段的显存

无论何种训练方案,GPU的显存消耗都主要源于以下三个方面:

1. 模型状态 (Model States)

这是与模型参数量 N 直接相关的静态部分,在全参数训练中占据主导地位。

  • **模型权重 (Model Weights,假设FP16): N * b(FP16) = 2 * N 字节
  • **梯度 (Gradients): N * b(FP16) = 2 * N 字节
  • 优化器状态 (Optimizer States): 主流的AdamW优化器(m/v共2个)通常以FP32存储,共8 * N 字节
  • 主权重 (Master Weights): 混合精度训练中为保证更新精度,常保留一份FP32权重副本,占用4 * N 字节

综合来看,在一次标准的FP16全参数训练中,模型状态的总显存占用是一个被广泛引用的经验公式: 显存_model_states ≈ (2+2+8+4) * N = 16 * N 字节

2. 激活值 (Activations)

这是训练过程中的动态部分,是导致长序列、大批量训练显存瓶颈的关键。

  • 估算公式: 显存_act ≈ B * T * H * L * K_factor * b(dtype)

  • K_factor 是一个依赖于具体实现和优化技术的经验常数:

  • 1)无任何优化: K_factor 约为 10-20。

  • 2)激活检查点 (Activation Checkpointing): K_factor 降至4-6

  • 3)FlashAttention + 激活检查点: K_factor 可进一步降至2-4

3. 临时工作区与碎片 (Workspace & Fragmentation)

这部分用于CUDA内核的临时存储、中间变量以及内存碎片。为保证训练稳定性,建议为此预留 10% - 20% 的总显存作为冗余。

二、分布式训练策略对显存的影响

1. 数据并行 (Data Parallelism - DP)

  • 显存影响:无显存优化。每个GPU均持有完整的模型状态副本 (16N),总体上大量增加了显存的占用;同时不适用于单卡无法加载的大模型。

2. 张量并行 (Tensor Parallelism - TP)

  • 显存影响:

  • 模型状态: 权重、梯度、优化器状态均被均匀切分,总占用降低为~16 * N / TP

  • 激活值: 权重、梯度、优化器状态近似 /TP激活值在启用序列并行(Sequence Parallelism)等策略下可接近 /TP,否则仅部分缩放,因为LayerNorm、残差连接等操作可能仍需全量激活。

  • 优点:显著降低模型和激活的显存,通信开销相对可控。

  • 缺点:通信非常密集,对节点内GPU间的高速互联(如NVLink)依赖极高。

3. 流水线并行 (Pipeline Parallelism - PP)

  • 显存影响:

  • 模型状态: 每个GPU只加载部分层,占用降低为~16 * N / PP

  • 激活值: 每个GPU只需存储其负责层的激活。但需注意,PP的峰值激活占用与并发的流水线切分数(micro-batch chunks)线性相关,这是在提高硬件利用率与控制显存之间的权衡。

  • 优点:大幅降低模型状态和激活显存,通信模式相对简单。

  • 缺点:会产生“流水线气泡”,即部分GPU在等待上游数据时处于空闲状态,导致硬件利用率下降(也有很多针对的优化策略,例如1F1B以及各种衍生策略)。

4. ZeRO (Zero Redundancy Optimizer)

  • ZeRO-1 & 2: 分别对优化器状态、梯度进行分片,逐步减少冗余,但每个GPU仍需持有完整权重。

  • ZeRO-3:对所有模型状态(权重、梯度、优化器)进行分片

  • 单卡显存: 理论上,模型状态占用降至~16 * N / D

  • 重要提示: 实际峰值会略高于此理论值,因为在前向/后向传播中,需要临时的All-gather缓冲区来动态聚合当前层所需的完整权重。

  • 业界对标: 主流PyTorch FSDP的fully_sharded策略在理念和效果上与ZeRO-3等价。

三、主流微调方式的显存

1. 全参数微调 (Full Fine-Tuning, Full FT)

  • “模型状态”是显存的绝对大头。

  • 标准AdamW: ~16N 字节。

  • 8位优化器 (Adam8bit): 实践中通常仍保留FP32主权重,因此更精确的估算是 权重(2N) + 梯度(2N) + 主权重(4N) + 8位m/v(~2N) ≈10N 字节

  • 适用场景: 必须依赖ZeRO-3/FSDP等分布式策略,适用于资源充足、追求极致性能的场景。

2. LoRA (Low-Rank Adaptation)

  • 激活值成为主要显存瓶颈。

  • 模型状态: 2N (16位基础模型) + c_opt * N_lora (LoRA训练状态)。N_lora极小。

  • c_opt 取决于优化器:标准AdamW时为16,8位优化器时为10。

  • 适用场景: 单张大显存GPU(≥ 24GB)微调成为可能。

3. QLoRA (Quantized LoRA)

  • 激活值是绝对的显存瓶颈。

  • 模型状态: ~ (0.55-0.9)N (4位基础模型及量化统计数据,常见设定约0.6N**)+ c_opt * N_lora。**

  • 激活值:与全参数微调和LoRA完全相同,因计算时4位权重复原为FP16。

  • 适用场景: 消费级GPU(≤ 16GB)微调大模型成为现实。

四、快速参考:关键公式清单

  • 全参+AdamW+FP16 模型状态: ≈ 16N 字节
  • 全参+Adam8bit (保留主权重): ≈ 10N 字节
  • LoRA 参数量 (Attention+MLP经验式): N_lora ≈ L * 18 * r * H
  • LoRA 训练状态: ≈ c_opt * N_lora 字节 (c_opt=10或16)
  • QLoRA 底座: ≈ (0.55 ~ 0.9) * N 字节
  • 激活值: ≈ B * T * L * H * K_factor * b(FP16) (K_factor=2~16)
  • 推理KV缓存: ≈ B * L * 2 * T * H * b(dtype)

五、实例分析

设定: 7B模型 (L=32, H=4096), T=2048, FP16, 使用FlashAttention+激活检查点 (K_factor≈3)。

  • 案例A: 全参数微调, B=4, AdamW, 单卡

  • 模型状态: 16 * 7B ≈ 112 GB

  • 激活值: 4*2048*4096*32*3*2 ≈ 6.4 GB

  • 总计: ~118.4 GB。单卡不可行。**

  • 案例B: 全参数微调, B=4, AdamW, 8卡ZeRO-3 (D=8)

  • 模型状态/卡: 112 GB / 8 = 14 GB

  • 激活值/卡: 6.4 GB

  • 总计/卡: ~20.4 GB (注意:实际峰值会因All-gather缓冲略高)。可在A100 40GB GPU上运行。**

  • 案例C: LoRA, B=4, r=16, 单卡 (假设为LoRA参数使用8位优化器)

  • LoRA参数量 N_lora32*18*16*4096 ≈ 38M

  • 模型状态: 14 GB (底座) + 38M * 10 (LoRA状态) ≈ 14.4 GB

  • 激活值: 6.4 GB

  • 总计: ~20.8 GB。可在单张24GB GPU(如RTX 4090)上运行,但显存紧张。**

  • 案例D: QLoRA, B=4, r=16, 单卡 (假设为LoRA参数使用8位优化器)

  • 模型状态: 0.6 * 7B (常见设定) + 0.4 GB ≈ 4.6 GB

  • 激活值: 6.4 GB

  • 总计: ~11.0 GB。可在单张16GB甚至12GB GPU上稳定运行。**

六、决策指南与高级技巧

显存优化顺序:

  • 首选: 启用FlashAttention激活检查点
  • 其次: 减小微批次大小 (B),配合梯度累积。
  • 再次: 切换到更节省显存的微调方案 (Full FT -> LoRA -> QLoRA)。
  • 最后: 减小序列长度 (T)

高级技巧:

  • CPU/NVMe Offload: 当GPU显存极度紧张时,可将优化器状态、甚至部分权重/激活卸载到CPU内存或NVMe硬盘,以通信带宽换取显存空间。这是ZeRO++和一些框架支持的高级功能。但这种能力是有代价的,代价就是显著(通常是数倍甚至一个数量级)的训练速度下降。
返回专题 · AI 工程落地上一篇:【AI项目落地】一文教你精准评估大模型推理成本下一篇:【AI项目落地】揭秘大模型推理的“省钱”黑科技:KV量化

持续沉淀企业 AI 技术内容。