Skip to content

Custom Loss Functions

For use cases outside of the built-in losses, Tinker provides the more flexible (but slower) methods forward_backward_custom and forward_backward_custom_async to compute a more general class of loss functions.

Usage

A custom loss function receives the data and the model's logprobs, and returns a scalar loss plus optional metrics:

def logprob_squared_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
    loss = (logprobs ** 2).sum()
    return loss, {"logprob_squared_loss": loss.item()}

Call it with forward_backward_custom:

future = training_client.forward_backward_custom(data, logprob_squared_loss)
result = future.result()
print(f"Loss: {result.loss}, Metrics: {result.metrics}")

You can also define loss functions which operate on multiple sequences at a time. For example, a loss function that computes the variance across the sequences (although practically useless) can be implemented as:

def variance_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
    flat_logprobs = torch.cat(logprobs)
    variance = torch.var(flat_logprobs)
    return variance, {"variance_loss": variance.item()}

A more practical use case would be to compute a Bradley-Terry loss on pairwise comparison data -- a classic approach in RL from human feedback, as introduced and popularized by Learning to Summarize. Similarly, we can also implement Direct Preference Optimization, which also computes a loss involving pairs of sequences; see the DPO & Preferences tutorial for more details.

Multi-target loss

You can also use forward_backward_custom with multi-target target_tokens to define losses over a small candidate set at specific sequence positions. For example, the following masked cross-entropy loss trains a model on a multiple-choice question by renormalizing over four answer-letter tokens instead of the full vocabulary. Because this renormalization is nonlinear in the extracted target-token logprobs, it's somewhat cumbersome to express in terms of the cross_entropy primitive:

import torch
import tinker

messages = [{"role": "user", "content": (
    "What is the capital of France?\n"
    "A) London\n"
    "B) Paris\n"
    "C) Berlin\n"
    "D) Madrid\n"
    "Answer Letter:"
)}]
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

# Extract token IDs for the four answer letters.
a_token = tokenizer.encode(" A", add_special_tokens=False)[0]
b_token = tokenizer.encode(" B", add_special_tokens=False)[0]
c_token = tokenizer.encode(" C", add_special_tokens=False)[0]
d_token = tokenizer.encode(" D", add_special_tokens=False)[0]

target_tokens = torch.zeros(len(prompt_tokens), 4, dtype=torch.long)
target_tokens[-1] = torch.tensor([a_token, b_token, c_token, d_token])

datum = tinker.Datum(
    model_input=tinker.ModelInput.from_ints(prompt_tokens),
    loss_fn_inputs={"target_tokens": target_tokens},
)

CORRECT_ANSWER_IDX = 1  # B) Paris

def masked_ce_loss(data, logprobs_list):
    # logprobs_list[i] has shape [N, K=4] for each datum
    # Renormalize over just the four answer tokens instead of the full vocabulary.
    answer_logprobs = logprobs_list[0][-1]  # shape [K=4]
    answer_logprobs = answer_logprobs - torch.logsumexp(answer_logprobs, dim=-1)
    loss = -answer_logprobs[CORRECT_ANSWER_IDX]
    return loss, {"masked_ce": loss.item()}

future = training_client.forward_backward_custom(
    [datum],
    masked_ce_loss,
)
result = future.result()

If you're using a custom loss function that you think is generally useful, please let us know, and we'll add it to the list of built-in loss functions.

We detail the async version of methods in the Async Patterns tutorial.

How forward_backward_custom works

You don't need to read the following section to use forward_backward_custom, but read on if you're curious about how it works under the hood. Tinker does NOT pickle your function or send it to the server. Instead, Tinker decomposes the gradient computation into a forward call followed by a forward_backward call on an appropriately designed weighted cross-entropy loss, which lets it compute exactly the right gradient.

Mathematically, this works as follows. First, consider the full nonlinear loss function:

loss = compute_loss_from_logprobs(compute_target_logprobs(params))

We construct a loss function that is linear in the logprobs, but has the same gradient with respect to params as the full nonlinear loss:

logprobs = compute_target_logprobs(params)
surrogate_loss = (logprobs * logprob_grads).sum()
# where logprob_grads = dLoss/dLogprobs

Here's what happens under the hood — the client and server collaborate in two passes:

1. Prepare dataCLIENT — build list of Datum objects

2. Forward passSERVER — compute logprobs for target tokens

3. Custom lossCLIENT — loss = custom_fn(logprobs)

4. BackwardCLIENT — loss.backward() → grad_outputs

5. Forward-backwardSERVER — surrogate loss: sum(logprobs × grad_outputs) → weight gradients

Since forward_backward_custom does an extra forward pass, it requires 1.5x as many FLOPs as a single forward_backward. It'll also take up to 3x as long (wall time), due to implementation details of how forward_backward operations are scheduled.