Skip to content

DRO

DRO (Richemond et al., 2024; Kimi Team et al., 2025) is a general off-policy (and even offline) reinforcement learning method that uses a quadratic penalty term to constrain the policy update. Notice that this loss uses a different (soft) formulation of the advantage estimation, which needs to be implemented on the client side.

The DRO objective is:

\[ \mathcal{L}_{\text{DRO}}(\theta) = \mathbb{E}_{x \sim q}\left[\log p_\theta(x) \cdot A(x) - \frac{1}{2}\beta \left(\log \frac{p_\theta(x)}{q(x)}\right)^2\right] \]

This is implemented as:

# Compute quadratic penalty term
quadratic_term = (target_logprobs - sampling_logprobs) ** 2
# Compute DRO objective
dro_objective = target_logprobs * advantages - 0.5 * beta * quadratic_term
# DRO loss is negative of objective
loss = -dro_objective.sum()

Input tensors:

  • target_tokens: array[(N,), int] — Target token IDs (from the sampler \(q\))
  • logprobs: array[(N,), float]sampling_logprobs for the tokens
  • advantages: array[(N,), float] — Advantage values for RL

Output tensors:

  • logprobs: array[(N,), float]target_logprobs for the tokens

Output diagnostics:

  • loss:sum (scalar) — Sum of DRO losses

Custom beta

fwd_bwd_future = await training_client.forward_backward_async(
    data=data,
    loss_fn="dro",
    loss_fn_config={"beta": 0.05}
)
fwd_bwd_result = await fwd_bwd_future.result_async()