FlowRL: Matching Reward Distributions for LLM Reasoning
1. Motivation & Problem Definition
1.1 Background and Limitations of Traditional RL in LLM Reasoning
Large language model (LLM) reasoning is typically formulated as a conditional generation problem: given a question , a policy model generates an answer . The quality of the answer is evaluated by a task-specific reward signal . In reasoning tasks, the reward is usually sparse and terminal (e.g., correctness of the final answer), which means we consider one-step reward instead of returns (i.e., discounted sum of rewards over time steps).
Existing reinforcement learning (RL) methods for LLMs—such as REINFORCE, PPO, and GRPO—adopt a reward‑maximization objective:
However, this approach tends to over‑fit the dominant reward mode, leading to mode collapse and a lack of diversity in generated reasoning paths. For complex reasoning tasks (e.g., long‑chain mathematical proofs or code generation), capturing a diverse set of valid solutions is crucial for generalization.
1.2 Core Idea of FlowRL
FlowRL shifts the paradigm from reward maximization to reward‑distribution matching. Instead of only pushing the policy toward the highest‑reward answers, FlowRL encourages the policy to generate answers with probabilities proportional to their exponentiated rewards. Formally, the target is:
where is a temperature parameter. This Boltzmann‑type distribution naturally balances exploitation (high‑reward answers) and exploration (lower‑reward but still valid answers).
2. Theoretical Framework: From Reverse KLD to Trajectory Balance
2.1 Distribution Matching via Reverse KL Divergence
To align the policy with the desired Boltzmann distribution, we minimize the reverse Kullback–Leibler (KL) divergence between the policy and a normalized target distribution . Because the partition function of the Boltzmann distribution is intractable, we introduce a learnable partition function :
The reverse KL divergence is:
Substituting gives:
2.2 From Flow Equation to Trajectory Balance: Gradient Equivalence
In GFlowNets, the trajectory balance condition emerges from the more fundamental flow conservation principle. Consider the generation of a complete response given a prompt as a trajectory in a directed acyclic graph. Let denote the flow (probability mass) at state . The forward policy (our generation model) determines the transition probabilities. For a complete trajectory, the flow equation states that the probability flow from the initial state (empty response) to the terminal state must satisfy:
where is the reward associated with the terminal state, and is a backward policy (often chosen as uniform or a simple fixed distribution). The initial flow is exactly the partition function , as it represents the total probability mass injected into the system. In FlowRL, the reward is defined as the exponentiated, reference‑model‑adjusted reward:
Moreover, for simplicity (and following common practice in GFlowNets for sequence generation), the backward policy is taken to be uniform, so that . Ignoring this constant (since it can be absorbed into the learned partition function), we obtain the simplified trajectory‑balance equation:
Taking logarithms on both sides gives the linear constraint:
Since this equality cannot hold for every possible trajectory during training, we turn it into a squared‑error objective, the trajectory‑balance loss:
Now we demonstrate that minimizing this loss is gradient‑equivalent to minimizing the reverse KL divergence between and the target distribution .
First, compute the gradient of with respect to the policy parameters :
Next, compute the gradient of the trajectory‑balance loss. Note that is an expectation over the same distribution (or an importance‑weighted version thereof when using off‑policy data). For on‑policy sampling, we have:
Applying the log‑derivative trick, we obtain:
Comparing KL‑grad and TB‑grad, the two gradients are proportional (differing only by a constant factor of 2). Therefore, minimizing the trajectory‑balance loss yields the same gradient direction as minimizing the reverse KL divergence.
Practical implications: This equivalence provides a robust surrogate objective. The squared‑error formulation is numerically stable, allows simultaneous optimization of both (policy) and (partition function), and naturally accommodates off‑policy data through importance sampling. Moreover, it directly enforces the flow‑balance equation, which is the cornerstone of the GFlowNet framework.
2.3 Incorporating Reference Model and Length Normalization
In practice, two modifications are essential for stable training on long reasoning chains:
Reference‑model regularization
To prevent the policy from deviating too far from the original pre‑trained model, the target distribution is augmented with a reference policy :Length normalization
Because , the loss can grow with sequence length, causing gradient explosion. FlowRL normalizes the log‑probabilities by the response length :Importance sampling for off‑policy correction
To reuse rollouts collected from an older policy , FlowRL employs clipped importance weights:
2.4 Final FlowRL Objective
Combining the above ingredients, the complete FlowRL loss becomes:
where is the group‑normalized reward (e.g., ).
3. Implementation Details
3.1 Partition Function Network
The partition function is parameterized by a small neural network (a 3‑layer MLP) that estimates .
- Input: The hidden‑state representation of the prompt .
Specifically, we take the mean of the last‑layer hidden states of the language model over the prompt tokens, yielding a fixed‑dimensional vector . - Architecture: where the hidden dimensions match the base LLM’s hidden size.
- Training: The parameters are updated jointly with using the same optimizer (e.g., Adam).
3.2 Simplified Python Code Snippet
import torch
import torch.nn as nn
class PartitionFunction(nn.Module):
"""3-layer MLP to estimate log Z_phi(x)."""
def __init__(self, hidden_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, prompt_embeddings):
# prompt_embeddings: [batch_size, seq_len, hidden_dim]
prompt_repr = prompt_embeddings.mean(dim=1) # [batch_size, hidden_dim]
log_Z = self.mlp(prompt_repr).squeeze(-1) # [batch_size]
return log_Z
def compute_flowrl_loss(prompt, response, reward, policy_model, ref_model, Z_phi, beta=15.0):
"""
Compute FlowRL loss with length normalization and importance sampling.
"""
# Get prompt embeddings from the policy model
with torch.no_grad():
prompt_outputs = policy_model(prompt, output_hidden_states=True)
prompt_embeddings = prompt_outputs.hidden_states[-1] # [batch, seq_len, hidden]
# Estimate log partition function
log_Z = Z_phi(prompt_embeddings) # [batch]
# Log probabilities from policy and reference model
log_pi = policy_model.get_log_prob(response, prompt) # [batch]
log_ref = ref_model.get_log_prob(response, prompt) # [batch]
# Length normalization
lengths = response.lengths.float() # [batch]
norm_log_pi = log_pi / lengths
norm_log_ref = log_ref / lengths
# Group‑normalized reward (pre‑computed)
norm_reward = reward # [batch]
# Importance weight (detached)
old_log_pi = ... # from stored rollouts
imp_ratio = (log_pi - old_log_pi).exp().detach()
w = torch.clamp(imp_ratio, 1-0.2, 1+0.2)
# Trajectory‑balance term
tb_term = log_Z + norm_log_pi - beta * norm_reward - norm_log_ref
loss = w * (tb_term ** 2)
return loss.mean()4. Why Does FlowRL Work? Key Insights
Distribution Matching vs. Reward Maximization
FlowRL explicitly encourages the policy to cover multiple high‑reward modes instead of collapsing to a single peak. This is crucial for reasoning tasks where diverse solution strategies exist.The Boltzmann Distribution as a Soft Exploration Mechanism
The target provides a continuous trade‑off between exploitation and exploration, analogous to a “softened” ε‑greedy strategy but more principled.Trajectory Balance as a Stable Optimization Proxy
The squared‑error loss derived from the flow‑balance condition is numerically more stable than direct policy‑gradient estimation, especially for long sequences.GFlowNets as a Conceptual Bridge
Although the core idea is distribution matching, the GFlowNets framework offers an intuitive flow‑based analogy and a theoretically grounded optimization objective (trajectory balance).Practical Adaptations for LLMs
Length normalization and reference‑model regularization are critical engineering adaptations that make distribution matching feasible for real‑world LLM reasoning tasks.
5. Conclusion
FlowRL re‑frames LLM reinforcement learning as a reward‑distribution matching problem. By minimizing the reverse KL divergence between the policy and a Boltzmann‑type target distribution—implemented via a learnable partition function and optimized through a trajectory‑balance loss—FlowRL achieves superior diversity and generalization in mathematical and code‑reasoning tasks. The method’s success stems from its principled departure from pure reward maximization, coupled with practical adaptations for training large language models on long reasoning chains.

