1387 words
7 minutes
Decoding Strategies
首次发布: 2025-06-03
... 次访问

Decoding strategies are the decision rules that turn a model’s next-token scores into actual tokens. The model provides a distribution over the vocabulary; the decoder decides how to pick the next token, how to stop, and how to trade off diversity vs. reliability.

This note focuses on practical, commonly used strategies:

  • Deterministic: greedy decoding, beam search
  • Stochastic: sampling, temperature, top-kk, top-pp (nucleus)
  • Controls and constraints: repetition penalties, no-repeat nn-grams, stop sequences
  • A short addendum: contrastive search and speculative decoding

1. Notation: what the model actually outputs#

Let x1:tx_{1:t} be the tokens generated so far. A causal language model (e.g., GPT-2-like) outputs logits zRVz \in \mathbb{R}^{|V|} for the next token:

z=fθ(x1:t)z = f_\theta(x_{1:t})

These logits become a probability distribution by softmax:

p(xt+1=vx1:t)=exp(zv)uVexp(zu)p(x_{t+1}=v\mid x_{1:t}) = \frac{\exp(z_v)}{\sum_{u\in V}\exp(z_u)}

Important detail:

  • model.generate(...) returns token ids (the chosen tokens), not logits.
  • To get logits/scores during generation in HuggingFace, you must request them (see code below).

2. A vivid example (correctly extracting logits)#

We feed GPT-2 with the prompt "The cat ran after a" and generate 4 tokens. We will print:

  • the chosen next token ids (what decoding decided)
  • the logits distribution for each step (what the model predicted)
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model.eval()

prompt = "The cat ran after a"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

with torch.no_grad():
    out = model.generate(
        input_ids,
        max_new_tokens=4,
        do_sample=False,  # greedy by default
        return_dict_in_generate=True,
        output_scores=True,
    )

sequences = out.sequences[0]                 # token ids (prompt + generated)
new_token_ids = sequences[-4:]               # only the 4 new tokens

print("New token ids:", new_token_ids.tolist())
print("Generated text:", tokenizer.decode(sequences, skip_special_tokens=True))

# out.scores is a list of length = max_new_tokens
# each element is logits for that step: [batch, vocab]
for step, logits in enumerate(out.scores, start=1):
    topk = torch.topk(logits[0], k=5)
    ids = topk.indices.tolist()
    vals = topk.values.tolist()
    toks = [tokenizer.decode([i]) for i in ids]
    print(f"Step {step} top-5 tokens:")
    for tok, logit in zip(toks, vals):
        print(f"  {tok!r:>10}  logit={logit:8.3f}")

Conceptually:

  • The model computes logits zz.
  • The decoder turns logits into the next token (greedy, beam, sampling, …).

3. Temperature: sharpening or flattening the distribution#

Before choosing a token, it is common to apply temperature scaling T>0T>0:

pT(vx1:t)=softmax(zT)vp_T(v \mid x_{1:t}) = \text{softmax}\left(\frac{z}{T}\right)_v

Effects:

  • T0+T \to 0^+: distribution becomes very peaked (approaches greedy / argmax)
  • T=1T = 1: original distribution
  • T>1T > 1: flatter distribution (more randomness)

In practice, temperature is one of the most important knobs for “creativity vs. correctness”.

4. Greedy decoding (deterministic)#

4.1 Definition#

Greedy decoding picks the most likely next token at every step:

xt+1=argmaxvV  p(vx1:t)x_{t+1} = \arg\max_{v\in V} \; p(v\mid x_{1:t})

4.2 Pros / cons#

  • Pros: fast, stable, often best for “exactness” tasks with strong local signal.
  • Cons: can be repetitive; can miss globally better continuations; brittle for open-ended generation.

4.3 Minimal “manual greedy” code#

This shows what generation is doing internally (one token at a time):

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model.eval()

prompt = "The cat ran after a"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

max_new_tokens = 4
with torch.no_grad():
    for _ in range(max_new_tokens):
        logits = model(input_ids).logits  # [batch, seq_len, vocab]
        next_logits = logits[:, -1, :]    # [batch, vocab]
        next_id = torch.argmax(next_logits, dim=-1, keepdim=True)
        input_ids = torch.cat([input_ids, next_id], dim=1)

