在前两篇文章中,我们了解了 transformers 库的生态位和如何使用参数来控制文本生成。本篇文章,我们深入其内部,探究那些参数究竟是如何驱动底层算法的。当模型完成一次前向传播,输出一个数万维度的 logits 向量后,究竟是怎样的魔法决定了下一个词元是谁?
本文将以算法的视角,解析两大核心解码策略——束搜索 (Beam Search) 和采样 (Sampling) 的内部工作流程。我们将重点剖析 Logits Processors 和 Logits Warpers 这两个关键组件,并最终厘清“模型本身”与“解码策略”之间清晰的界限。读完本文,你将对文本生成的每一个环节都了如指掌。
一、 解码的统一前奏:Logits Processors & Warpers
在任何解码策略(无论是束搜索还是采样)启动之前,模型输出的原始 logits 向量都需要经过一个两阶段的预处理流程。这个流程是保证生成质量和执行特定约束的关键。
阶段一:Logits Processors (规则执行者)
这是一系列确定性的规则,它们在采样之前对 logits 进行修改,以强制执行某些约束。可以把它们看作是“语法和规则警察”。
- 作用:应用硬性或软性的惩罚/奖励。
常见成员
MinLengthLogitsProcessor:在序列长度未达到min_length之前,强制将eos_token_id(结束符) 的概率设为负无穷,防止句子提前结束。RepetitionPenaltyLogitsProcessor:对已经在上文中出现过的词元,降低它们的 logits 值,以减少重复 (repetition_penalty)。NoRepeatNGramLogitsProcessor:如果一个 n-gram (如一个三词短语) 已经生成过,则将导致该 n-gram 再次出现的下一个词元的概率设为负无穷。ForcedEOSTokenLogitsProcessor:当达到max_length时,强制将eos_token_id之外所有词元的概率设为负无穷,确保序列正常结束。
阶段二:Logits Warpers (概率控制)
在规则执行完毕后,如果启用了采样 (do_sample=True),Logits Warpers 就会登场。它们负责对概率分布进行“整形”,以控制采样的随机性。
- 作用:改变概率分布的形状,使其更“陡峭”或更“平滑”。
常见成员
TemperatureLogitsWarper:通过除以temperature值来缩放 logits,从而改变分布的平滑度。TopKLogitsWarper:将概率最低的一批词元的 logits 设为负无穷,只保留概率最高的k个。TopPLogitsWarper(Nucleus Sampling):动态保留一个最小的词元集,使其累积概率超过top_p。
数据流总结:原始 Logits -> (Logits Processors) -> 规则约束后的 Logits -> (Logits Warpers, if sampling) -> 整形后的 Logits -> Softmax -> 最终概率分布 -> 解码策略
二、 算法详解 1:束搜索 (Beam Search)
束搜索的核心思想是:不要过早地做出最优选择,而是在每一步都保留多种可能性。
- 核心参数:
num_beams(宽度),max_new_tokens(深度),length_penalty(评分校正)。
图解步骤 (num_beams=2)
- 初始化:将输入 prompt 复制
num_beams次,作为初始的 2 个“束”(beam)。
2. 第一步扩展
- 模型对初始输入进行一次前向传播,得到 logits。
- 应用
Logits Processors(此时通常没有采样,不经过 Warpers)。 - 计算所有词元的
log_softmax概率。 - 从整个词汇表中,选出总概率(此时就是第一步的概率)最高的 2 个词元,形成 2 条新的候选序列。假设是 "I have a good" 和 "I have a nice"。
3. 第二步扩展
- 将 "I have a good" 和 "I have a nice" 作为新的输入,分别进行前向传播,得到 2 组 logits。
- 对于每一组 logits,计算
log_softmax并将其加到各自路径的累计对数概率上。 - 现在我们有了
2 (beams) * V (vocab_size)个可能的扩展。从这成千上万个可能性中,再次选出累计总分最高的 2 条序列。
4. 循环与终止
- 重复第 3 步,直到序列遇到
eos_token或达到max_length。 - 遇到
eos的序列会被放入“已完成”的候选池中。 - 评分校正:为了避免模型偏爱短句(因为对数概率是负数,累加越多值越小),在比较最终候选序列时,会使用
length_penalty对分数进行归一化,公式通常为score / (length ** length_penalty)。
- 最终选择:从“已完成”的候选池中,选出经过长度惩罚校正后分数最高的序列作为最终输出。
深度 vs. 宽度:
- 宽度:由
num_beams决定,即同时探索多少条路径。 - 深度:由
max_new_tokens、eos_token_id和early_stopping等停止条件共同决定,即每条路径能走多远。
三、 算法详解 2:采样 (Sampling)
采样的核心思想是:根据概率分布进行随机选择,从而引入多样性和创造性。
- 核心参数:
do_sample=True,top_p,top_k,temperature。 - 图解步骤 (纯采样,
num_beams=1):
- 初始化:输入 prompt。
2. 第一步生成
- 模型对输入进行前向传播,得到 logits。
- 应用
Logits Processors(如repetition_penalty)。 - 应用
Logits Warpers(如先用temperature缩放,再用top_p截断)。 - 对处理后的 logits 应用
softmax得到最终的概率分布。 - 根据这个概率分布,随机抽样一个词元作为下一个词。假设是 "day"。
3. 第二步生成
- 将 "I have a good day" 作为新的输入,重复上述过程,再次随机抽样下一个词元。
4. 循环与终止
- 不断重复,直到抽样到
eos_token或达到max_length。
整个过程就像掷一个被 top_p 等参数“动过手脚”的骰子,每一步都充满了可能性,但又被限制在了一个合理的范围内。
四、 模型与推理引擎的边界:谁负责什么?
经过以上分析,我们可以清晰地划分出模型和推理引擎(即 transformers 的 generate 函数)的责任:
- 模型的责任:单一且纯粹——提供 Logits。
- 模型本身是一个巨大的权重矩阵,它的唯一工作就是接收一串词元 ID,经过复杂的矩阵运算后,输出一个代表下一个词元概率分布的 logits 向量。模型本身并不知道什么是束搜索,也不知道什么是
top_p。 - 推理引擎/解码策略的责任:执行算法,选择下一个词元。
generate函数扮演了推理引擎的角色。它接收模型输出的 logits,然后负责执行我们上面讨论的所有后续步骤:应用Logits Processors和Warpers,实现束搜索的扩展与剪枝,或实现采样的随机抽样。
generation_config.json 的角色: 它就像一个“契约”或“推荐信”。模型作者通过这个文件,向推理引擎推荐一套他们认为最适合该模型的默认解码参数。推理引擎会读取这份推荐,但最终用户在调用 generate 时拥有最终决定权。
五、结语
generate 函数并非一个神秘的黑盒,而是一个设计精良、层次分明的算法执行器。它将模型训练与推理解码清晰地解耦:模型专注于“预测”,而解码策略专注于“选择”。
- 束搜索通过广度优先的探索和剪枝,在庞大的搜索空间中寻找一条近似最优的路径,适合需要准确性和连贯性的任务。
- 采样则通过在受控的概率分布上进行随机选择,为生成过程注入了多样性和创造力,适合需要新颖性的开放式任务。
至此,我们完成了从库的生态定位,到参数的实践应用,再到底层算法原理的完整探索。