1536 words
8 minutes
FlowRL - Matching Reward Distributions for LLM Reasoning
首次发布: 2026-02-09
... 次访问

FlowRL: Matching Reward Distributions for LLM Reasoning#

1. Motivation & Problem Definition#

1.1 Background and Limitations of Traditional RL in LLM Reasoning#

Large language model (LLM) reasoning is typically formulated as a conditional generation problem: given a question xX\mathbf{x} \in \mathcal{X}, a policy model πθ(yx)\pi_{\theta}(\mathbf{y}|\mathbf{x}) generates an answer yY\mathbf{y} \in \mathcal{Y}. The quality of the answer is evaluated by a task-specific reward signal r(x,y)r(\mathbf{x}, \mathbf{y}). In reasoning tasks, the reward is usually sparse and terminal (e.g., correctness of the final answer), which means we consider one-step reward instead of returns (i.e., discounted sum of rewards over time steps).

Existing reinforcement learning (RL) methods for LLMs—such as REINFORCE, PPO, and GRPO—adopt a reward‑maximization objective:

maxθEyπθ(x)[r(x,y)].\max_{\theta} \mathbb{E}_{\mathbf{y} \sim \pi_{\theta}(\cdot|\mathbf{x})}\bigl[ r(\mathbf{x}, \mathbf{y}) \bigr].

However, this approach tends to over‑fit the dominant reward mode, leading to mode collapse and a lack of diversity in generated reasoning paths. For complex reasoning tasks (e.g., long‑chain mathematical proofs or code generation), capturing a diverse set of valid solutions is crucial for generalization.

1.2 Core Idea of FlowRL#

FlowRL shifts the paradigm from reward maximization to reward‑distribution matching. Instead of only pushing the policy toward the highest‑reward answers, FlowRL encourages the policy to generate answers with probabilities proportional to their exponentiated rewards. Formally, the target is:

πθ(yx)exp ⁣(βr(x,y)),\pi_{\theta}(\mathbf{y}|\mathbf{x}) \propto \exp\!\bigl(\beta \, r(\mathbf{x}, \mathbf{y})\bigr),

where β>0\beta > 0 is a temperature parameter. This Boltzmann‑type distribution naturally balances exploitation (high‑reward answers) and exploration (lower‑reward but still valid answers).

2. Theoretical Framework: From Reverse KLD to Trajectory Balance#

2.1 Distribution Matching via Reverse KL Divergence#

To align the policy with the desired Boltzmann distribution, we minimize the reverse Kullback–Leibler (KL) divergence between the policy πθ\pi_{\theta} and a normalized target distribution π~\tilde{\pi}. Because the partition function of the Boltzmann distribution is intractable, we introduce a learnable partition function Zϕ(x)Z_{\phi}(\mathbf{x}):

π~(yx)=exp ⁣(βr(x,y))Zϕ(x).\tilde{\pi}(\mathbf{y}|\mathbf{x}) = \frac{\exp\!\bigl(\beta \, r(\mathbf{x}, \mathbf{y})\bigr)}{Z_{\phi}(\mathbf{x})}.

The reverse KL divergence is:

DKL ⁣(πθπ~)=Eyπθ(x)[logπθ(yx)logπ~(yx)].\mathbb{D}_{\text{KL}}\!\bigl(\pi_{\theta} \parallel \tilde{\pi}\bigr) = \mathbb{E}_{\mathbf{y} \sim \pi_{\theta}(\cdot|\mathbf{x})} \Bigl[ \log \pi_{\theta}(\mathbf{y}|\mathbf{x}) - \log \tilde{\pi}(\mathbf{y}|\mathbf{x}) \Bigr].

Substituting π~\tilde{\pi} gives:

DKL=Eyπθ(x)[logπθ(yx)βr(x,y)+logZϕ(x)].\mathbb{D}_{\text{KL}} = \mathbb{E}_{\mathbf{y} \sim \pi_{\theta}(\cdot|\mathbf{x})} \Bigl[ \log \pi_{\theta}(\mathbf{y}|\mathbf{x}) - \beta \, r(\mathbf{x}, \mathbf{y}) + \log Z_{\phi}(\mathbf{x}) \Bigr].

2.2 From Flow Equation to Trajectory Balance: Gradient Equivalence#

In GFlowNets, the trajectory balance condition emerges from the more fundamental flow conservation principle. Consider the generation of a complete response y=(y1,,yT)\mathbf{y} = (y_1, \dots, y_T) given a prompt x\mathbf{x} as a trajectory in a directed acyclic graph. Let F(s)F(s) denote the flow (probability mass) at state ss. The forward policy πθ\pi_\theta (our generation model) determines the transition probabilities. For a complete trajectory, the flow equation states that the probability flow from the initial state s0s_0 (empty response) to the terminal state sT=ys_T = \mathbf{y} must satisfy:

