PPO
PPO (Schulman et al., 2017) addresses issues with standard policy gradient methods by introducing a clipping objective that limits policy updates within a close neighborhood of the sampling distribution. This prevents updates that are too large in policy space, especially when taking multiple gradient steps on the same rollout distribution.
The objective clips the importance ratio \(\frac{p_\theta(x)}{q(x)}\) to prevent large policy updates, where \(p_\theta\) is the learner policy and \(q\) is the sampling policy. Note that the PPO clipping and loss computation is applied token-wise, computing the loss for each token independently.
The PPO clipping objective is:
The final PPO loss combines the clipped and unclipped objectives:
where \(\epsilon_{\text{low}}\) and \(\epsilon_{\text{high}}\) are hyperparameters (currently fixed to 0.2 in Tinker).
This is implemented as:
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Apply clipping
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
# Compute both objectives
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
# Take minimum (most conservative)
ppo_objective = torch.min(unclipped_objective, clipped_objective)
# PPO loss is negative of objective
loss = -ppo_objective.sum()
Input tensors:
target_tokens: array[(N,), int]— Target token IDs (from the sampler \(q\))logprobs: array[(N,), float]—sampling_logprobsfor the tokensadvantages: array[(N,), float]— Advantage values for RL
Output tensors:
logprobs: array[(N,), float]—target_logprobsfor the tokens
Output diagnostics:
loss:sum(scalar) — Sum of PPO clipped losses