Skip to content

在前两篇文章中,我们了解了 transformers 库的生态位和如何使用参数来控制文本生成。本篇文章,我们深入其内部,探究那些参数究竟是如何驱动底层算法的。当模型完成一次前向传播,输出一个数万维度的 logits 向量后,究竟是怎样的魔法决定了下一个词元是谁?

本文将以算法的视角,解析两大核心解码策略——束搜索 (Beam Search) 和采样 (Sampling) 的内部工作流程。我们将重点剖析 Logits Processors 和 Logits Warpers 这两个关键组件,并最终厘清“模型本身”与“解码策略”之间清晰的界限。读完本文,你将对文本生成的每一个环节都了如指掌。

一、 解码的统一前奏:Logits Processors & Warpers

在任何解码策略(无论是束搜索还是采样)启动之前,模型输出的原始 logits 向量都需要经过一个两阶段的预处理流程。这个流程是保证生成质量和执行特定约束的关键。

阶段一:Logits Processors (规则执行者)

这是一系列确定性的规则,它们在采样之前对 logits 进行修改,以强制执行某些约束。可以把它们看作是“语法和规则警察”。

  • 作用:应用硬性或软性的惩罚/奖励。

常见成员

  • MinLengthLogitsProcessor:在序列长度未达到 min_length 之前,强制将 eos_token_id (结束符) 的概率设为负无穷,防止句子提前结束。
  • RepetitionPenaltyLogitsProcessor:对已经在上文中出现过的词元,降低它们的 logits 值,以减少重复 (repetition_penalty)。
  • NoRepeatNGramLogitsProcessor:如果一个 n-gram (如一个三词短语) 已经生成过,则将导致该 n-gram 再次出现的下一个词元的概率设为负无穷。
  • ForcedEOSTokenLogitsProcessor:当达到 max_length 时,强制将 eos_token_id 之外所有词元的概率设为负无穷,确保序列正常结束。

阶段二:Logits Warpers (概率控制)

在规则执行完毕后,如果启用了采样 (do_sample=True),Logits Warpers 就会登场。它们负责对概率分布进行“整形”,以控制采样的随机性。

  • 作用:改变概率分布的形状,使其更“陡峭”或更“平滑”。

常见成员

  • TemperatureLogitsWarper:通过除以 temperature 值来缩放 logits,从而改变分布的平滑度。
  • TopKLogitsWarper:将概率最低的一批词元的 logits 设为负无穷,只保留概率最高的 k 个。
  • TopPLogitsWarper (Nucleus Sampling):动态保留一个最小的词元集,使其累积概率超过 top_p

数据流总结:原始 Logits -> (Logits Processors) -> 规则约束后的 Logits -> (Logits Warpers, if sampling) -> 整形后的 Logits -> Softmax -> 最终概率分布 -> 解码策略

束搜索的核心思想是:不要过早地做出最优选择,而是在每一步都保留多种可能性。

  • 核心参数num_beams (宽度), max_new_tokens (深度), length_penalty (评分校正)。

图解步骤 (num_beams=2)

  1. 初始化:将输入 prompt 复制 num_beams 次,作为初始的 2 个“束”(beam)。

2. 第一步扩展

  • 模型对初始输入进行一次前向传播,得到 logits。
  • 应用 Logits Processors (此时通常没有采样,不经过 Warpers)。
  • 计算所有词元的 log_softmax 概率。
  • 从整个词汇表中,选出总概率(此时就是第一步的概率)最高的 2 个词元,形成 2 条新的候选序列。假设是 "I have a good" 和 "I have a nice"。

3. 第二步扩展

  • 将 "I have a good" 和 "I have a nice" 作为新的输入,分别进行前向传播,得到 2 组 logits。
  • 对于每一组 logits,计算 log_softmax 并将其加到各自路径的累计对数概率上。
  • 现在我们有了 2 (beams) * V (vocab_size) 个可能的扩展。从这成千上万个可能性中,再次选出累计总分最高的 2 条序列。

4. 循环与终止

  • 重复第 3 步,直到序列遇到 eos_token 或达到 max_length
  • 遇到 eos 的序列会被放入“已完成”的候选池中。
  • 评分校正:为了避免模型偏爱短句(因为对数概率是负数,累加越多值越小),在比较最终候选序列时,会使用 length_penalty 对分数进行归一化,公式通常为 score / (length ** length_penalty)
  1. 最终选择:从“已完成”的候选池中,选出经过长度惩罚校正后分数最高的序列作为最终输出。

深度 vs. 宽度

  • 宽度:由 num_beams 决定,即同时探索多少条路径。
  • 深度:由 max_new_tokenseos_token_id 和 early_stopping 等停止条件共同决定,即每条路径能走多远。

三、 算法详解 2:采样 (Sampling)

采样的核心思想是:根据概率分布进行随机选择,从而引入多样性和创造性。

  • 核心参数do_sample=Truetop_ptop_ktemperature
  • 图解步骤 (纯采样, num_beams=1)
  1. 初始化:输入 prompt。

2. 第一步生成

  • 模型对输入进行前向传播,得到 logits。
  • 应用 Logits Processors (如 repetition_penalty)。
  • 应用 Logits Warpers (如先用 temperature 缩放,再用 top_p 截断)。
  • 对处理后的 logits 应用 softmax 得到最终的概率分布。
  • 根据这个概率分布,随机抽样一个词元作为下一个词。假设是 "day"。

3. 第二步生成

  • 将 "I have a good day" 作为新的输入,重复上述过程,再次随机抽样下一个词元。

4. 循环与终止

  • 不断重复,直到抽样到 eos_token 或达到 max_length

整个过程就像掷一个被 top_p 等参数“动过手脚”的骰子,每一步都充满了可能性,但又被限制在了一个合理的范围内。

四、 模型与推理引擎的边界:谁负责什么?

经过以上分析,我们可以清晰地划分出模型和推理引擎(即 transformers 的 generate 函数)的责任:

  • 模型的责任单一且纯粹——提供 Logits
  • 模型本身是一个巨大的权重矩阵,它的唯一工作就是接收一串词元 ID,经过复杂的矩阵运算后,输出一个代表下一个词元概率分布的 logits 向量。模型本身并不知道什么是束搜索,也不知道什么是 top_p
  • 推理引擎/解码策略的责任执行算法,选择下一个词元
  • generate 函数扮演了推理引擎的角色。它接收模型输出的 logits,然后负责执行我们上面讨论的所有后续步骤:应用 Logits Processors 和 Warpers,实现束搜索的扩展与剪枝,或实现采样的随机抽样。

generation_config.json 的角色: 它就像一个“契约”或“推荐信”。模型作者通过这个文件,向推理引擎推荐一套他们认为最适合该模型的默认解码参数。推理引擎会读取这份推荐,但最终用户在调用 generate 时拥有最终决定权。

五、结语

generate 函数并非一个神秘的黑盒,而是一个设计精良、层次分明的算法执行器。它将模型训练与推理解码清晰地解耦:模型专注于“预测”,而解码策略专注于“选择”

  • 束搜索通过广度优先的探索和剪枝,在庞大的搜索空间中寻找一条近似最优的路径,适合需要准确性连贯性的任务。
  • 采样则通过在受控的概率分布上进行随机选择,为生成过程注入了多样性创造力,适合需要新颖性的开放式任务。

至此,我们完成了从库的生态定位,到参数的实践应用,再到底层算法原理的完整探索。

持续沉淀企业 AI 技术内容。