F(s0)t=1Tπθ(yty<t,x)=R(x,y)t=1TPB(yt1yt,x),F(s_0) \cdot \prod_{t=1}^{T} \pi_\theta(y_t \mid \mathbf{y}_{<t}, \mathbf{x}) = R(\mathbf{x}, \mathbf{y}) \cdot \prod_{t=1}^{T} P_B(y_{t-1} \mid \mathbf{y}_{\le t}, \mathbf{x}),

where R(x,y)R(\mathbf{x},\mathbf{y}) is the reward associated with the terminal state, and PBP_B is a backward policy (often chosen as uniform or a simple fixed distribution). The initial flow F(s0)F(s_0) is exactly the partition function Zϕ(x)Z_\phi(\mathbf{x}), as it represents the total probability mass injected into the system. In FlowRL, the reward is defined as the exponentiated, reference‑model‑adjusted reward:

R(x,y)=exp ⁣(βr(x,y))πref(yx).R(\mathbf{x}, \mathbf{y}) = \exp\!\bigl(\beta \, r(\mathbf{x}, \mathbf{y})\bigr) \cdot \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}).

Moreover, for simplicity (and following common practice in GFlowNets for sequence generation), the backward policy is taken to be uniform, so that tPB()=constant\prod_{t} P_B(\cdot) = \text{constant}. Ignoring this constant (since it can be absorbed into the learned partition function), we obtain the simplified trajectory‑balance equation:

Zϕ(x)πθ(yx)=exp ⁣(βr(x,y))πref(yx).(TB‑eq)Z_\phi(\mathbf{x}) \cdot \pi_\theta(\mathbf{y} \mid \mathbf{x}) = \exp\!\bigl(\beta \, r(\mathbf{x}, \mathbf{y})\bigr) \cdot \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}). \tag{TB‑eq}

Taking logarithms on both sides gives the linear constraint:

logZϕ(x)+logπθ(yx)=βr(x,y)+logπref(yx).(log‑TB)\log Z_\phi(\mathbf{x}) + \log \pi_\theta(\mathbf{y} \mid \mathbf{x}) = \beta \, r(\mathbf{x}, \mathbf{y}) + \log \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}). \tag{log‑TB}

Since this equality cannot hold for every possible trajectory during training, we turn it into a squared‑error objective, the trajectory‑balance loss:

LTB(x,y;θ,ϕ)=(logZϕ(x)+logπθ(yx)βr(x,y)logπref(yx))2.\mathcal{L}_{\text{TB}}(\mathbf{x}, \mathbf{y}; \theta, \phi) = \Bigl( \log Z_\phi(\mathbf{x}) + \log \pi_\theta(\mathbf{y} \mid \mathbf{x}) - \beta \, r(\mathbf{x}, \mathbf{y}) - \log \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}) \Bigr)^2.

Now we demonstrate that minimizing this loss is gradient‑equivalent to minimizing the reverse KL divergence between πθ\pi_\theta and the target distribution π~(yx)=exp(βr(x,y))πref(yx)/Zϕ(x)\tilde{\pi}(\mathbf{y} \mid \mathbf{x}) = \exp(\beta r(\mathbf{x},\mathbf{y})) \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}) / Z_\phi(\mathbf{x}).

First, compute the gradient of DKL\mathbb{D}_{\text{KL}} with respect to the policy parameters θ\theta:

θDKL=Eyπθ(x)[θlogπθ(yx)  (logπθ(yx)βr(x,y)logπref(yx)+logZϕ(x))].\nabla_{\theta} \mathbb{D}_{\text{KL}} = \mathbb{E}_{\mathbf{y} \sim \pi_\theta(\cdot \mid \mathbf{x})} \Bigl[ \nabla_{\theta} \log \pi_\theta(\mathbf{y} \mid \mathbf{x}) \; \bigl( \log \pi_\theta(\mathbf{y} \mid \mathbf{x}) - \beta \, r(\mathbf{x}, \mathbf{y}) - \log \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}) + \log Z_\phi(\mathbf{x}) \bigr) \Bigr].

Next, compute the gradient of the trajectory‑balance loss. Note that LTB\mathcal{L}_{\text{TB}} is an expectation over the same distribution πθ\pi_\theta (or an importance‑weighted version thereof when using off‑policy data). For on‑policy sampling, we have:

θLTB=θEyπθ(x)[(logZϕ(x)+logπθ(yx)βr(x,y)logπref(yx))2].\nabla_{\theta} \mathcal{L}_{\text{TB}} = \nabla_{\theta} \, \mathbb{E}_{\mathbf{y} \sim \pi_\theta(\cdot \mid \mathbf{x})} \Bigl[ \bigl( \log Z_\phi(\mathbf{x}) + \log \pi_\theta(\mathbf{y} \mid \mathbf{x}) - \beta \, r(\mathbf{x}, \mathbf{y}) - \log \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}) \bigr)^2 \Bigr].

