Skip to content

Importance Sampling

For RL, we implement a common variant of the policy gradient objective, used in practical settings where the learner policy \(p\) may differ from the sampling policy \(q\), which is common due to, e.g., non-determinism. The issue is that if these policies differ, then the objective:

\[ \mathcal{L}(\theta) = \mathbb{E}_{x\sim p_\theta}\bigl[A(x)\bigr] \]

is not computed in an unbiased way due to \(x \sim q\) (sampler) not exactly matching the desired \(x \sim p_\theta\) (learner). To correct the bias, we use a modified "importance sampling" objective:

\[ \mathcal{L}_{\text{IS}}(\theta) = \mathbb{E}_{x\sim q}\Bigl[\frac{p_\theta(x)}{q(x)}A(x)\Bigr], \]

which yields the correct expected reward. In the formula above:

  • \(\log p_\theta(x)\)target_logprobs is from the learner, on the forward part of the forward_backward pass.
  • \(\log q(x)\)sampling_logprobs is from the sampler, recorded during sampling as a correction term.

This is implemented as:

# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Compute importance-weighted loss
loss = -(prob_ratio * advantages).sum()

Input tensors:

  • target_tokens: array[(N,), int] — Target token IDs (from the sampler \(q\))
  • logprobs: array[(N,), float]sampling_logprobs for the tokens
  • advantages: array[(N,), float] — Advantage values for RL (positive to reinforce, negative to discourage)

Output tensors:

  • logprobs: array[(N,), float]target_logprobs for the tokens

Output diagnostics:

  • loss:sum (scalar) — Sum of importance-weighted policy gradient losses \(\mathcal L_{\text{IS}}\)