Importance Sampling
For RL, we implement a common variant of the policy gradient objective, used in practical settings where the learner policy \(p\) may differ from the sampling policy \(q\), which is common due to, e.g., non-determinism. The issue is that if these policies differ, then the objective:
\[
\mathcal{L}(\theta) = \mathbb{E}_{x\sim p_\theta}\bigl[A(x)\bigr]
\]
is not computed in an unbiased way due to \(x \sim q\) (sampler) not exactly matching the desired \(x \sim p_\theta\) (learner). To correct the bias, we use a modified "importance sampling" objective:
\[
\mathcal{L}_{\text{IS}}(\theta) = \mathbb{E}_{x\sim q}\Bigl[\frac{p_\theta(x)}{q(x)}A(x)\Bigr],
\]
which yields the correct expected reward. In the formula above:
- \(\log p_\theta(x)\) –
target_logprobsis from the learner, on the forward part of theforward_backwardpass. - \(\log q(x)\) –
sampling_logprobsis from the sampler, recorded during sampling as a correction term.
This is implemented as:
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Compute importance-weighted loss
loss = -(prob_ratio * advantages).sum()
Input tensors:
target_tokens: array[(N,), int]— Target token IDs (from the sampler \(q\))logprobs: array[(N,), float]—sampling_logprobsfor the tokensadvantages: array[(N,), float]— Advantage values for RL (positive to reinforce, negative to discourage)
Output tensors:
logprobs: array[(N,), float]—target_logprobsfor the tokens
Output diagnostics:
loss:sum(scalar) — Sum of importance-weighted policy gradient losses \(\mathcal L_{\text{IS}}\)