Applying the log‑derivative trick, we obtain:

θLTB=2Eyπθ(x)[θlogπθ(yx)  (logZϕ(x)+logπθ(yx)βr(x,y)logπref(yx))].\nabla_{\theta} \mathcal{L}_{\text{TB}} = 2 \, \mathbb{E}_{\mathbf{y} \sim \pi_\theta(\cdot \mid \mathbf{x})} \Bigl[ \nabla_{\theta} \log \pi_\theta(\mathbf{y} \mid \mathbf{x}) \; \bigl( \log Z_\phi(\mathbf{x}) + \log \pi_\theta(\mathbf{y} \mid \mathbf{x}) - \beta \, r(\mathbf{x}, \mathbf{y}) - \log \pi_{\text{ref}}(\mathbf{y} \mid \mathbf{x}) \bigr) \Bigr].

Comparing KL‑grad and TB‑grad, the two gradients are proportional (differing only by a constant factor of 2). Therefore, minimizing the trajectory‑balance loss yields the same gradient direction as minimizing the reverse KL divergence.

Practical implications: This equivalence provides a robust surrogate objective. The squared‑error formulation is numerically stable, allows simultaneous optimization of both θ\theta (policy) and ϕ\phi (partition function), and naturally accommodates off‑policy data through importance sampling. Moreover, it directly enforces the flow‑balance equation, which is the cornerstone of the GFlowNet framework.

2.3 Incorporating Reference Model and Length Normalization#

In practice, two modifications are essential for stable training on long reasoning chains:

  1. Reference‑model regularization
    To prevent the policy from deviating too far from the original pre‑trained model, the target distribution is augmented with a reference policy πref\pi_{\text{ref}}:

    π~(yx)=exp ⁣(βr(x,y))  πref(yx)Zϕ(x).\tilde{\pi}(\mathbf{y}|\mathbf{x}) = \frac{\exp\!\bigl(\beta \, r(\mathbf{x}, \mathbf{y})\bigr) \; \pi_{\text{ref}}(\mathbf{y}|\mathbf{x})}{Z_{\phi}(\mathbf{x})}.
  2. Length normalization
    Because logπθ(yx)=t=1ylogπθ(yty<t,x)\log \pi_{\theta}(\mathbf{y}|\mathbf{x}) = \sum_{t=1}^{|\mathbf{y}|} \log \pi_{\theta}(y_t|\mathbf{y}_{<t},\mathbf{x}), the loss can grow with sequence length, causing gradient explosion. FlowRL normalizes the log‑probabilities by the response length y|\mathbf{y}|:

    1ylogπθ(yx),1ylogπref(yx).\frac{1}{|\mathbf{y}|} \log \pi_{\theta}(\mathbf{y}|\mathbf{x}), \quad \frac{1}{|\mathbf{y}|} \log \pi_{\text{ref}}(\mathbf{y}|\mathbf{x}).
  3. Importance sampling for off‑policy correction
    To reuse rollouts collected from an older policy πθold\pi_{\theta_{\text{old}}}, FlowRL employs clipped importance weights:

    w=clip ⁣(πθ(yx)πθold(yx),  1ϵ,  1+ϵ)detach.w = \text{clip}\!\left( \frac{\pi_{\theta}(\mathbf{y}|\mathbf{x})}{\pi_{\theta_{\text{old}}}(\mathbf{y}|\mathbf{x})},\; 1-\epsilon,\; 1+\epsilon \right)^{\text{detach}}.

2.4 Final FlowRL Objective#

Combining the above ingredients, the complete FlowRL loss becomes:

LFlowRL=w(logZϕ(x)+1ylogπθ(yx)βr^(x,y)1ylogπref(yx))2,\mathcal{L}_{\text{FlowRL}} = w \cdot \Bigl( \log Z_{\phi}(\mathbf{x}) + \frac{1}{|\mathbf{y}|}\log \pi_{\theta}(\mathbf{y}|\mathbf{x}) - \beta \, \hat{r}(\mathbf{x},\mathbf{y}) - \frac{1}{|\mathbf{y}|}\log \pi_{\text{ref}}(\mathbf{y}|\mathbf{x}) \Bigr)^2,

where r^\hat{r} is the group‑normalized reward (e.g., r^i=(rimean(r))/std(r)\hat{r}_i = (r_i - \text{mean}(\mathbf{r}))/\text{std}(\mathbf{r})).

3. Implementation Details#

3.1 Partition Function Network ZϕZ_{\phi}#

