CISPO
CISPO (Chen et al., 2024; Khatri et al., 2024) is a policy gradient method that uses a clipped importance ratio as a coefficient for the policy gradient. Unlike PPO which clips the objective directly, CISPO clips the ratio and uses it to weight the log probability.
The CISPO objective is:
\[
\mathcal{L}_{\text{CISPO}}(\theta) = \mathbb{E}_{x \sim q}\left[\textbf{sg}\left( \text{clip}\left(\frac{p_\theta(x)}{q(x)}, 1-\epsilon_{\text{low}}, 1+\epsilon_{\text{high}}\right) \right) \cdot \log p_\theta(x) \cdot A(x)\right]
\]
This is implemented as:
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Apply clipping
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
# Compute CISPO objective (detach the clipped ratio)
cispo_objective = clipped_ratio.detach() * target_logprobs * advantages
# CISPO loss is negative of objective
loss = -cispo_objective.sum()
Input tensors:
target_tokens: array[(N,), int]— Target token IDs (from the sampler \(q\))logprobs: array[(N,), float]—sampling_logprobsfor the tokensadvantages: array[(N,), float]— Advantage values for RL
Output tensors:
logprobs: array[(N,), float]—target_logprobsfor the tokens
Output diagnostics:
loss:sum(scalar) — Sum of CISPO losses