人工智能时代,计算大模型的资源消耗是个基础问题。无论是模型的预训练,还是垂直领域的模型微调,都需要评估显卡需求数量,其中最基础的是显存评估。
本篇我们分析下大模型训练和微调阶段需要的显存大小。
我们首先统一定义核心符号:
- N: 模型总参数量 (单位:十亿, Billion)
- L: 模型层数 (Number of Layers)
- H: 隐藏层维度 (Hidden Dimension,
d_model) - T: 序列长度 (Sequence Length)
- B: 单个GPU的微批次大小 (Micro-batch Size)
- r: LoRA的秩 (Rank)
- b(x): 数据类型
x的字节数 (FP32=4, BF16/FP16=2, INT8/FP8=1, INT4=0.5) - DP, TP, PP, D: 分别代表数据并行、张量并行、流水线并行、ZeRO数据并行的并行度
- GB: 10^9 字节 (注意:与模型参数量保持一致方便计算,本文统一按 10^9 字节估算,而
nvidia-smi等监控工具多以 GiB = 2^30 字节计量,两者间约有7%的数值偏差)
一、模型训练阶段的显存
无论何种训练方案,GPU的显存消耗都主要源于以下三个方面:
1. 模型状态 (Model States)
这是与模型参数量 N 直接相关的静态部分,在全参数训练中占据主导地位。
- **模型权重 (Model Weights,假设FP16):
N * b(FP16)=2 * N字节 - **梯度 (Gradients):
N * b(FP16)=2 * N字节 - 优化器状态 (Optimizer States): 主流的AdamW优化器(m/v共2个)通常以FP32存储,共
8 * N字节。 - 主权重 (Master Weights): 混合精度训练中为保证更新精度,常保留一份FP32权重副本,占用
4 * N字节。
综合来看,在一次标准的FP16全参数训练中,模型状态的总显存占用是一个被广泛引用的经验公式: 显存_model_states ≈ (2+2+8+4) * N = 16 * N 字节
2. 激活值 (Activations)
这是训练过程中的动态部分,是导致长序列、大批量训练显存瓶颈的关键。
估算公式:
显存_act ≈ B * T * H * L * K_factor * b(dtype)K_factor 是一个依赖于具体实现和优化技术的经验常数:
1)无任何优化:
K_factor约为 10-20。2)激活检查点 (Activation Checkpointing):
K_factor降至4-6。3)FlashAttention + 激活检查点:
K_factor可进一步降至2-4。
3. 临时工作区与碎片 (Workspace & Fragmentation)
这部分用于CUDA内核的临时存储、中间变量以及内存碎片。为保证训练稳定性,建议为此预留 10% - 20% 的总显存作为冗余。
二、分布式训练策略对显存的影响
1. 数据并行 (Data Parallelism - DP)
- 显存影响:无显存优化。每个GPU均持有完整的模型状态副本 (
16N),总体上大量增加了显存的占用;同时不适用于单卡无法加载的大模型。
2. 张量并行 (Tensor Parallelism - TP)
显存影响:
模型状态: 权重、梯度、优化器状态均被均匀切分,总占用降低为
~16 * N / TP。激活值: 权重、梯度、优化器状态近似
/TP;激活值在启用序列并行(Sequence Parallelism)等策略下可接近/TP,否则仅部分缩放,因为LayerNorm、残差连接等操作可能仍需全量激活。优点:显著降低模型和激活的显存,通信开销相对可控。
缺点:通信非常密集,对节点内GPU间的高速互联(如NVLink)依赖极高。
3. 流水线并行 (Pipeline Parallelism - PP)
显存影响:
模型状态: 每个GPU只加载部分层,占用降低为
~16 * N / PP。激活值: 每个GPU只需存储其负责层的激活。但需注意,PP的峰值激活占用与并发的流水线切分数(micro-batch chunks)线性相关,这是在提高硬件利用率与控制显存之间的权衡。
优点:大幅降低模型状态和激活显存,通信模式相对简单。
缺点:会产生“流水线气泡”,即部分GPU在等待上游数据时处于空闲状态,导致硬件利用率下降(也有很多针对的优化策略,例如1F1B以及各种衍生策略)。
4. ZeRO (Zero Redundancy Optimizer)
ZeRO-1 & 2: 分别对优化器状态、梯度进行分片,逐步减少冗余,但每个GPU仍需持有完整权重。
ZeRO-3:对所有模型状态(权重、梯度、优化器)进行分片。
单卡显存: 理论上,模型状态占用降至
~16 * N / D。重要提示: 实际峰值会略高于此理论值,因为在前向/后向传播中,需要临时的All-gather缓冲区来动态聚合当前层所需的完整权重。
业界对标: 主流PyTorch FSDP的
fully_sharded策略在理念和效果上与ZeRO-3等价。
三、主流微调方式的显存
1. 全参数微调 (Full Fine-Tuning, Full FT)
“模型状态”是显存的绝对大头。
标准AdamW:
~16N字节。8位优化器 (Adam8bit): 实践中通常仍保留FP32主权重,因此更精确的估算是
权重(2N) + 梯度(2N) + 主权重(4N) + 8位m/v(~2N)≈10N字节。适用场景: 必须依赖ZeRO-3/FSDP等分布式策略,适用于资源充足、追求极致性能的场景。
2. LoRA (Low-Rank Adaptation)
激活值成为主要显存瓶颈。
模型状态:
2N(16位基础模型) +c_opt * N_lora(LoRA训练状态)。N_lora极小。c_opt 取决于优化器:标准AdamW时为16,8位优化器时为10。
适用场景: 单张大显存GPU(≥ 24GB)微调成为可能。
3. QLoRA (Quantized LoRA)
激活值是绝对的显存瓶颈。
模型状态:
~ (0.55-0.9)N(4位基础模型及量化统计数据,常见设定约0.6N**)+c_opt * N_lora。**激活值:与全参数微调和LoRA完全相同,因计算时4位权重复原为FP16。
适用场景: 消费级GPU(≤ 16GB)微调大模型成为现实。
四、快速参考:关键公式清单
- 全参+AdamW+FP16 模型状态:
≈ 16N字节 - 全参+Adam8bit (保留主权重):
≈ 10N字节 - LoRA 参数量 (Attention+MLP经验式):
N_lora ≈ L * 18 * r * H - LoRA 训练状态:
≈ c_opt * N_lora字节 (c_opt=10或16) - QLoRA 底座:
≈ (0.55 ~ 0.9) * N字节 - 激活值:
≈ B * T * L * H * K_factor * b(FP16)(K_factor=2~16) - 推理KV缓存:
≈ B * L * 2 * T * H * b(dtype)
五、实例分析
设定: 7B模型 (L=32, H=4096), T=2048, FP16, 使用FlashAttention+激活检查点 (K_factor≈3)。
案例A: 全参数微调, B=4, AdamW, 单卡
模型状态:
16 * 7B≈ 112 GB激活值:
4*2048*4096*32*3*2≈ 6.4 GB总计: ~118.4 GB。单卡不可行。**
案例B: 全参数微调, B=4, AdamW, 8卡ZeRO-3 (D=8)
模型状态/卡:
112 GB / 8= 14 GB激活值/卡: 6.4 GB
总计/卡: ~20.4 GB (注意:实际峰值会因All-gather缓冲略高)。可在A100 40GB GPU上运行。**
案例C: LoRA, B=4, r=16, 单卡 (假设为LoRA参数使用8位优化器)
LoRA参数量
N_lora:32*18*16*4096≈ 38M模型状态:
14 GB(底座) +38M * 10(LoRA状态) ≈ 14.4 GB激活值: 6.4 GB
总计: ~20.8 GB。可在单张24GB GPU(如RTX 4090)上运行,但显存紧张。**
案例D: QLoRA, B=4, r=16, 单卡 (假设为LoRA参数使用8位优化器)
模型状态:
0.6 * 7B(常见设定) +0.4 GB≈ 4.6 GB激活值: 6.4 GB
总计: ~11.0 GB。可在单张16GB甚至12GB GPU上稳定运行。**
六、决策指南与高级技巧
显存优化顺序:
- 首选: 启用FlashAttention和激活检查点。
- 其次: 减小微批次大小 (B),配合梯度累积。
- 再次: 切换到更节省显存的微调方案 (Full FT -> LoRA -> QLoRA)。
- 最后: 减小序列长度 (T)。
高级技巧:
- CPU/NVMe Offload: 当GPU显存极度紧张时,可将优化器状态、甚至部分权重/激活卸载到CPU内存或NVMe硬盘,以通信带宽换取显存空间。这是ZeRO++和一些框架支持的高级功能。但这种能力是有代价的,代价就是显著(通常是数倍甚至一个数量级)的训练速度下降。