The partition function is parameterized by a small neural network (a 3‑layer MLP) that estimates logZϕ(x)\log Z_{\phi}(\mathbf{x}).

  • Input: The hidden‑state representation of the prompt x\mathbf{x}.
    Specifically, we take the mean of the last‑layer hidden states of the language model over the prompt tokens, yielding a fixed‑dimensional vector hxRd\mathbf{h}_{\mathbf{x}} \in \mathbb{R}^{d}.
  • Architecture: h(1)=GELU(W1hx+b1),h(2)=GELU(W2h(1)+b2),logZϕ(x)=W3h(2)+b3,\mathbf{h}^{(1)} = \text{GELU}\bigl( \mathbf{W}_1 \mathbf{h}_{\mathbf{x}} + \mathbf{b}_1 \bigr), \quad \mathbf{h}^{(2)} = \text{GELU}\bigl( \mathbf{W}_2 \mathbf{h}^{(1)} + \mathbf{b}_2 \bigr), \quad \log Z_{\phi}(\mathbf{x}) = \mathbf{W}_3 \mathbf{h}^{(2)} + b_3, where the hidden dimensions match the base LLM’s hidden size.
  • Training: The parameters ϕ\phi are updated jointly with θ\theta using the same optimizer (e.g., Adam).

3.2 Simplified Python Code Snippet#

import torch
import torch.nn as nn

class PartitionFunction(nn.Module):
    """3-layer MLP to estimate log Z_phi(x)."""
    def __init__(self, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, prompt_embeddings):
        # prompt_embeddings: [batch_size, seq_len, hidden_dim]
        prompt_repr = prompt_embeddings.mean(dim=1)          # [batch_size, hidden_dim]
        log_Z = self.mlp(prompt_repr).squeeze(-1)           # [batch_size]
        return log_Z

def compute_flowrl_loss(prompt, response, reward, policy_model, ref_model, Z_phi, beta=15.0):
    """
    Compute FlowRL loss with length normalization and importance sampling.
    """
    # Get prompt embeddings from the policy model
    with torch.no_grad():
        prompt_outputs = policy_model(prompt, output_hidden_states=True)
        prompt_embeddings = prompt_outputs.hidden_states[-1]  # [batch, seq_len, hidden]
    
    # Estimate log partition function
    log_Z = Z_phi(prompt_embeddings)                         # [batch]
    
    # Log probabilities from policy and reference model
    log_pi = policy_model.get_log_prob(response, prompt)     # [batch]
    log_ref = ref_model.get_log_prob(response, prompt)       # [batch]
    
    # Length normalization
    lengths = response.lengths.float()                       # [batch]
    norm_log_pi = log_pi / lengths
    norm_log_ref = log_ref / lengths
    
    # Group‑normalized reward (pre‑computed)
    norm_reward = reward                                     # [batch]
    
    # Importance weight (detached)
    old_log_pi = ...  # from stored rollouts
    imp_ratio = (log_pi - old_log_pi).exp().detach()
    w = torch.clamp(imp_ratio, 1-0.2, 1+0.2)
    
    # Trajectory‑balance term
    tb_term = log_Z + norm_log_pi - beta * norm_reward - norm_log_ref
    loss = w * (tb_term ** 2)
    
    return loss.mean()

4. Why Does FlowRL Work? Key Insights#

  1. Distribution Matching vs. Reward Maximization
    FlowRL explicitly encourages the policy to cover multiple high‑reward modes instead of collapsing to a single peak. This is crucial for reasoning tasks where diverse solution strategies exist.

  2. The Boltzmann Distribution as a Soft Exploration Mechanism
    The target πexp(βr)\pi \propto \exp(\beta r) provides a continuous trade‑off between exploitation and exploration, analogous to a “softened” ε‑greedy strategy but more principled.

  3. Trajectory Balance as a Stable Optimization Proxy
    The squared‑error loss derived from the flow‑balance condition is numerically more stable than direct policy‑gradient estimation, especially for long sequences.

  4. GFlowNets as a Conceptual Bridge
    Although the core idea is distribution matching, the GFlowNets framework offers an intuitive flow‑based analogy and a theoretically grounded optimization objective (trajectory balance).

  5. Practical Adaptations for LLMs
    Length normalization and reference‑model regularization are critical engineering adaptations that make distribution matching feasible for real‑world LLM reasoning tasks.

5. Conclusion#

FlowRL re‑frames LLM reinforcement learning as a reward‑distribution matching problem. By minimizing the reverse KL divergence between the policy and a Boltzmann‑type target distribution—implemented via a learnable partition function and optimized through a trajectory‑balance loss—FlowRL achieves superior diversity and generalization in mathematical and code‑reasoning tasks. The method’s success stems from its principled departure from pure reward maximization, coupled with practical adaptations for training large language models on long reasoning chains.

References#

FlowRL - Matching Reward Distributions for LLM Reasoning
https://adalovelemon.github.io/blog/en/posts/content/paperreading/rl/flowrl/
Author
Ada Lovelemon
Published at
2026-02-09

Comments Section