Skip to content

C.6 Top-k / Top-p Sampling 与 Temperature

解码策略是 LLM 面试的常考题,和 RL 直接相关(RLHF 训练后的模型怎么采样、temperature 如何影响策略分布)。


Temperature

一句话记忆

logits 除以 T 再 softmax。T 大 → 更随机,T 小 → 更确定。

伪代码

scaled_logits = logits / temperature
probs = softmax(scaled_logits)

记忆方法

  • T0T \to 0:趋向 argmax(贪婪),相当于确定性策略
  • T=1T = 1:原始分布
  • TT \to \infty:趋向均匀分布,相当于随机策略

RL 视角:temperature 就是策略的探索程度。

PyTorch 实现

python
def sample_with_temperature(logits, temperature=1.0):
    if temperature < 1e-8:
        return logits.argmax(dim=-1)  # 贪婪
    scaled = logits / temperature
    probs = torch.softmax(scaled, dim=-1)
    return torch.multinomial(probs, num_samples=1)

Top-k Sampling

一句话记忆

只保留概率最大的 k 个 token,其余设为 -inf,再 softmax 归一化采样。

伪代码

top_k_values, top_k_indices = topk(logits, k)
logits[not in top_k] = -inf
probs = softmax(logits)
sample from probs

Python 实现

python
import numpy as np

def top_k_filtering(logits, k):
    """
    logits: [vocab_size]
    返回: 过滤后的 logits(非 top-k 位置为 -inf)
    """
    if k >= len(logits):
        return logits
    threshold = np.sort(logits)[-k]  # 第 k 大的值
    logits_filtered = np.where(logits >= threshold, logits, -np.inf)
    return logits_filtered

PyTorch 实现

python
import torch

def top_k_filtering(logits, k):
    """
    logits: [B, vocab_size] 或 [vocab_size]
    """
    if k <= 0:
        return logits
    top_k = min(k, logits.size(-1))
    # 找到第 k 大的值作为阈值
    threshold = torch.topk(logits, top_k, dim=-1).values[..., -1:]
    return logits.masked_fill(logits < threshold, float('-inf'))

def top_k_sample(logits, k, temperature=1.0):
    logits = top_k_filtering(logits, k)
    probs = torch.softmax(logits / temperature, dim=-1)
    return torch.multinomial(probs, num_samples=1)

Top-p (Nucleus) Sampling

一句话记忆

把 token 按概率从大到小排,累加到 p 就停,只保留前这些。

伪代码

sorted_logits = sort_desc(logits)
sorted_probs = softmax(sorted_logits)
cumulative_probs = cumsum(sorted_probs)

# 累积概率超过 p 的位置设为 -inf
cutoff_mask = cumulative_probs - sorted_probs > p
sorted_logits[cutoff_mask] = -inf

# 还原顺序,softmax,采样

记忆方法

Top-k 是固定数量,Top-p 是固定概率质量。Top-p 更灵活:确定性强时选几个 token 就够了,不确定时可能需要很多 token 才能凑够 p。

面试常问区别:

Top-kTop-p
筛选依据固定保留 k 个保留概率质量前 p
适应性不随分布变化自动适应分布尖锐程度
极端情况k=1 → 贪婪p=0 → 贪婪,p=1 → 不限制

Python 实现

python
import numpy as np

def top_p_filtering(logits, p):
    """
    logits: [vocab_size]
    """
    sorted_indices = np.argsort(logits)[::-1]  # 降序
    sorted_logits = logits[sorted_indices]
    sorted_probs = np.exp(sorted_logits - sorted_logits.max())
    sorted_probs = sorted_probs / sorted_probs.sum()
    cumulative_probs = np.cumsum(sorted_probs)

    # 找到累积概率超过 p 的位置(保留至少一个 token)
    cutoff = cumulative_probs - sorted_probs > p
    sorted_logits[cutoff] = -np.inf

    # 还原顺序
    result = np.empty_like(logits)
    result[sorted_indices] = sorted_logits
    return result

PyTorch 实现

python
import torch

def top_p_filtering(logits, p):
    """
    logits: [B, vocab_size]
    """
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    sorted_probs = torch.softmax(sorted_logits, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # 移除累积概率超过 p 的 token(保留至少一个)
    sorted_mask = (cumulative_probs - sorted_probs) > p
    sorted_logits[sorted_mask] = float('-inf')

    # 还原原始顺序
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

def top_p_sample(logits, p, temperature=1.0):
    logits = top_p_filtering(logits, p)
    probs = torch.softmax(logits / temperature, dim=-1)
    return torch.multinomial(probs, num_samples=1)

实际组合用法

工程中通常 Top-k + Top-p + Temperature 一起用:

python
def generate_sample(logits, temperature=1.0, top_k=50, top_p=0.9):
    # 1. Temperature 缩放
    logits = logits / max(temperature, 1e-8)
    # 2. Top-k 过滤
    logits = top_k_filtering(logits, top_k)
    # 3. Top-p 过滤
    logits = top_p_filtering(logits, top_p)
    # 4. 采样
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

易错点

易错说明
Top-p 的 cumsum 方向必须降序排列后再 cumsum,升序没有意义
Top-p 保留至少一个cumsum - current_prob > p 而不是 cumsum > p,否则第一个 token 可能被误杀
Top-k 的阈值topk().values[..., -1] 取第 k 大的值,不是 sort 后取 index
还原顺序Top-p 排序后要 scatter 回原位,忘了还原会导致采样错乱
Temperature=0要特殊处理为 argmax,不能真的除以 0

Built for reusable bilingual course delivery