Skip to content

PPO

PPO (Schulman et al., 2017) addresses issues with standard policy gradient methods by introducing a clipping objective that limits policy updates within a close neighborhood of the sampling distribution. This prevents updates that are too large in policy space, especially when taking multiple gradient steps on the same rollout distribution.

The objective clips the importance ratio \(\frac{p_\theta(x)}{q(x)}\) to prevent large policy updates, where \(p_\theta\) is the learner policy and \(q\) is the sampling policy. Note that the PPO clipping and loss computation is applied token-wise, computing the loss for each token independently.

The PPO clipping objective is:

\[ \mathcal{L}_{\text{CLIP}}(\theta) = -\mathbb{E}_{x \sim q}\left[\text{clip}\left(\frac{p_\theta(x)}{q(x)}, 1-\epsilon_{\text{low}}, 1+\epsilon_{\text{high}}\right) \cdot A(x)\right] \]

The final PPO loss combines the clipped and unclipped objectives:

\[ \mathcal{L}_{\text{PPO}}(\theta) = -\mathbb{E}_{x \sim q}\left[\min\left(\frac{p_\theta(x)}{q(x)} \cdot A(x), \text{clip}\left(\frac{p_\theta(x)}{q(x)}, 1-\epsilon_{\text{low}}, 1+\epsilon_{\text{high}}\right) \cdot A(x)\right)\right] \]

where \(\epsilon_{\text{low}}\) and \(\epsilon_{\text{high}}\) are hyperparameters (currently fixed to 0.2 in Tinker).

This is implemented as:

# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Apply clipping
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
# Compute both objectives
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
# Take minimum (most conservative)
ppo_objective = torch.min(unclipped_objective, clipped_objective)
# PPO loss is negative of objective
loss = -ppo_objective.sum()

Input tensors:

  • target_tokens: array[(N,), int] — Target token IDs (from the sampler \(q\))
  • logprobs: array[(N,), float]sampling_logprobs for the tokens
  • advantages: array[(N,), float] — Advantage values for RL

Output tensors:

  • logprobs: array[(N,), float]target_logprobs for the tokens

Output diagnostics:

  • loss:sum (scalar) — Sum of PPO clipped losses

Custom clipping thresholds

fwd_bwd_future = await training_client.forward_backward_async(
    data=data,
    loss_fn="ppo",
    loss_fn_config={"clip_low_threshold": 0.9, "clip_high_threshold": 1.1}
)
fwd_bwd_result = await fwd_bwd_future.result_async()