print(tokenizer.decode(input_ids[0], skip_special_tokens=True))

5. Beam search (deterministic, “global-ish”)#

Greedy decoding is locally optimal. Beam search keeps multiple hypotheses.

5.1 Objective#

An autoregressive model defines sequence probability:

P(x1:T)=t=1TP(xtx1:t1)P(x_{1:T}) = \prod_{t=1}^{T} P(x_t \mid x_{1:t-1})

It is numerically convenient to maximize log-probability:

logP(x1:T)=t=1TlogP(xtx1:t1)\log P(x_{1:T}) = \sum_{t=1}^{T} \log P(x_t \mid x_{1:t-1})

Beam search approximates:

argmaxx1:TlogP(x1:T)\arg\max_{x_{1:T}} \log P(x_{1:T})

by keeping only the top BB partial sequences at each step.

5.2 Length normalization / length penalty#

Pure log-probability tends to favor shorter sequences (because it sums negative numbers). A common fix is a length penalty. One popular form (used in NMT) is:

extscore(x1:T)=logP(x1:T)(5+T5+1)αext{score}(x_{1:T}) = \frac{\log P(x_{1:T})}{\left(\frac{5+T}{5+1}\right)^\alpha}

where α[0,2]\alpha \in [0, 2] controls how much we prefer longer sequences.

5.3 Pros / cons#

  • Pros: good for tasks with a “single best answer” (translation, some summarization, constrained generation).
  • Cons: for open-ended text, can increase dullness and repetition; more compute: roughly BB times the work.

5.4 HuggingFace usage#

out = model.generate(
    input_ids,
    max_new_tokens=80,
    num_beams=4,
    do_sample=False,
    early_stopping=True,
    length_penalty=1.0,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))

6. Sampling (stochastic)#

Sampling draws the next token from the probability distribution instead of always taking the max:

xt+1Categorical(p(x1:t))x_{t+1} \sim \text{Categorical}(p(\cdot\mid x_{1:t}))

Sampling is the main tool for diversity.

6.1 Why naive sampling can be bad#

The long tail of the vocabulary contains many low-probability tokens that can derail coherence. Practical decoders therefore often truncate the distribution before sampling.

7. Top-kk sampling#

