Loss functions in Tinker
For most use cases, you can use the Tinker API's built-in loss functions by passing in a string identifier to forward_backward
, which supports cross-entropy and policy gradient objectives. When you need more control, forward_backward_custom
enables arbitrary differentiable loss functions at the cost of an additional forward pass; we explain both approaches in this doc.
When you call forward_backward
, you specify a loss function using a string that selects from a predetermined set of options, comprising the most common losses used for language model training.
- Input:
forward_backward
expects a certain set of input tensors, passed in viadatum.loss_fn_inputs
, which is a dict mappingstr
to either a numpy or torch tensor - Output:
forward_backward
returns aForwardBackwardOutput
, which has a set of output tensors infwd_bwd_result.loss_fn_outputs
For an example of using forward_backward
, see rl/train.py
in the Cookbook:
import tinker
import torch
from tinker import TensorData
# Create training data with required inputs
datum = tinker.Datum(
model_input=input_tokens,
loss_fn_inputs={
"target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
"logprobs": TensorData.from_torch(torch.tensor(sampling_logprobs)), # Reference logprobs
"advantages": TensorData.from_torch(torch.tensor(advantages)),
}
)
# Option 1: Use importance sampling REINFORCE
fwd_bwd_result = await training_client.forward_backward_async(
[datum], loss_fn="importance_sampling"
)
# Option 2: Use PPO with clipping
fwd_bwd_result = await training_client.forward_backward_async(
[datum], loss_fn="ppo"
)
Basic loss functions
Currently, the Tinker API supports cross_entropy
(for supervised learning), importance_sampling
(for RL), and ppo
(for RL). We denote the training model as , the sampling distribution as , and advantages as . Also, for notation simplicity we omit the query and denote the full model completion sequence of tokens as .
All losses are applied at the token level and tensors below have shape (N,)
where N
is model_input.length
. They can be provided as numpy.ndarray
or torch.Tensor
, and the return values will use the same tensor type.
Supervised learning: cross_entropy
For SL, we implement the standard cross-entropy loss (i.e., negative log-likelihood), which optimizes the policy to maximize the log-probability of the tokens :
where weights
is either 0 or 1, typically generated from renderers.build_supervised_example
(i.e., to specify the desired assistant turns to train on).
This is implemented as:
# Apply weights and compute elementwise loss
elementwise_loss = -target_logprobs * weights
# Apply sum reduction to get the total loss
loss = elementwise_loss.sum() # scalar
- Input tensors:
target_tokens: array[(N,), int]
- Target token IDsweights: array[(N,), float]
- Token-level loss weights (typically from the renderer)
- Output tensors:
logprobs: array[(N,), float]
- Log probabilities of predicted tokens
- Output diagnostics:
loss:sum
(scalar) - Sum of weighted cross-entropy losses
Policy gradient: importance_sampling
For RL, we implement a common variant of the policy gradient objective, used in practical settings where the learner policy may differ from the sampling policy , which is common due to, e.g., non-determinism (opens in a new tab). To remove the bias caused by this difference, we can use a modified "importance sampling" objective:
which yields the correct expected reward. This is implemented as:
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Compute importance-weighted loss
loss = -(prob_ratio * advantages).sum()
where advantages
may additionally subtract a baseline from the rewards.
- Input tensors:
target_tokens: array[(N,), int]
- Target token IDs (from the sampler )logprobs: array[(N,), float]
- Reference log probabilities for the target tokensadvantages: array[(N,), float]
- Advantage values for RL
- Output tensors:
logprobs: array[(N,), float]
- Log probabilities for the target tokens
- Output diagnostics:
loss:sum
(scalar) - Sum of importance-weighted policy gradient losses
Proximal Policy Optimization: ppo
PPO (Schulman et al., 2017 (opens in a new tab)) addresses issues with standard policy gradient methods by introducing a clipping objective that limits policy updates within a close neighborhood of the sampling distribution. This prevents updates that are too large in policy space, especially when taking multiple gradient steps on the same rollout distribution.
The objective clips the importance ratio to prevent large policy updates, where is the learner policy and is the sampling policy. Note that the PPO clipping and loss computation is applied token-wise, computing the loss for each token independently.
The PPO clipping objective is:
The final PPO loss combines the clipped and unclipped objectives:
where and are hyperparameters (currently fixed to 0.2 in Tinker).
This is implemented as:
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Apply clipping
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
# Compute both objectives
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
# Take minimum (most conservative)
ppo_objective = torch.min(unclipped_objective, clipped_objective)
# PPO loss is negative of objective
loss = -ppo_objective.sum()
Currently this function has the same inputs and outputs as the importance sampling REINFORCE objective above.
Additional Notes:
- The loss formulation above is quite general, since the user can organize the data generation and advantage estimation in their own code. For example, the main RL training scripts in the Tinker Cookbook use group-based rollouts with per-group advantage centering similar to GRPO (Shao et al., 2024 (opens in a new tab)).
- The functional implementations of REINFORCE and PPO do not use an additional KL term like the original GRPO work, which has been noted to be mathematically inconsistent (Zhang et al., 2025 (opens in a new tab); Tang et al., 2025 (opens in a new tab)). However, it is possible to include a KL regularization term as part of the reward, which is mathematically correct and we provide this option in our RL training code and examples (consider the incorporate_kl_penalty function).
- Notice that for all objectives we sum the token-level losses over the sequence length unlike some other loss implementations. If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation.
- We currently have fixed the hyperparameters and to 0.2.
Flexible loss functions: forward_backward_custom
For use cases outside of the above, we've provided the more flexible (but slower) methods forward_backward_custom
and forward_backward_custom_async
to compute a more general class of loss functions.
Usage
Here's a simple example of a custom loss function:
def logprob_squared_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
loss = (logprobs ** 2).sum()
return loss, {"logprob_squared_loss": loss.item()}
You can call this loss function with forward_backward_custom
like:
loss, metrics = training_client.forward_backward_custom(data, logprob_squared_loss)
You can also define loss functions which operate on multiple sequences at a time. For example, a loss function that computes the variance across the sequences (although practically useless) can be implemented as:
def variance_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
flat_logprobs = torch.cat(logprobs)
variance = torch.var(flat_logprobs)
return variance, {"variance_loss": variance.item()}
A more practical use case would be to compute a Bradley-Terry loss on pairwise comparison data -- a classic approach in RL from human feedback, as introduced and popularized by Learning to Summarize (opens in a new tab). Similarly, we can also implement Direct Preference Optimization (opens in a new tab), which also computes a loss involving pairs of sequences; see the DPO guide for more details.
If you're using a custom loss function that you think is generally useful, please let us know, and we'll add it to the list of built-in loss functions.
We detail the async
version of methods in the Async and Futures of these docs.
How forward_backward_custom
works
You don't need to read the following section to use forward_backward_custom
, but read on if you're curious about how it works under the hood. Tinker does NOT pickle your function or send it to the server. Instead, Tinker decomposes the gradient computation into a forward
call followed by a forward_backward
call on an appropriately designed weighted cross-entropy loss, which lets it compute exactly the right gradient.
Mathematically, this works as follows. First, consider the full nonlinear loss function:
loss = compute_loss_from_logprobs(compute_target_logprobs(params))
We construct a loss function that is linear in the logprobs, but has the same gradient with respect to params as the full nonlinear loss:
logprobs = compute_target_logprobs(params)
surrogate_loss = (logprobs * logprob_grads).sum()
# where logprob_grads = dLoss/dLogprobs
Here's a diagram showing what happens under the hood, on the client and server.
Since forward_backward_custom
does an extra forward pass, it requires 1.5x as many FLOPs as a single forward_backward
. It'll also take up to 3x as long (wall time), due to implementation details of how forward_backward
operations are scheduled.