Top-kk sampling keeps only the kk most probable tokens and renormalizes:

  1. Let SkS_k be the set of the kk tokens with highest probability.
  2. Define p(v)p(v)p'(v) \propto p(v) for vSkv\in S_k, and p(v)=0p'(v)=0 otherwise.
  3. Sample xt+1Categorical(p)x_{t+1} \sim \text{Categorical}(p').

Intuition:

  • Small kk behaves more like greedy.
  • Large kk approaches naive sampling.

HuggingFace:

out = model.generate(
    input_ids,
    max_new_tokens=120,
    do_sample=True,
    top_k=50,
    temperature=0.8,
)

8. Top-pp (nucleus) sampling#

Top-pp chooses the smallest token set whose cumulative probability is at least pp.

Let tokens be sorted by probability: p1p2p_1 \ge p_2 \ge \cdots. Choose the smallest mm such that:

i=1mpip\sum_{i=1}^{m} p_i \ge p

Then sample from those mm tokens after renormalization.

Why this is often better than top-kk:

  • The “right” number of plausible tokens changes with context.
  • Top-pp adapts the candidate set size automatically.

HuggingFace:

out = model.generate(
    input_ids,
    max_new_tokens=120,
    do_sample=True,
    top_p=0.9,
    temperature=0.8,
)

Rule of thumb: start with top_p=0.9 and tune temperature.

9. A single-step implementation: temperature + top-kk + top-pp#

The following helper illustrates the core mechanics.

import torch

def sample_next_token(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: int | None = None,
    top_p: float | None = None,
):
    # logits: [vocab]
    if temperature <= 0:
        raise ValueError("temperature must be > 0")

    logits = logits / temperature

    # Top-k truncation
    if top_k is not None and top_k > 0:
        values, _ = torch.topk(logits, k=min(top_k, logits.numel()))
        cutoff = values[-1]
        logits = torch.where(logits < cutoff, torch.tensor(float("-inf"), device=logits.device), logits)

    # Convert to probs for top-p
    probs = torch.softmax(logits, dim=-1)

    if top_p is not None and 0 < top_p < 1:
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cdf = torch.cumsum(sorted_probs, dim=-1)
        keep = cdf <= top_p
        # ensure at least 1 token
        keep[..., 0] = True

        filtered = torch.zeros_like(probs)
        filtered[sorted_idx[keep]] = probs[sorted_idx[keep]]
        probs = filtered / filtered.sum()

    next_id = torch.multinomial(probs, num_samples=1)
    return next_id

This is essentially what “sampling” decoders do, with a few extra details for speed and numerical stability.

10. Preventing repetition and degeneration#

Many decoding failures come from degeneration: loops, repeated phrases, or bland continuations.

10.1 Repetition penalty#

In practice, decoders often penalize tokens already used in the context (or recently used). HuggingFace implements a common heuristic repetition_penalty.

out = model.generate(
    input_ids,
    max_new_tokens=160,
    do_sample=True,
    top_p=0.9,
    temperature=0.8,
    repetition_penalty=1.1,
)

10.2 No-repeat nn-gram constraint#

For a chosen nn, forbid generating any nn-gram that already appeared.

out = model.generate(
    input_ids,
    max_new_tokens=160,
    do_sample=True,
    top_p=0.9,
    no_repeat_ngram_size=3,
)

This is a hard constraint (can improve readability, but may overconstrain and reduce fluency).

10.3 Stop sequences / EOS#

Decoding also decides when to stop. Common stop rules:

  • Emit an EOS token
  • Hit max_new_tokens
  • Match a stop string (e.g., "\n\n", "</answer>")

In pure HuggingFace generate, stop strings require custom stopping criteria.

11. Choosing a strategy: practical guidance#

There is no universally best decoder; choose based on task.

11.1 “Single correct answer” tasks#

Examples: translation, extraction with strict schema, deterministic tool calls.

  • Prefer: greedy or beam search
  • Typical settings: do_sample=False, num_beams=1..5
  • Add constraints: stop tokens, formatting rules, sometimes no_repeat_ngram_size

11.2 “Helpful but varied” tasks#

Examples: brainstorming, writing, ideation.

  • Prefer: nucleus sampling + moderate temperature
  • Typical settings: top_p=0.9, temperature=0.7..1.0, repetition_penalty=1.05..1.2

11.3 “Code generation”#

Code is sensitive to small mistakes and needs consistency.

  • Often good: lower temperature, smaller nucleus
  • Typical settings: temperature=0.2..0.6, top_p=0.8..0.95
  • Add: stop sequences (e.g., stop after closing triple backticks), constraints if possible

11.4 A quick parameter cheat sheet#

  • Too boring / repetitive: increase temperature slightly, increase top_p, add mild repetition_penalty
  • Too random / incoherent: decrease temperature, decrease top_p, consider greedy
  • Too short: increase max_new_tokens, consider a positive length penalty in beam search

12. Two modern add-ons (optional but useful)#

12.1 Contrastive search (quality vs repetition)#

Contrastive search tries to balance:

  • high probability under the model (coherence)
  • low similarity to recent hidden states (avoid repetition)

HuggingFace has a mode for it:

out = model.generate(
    input_ids,
    max_new_tokens=160,
    penalty_alpha=0.6,
    top_k=4,
)

This can produce less repetitive text than greedy/beam without full stochastic sampling.

12.2 Speculative decoding (speed, not a style)#

Speculative decoding uses a small “draft” model to propose multiple tokens and a larger model to verify them. This accelerates generation but does not change the intended distribution if implemented correctly. It’s mostly an engineering optimization.

13. Summary#

  • The model outputs logits; decoding decides what tokens to emit.
  • Greedy: fastest and stable, but can be dull.
  • Beam search: better for structured tasks; can be dull for open-ended text.
  • Sampling (top-kk/top-pp + temperature): best for diversity; needs controls to avoid incoherence.
  • Repetition controls and stopping rules are part of decoding in practice.

If you want, I can also add a small section comparing these methods on the same prompt with side-by-side outputs (and plots of entropy / top-pp set size) to make the differences visually obvious.

Comments Section