Skip to content

Cross-Entropy

For supervised learning, we implement the standard cross-entropy loss (i.e., negative log-likelihood), which optimizes the policy \(p_\theta\) to maximize the log-probability of the tokens \(x\):

\[ \mathcal{L(\theta)} = -\mathbb{E}_x[\log p_\theta(x)] \]

where weights is either 0 or 1, typically generated from renderer.build_supervised_example() which returns (model_input, weights) (i.e., to specify the desired assistant turns to train on).

This is implemented as:

# Apply weights and compute elementwise loss
elementwise_loss = -target_logprobs * weights
# Apply sum reduction to get the total loss
loss = elementwise_loss.sum()  # scalar

Input tensors:

  • target_tokens: array[(N,), int] | array[(N, K), int] — Target token IDs
  • weights: array[(N,), float] | array[(N, K), float] — Token-level loss weights (typically from the renderer)

Output tensors:

  • logprobs: array[(N,), float] | array[(N, K), float] — Log probabilities of the requested target tokens

Output diagnostics:

  • loss:sum (scalar) — Sum of weighted cross-entropy losses

Top-K distillation

When the input tensors have shape (N, K), cross_entropy extracts the student's logprobs for \(K\) target tokens per position and computes:

\[ \mathcal{L}_{\text{CE-topK}}(\theta) = -\sum_{t}\sum_{k=1}^{K} w_{t,k} \cdot \log p_\theta(x_{t,k}) \]

where \(w_{t,k}\) are the user-provided weights. For soft-target distillation, set \(w_{t,k} = \tilde{p}_{\text{teacher}}(x_{t,k})\) - the teacher's probabilities renormalized over the top-K tokens. For hard-target training, put weight 1.0 on a single candidate and 0.0 on the rest.

The example below distills from Qwen3-235B-Instruct-2507 (teacher) into Qwen3-30B-Instruct-2507 (student) with \(K{=}20\). We sample a completion from the teacher, recover its top-K distribution at each generated position, and train the student to match it.

import tinker
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-235B-A22B-Instruct-2507")
K = 20

service_client = tinker.ServiceClient()
teacher_sampling_client = await service_client.create_sampling_client_async(
    base_model="Qwen/Qwen3-235B-A22B-Instruct-2507"
)
student_training_client = await service_client.create_lora_training_client_async(
    base_model="Qwen/Qwen3-30B-A3B-Instruct-2507"
)

messages = [{"role": "user", "content": "Write a Python function to compute the nth Fibonacci number efficiently."}]
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

# 1. Sample a completion from the teacher.
sample_response = await teacher_sampling_client.sample_async(
    prompt=tinker.ModelInput.from_ints(prompt_tokens),
    num_samples=1,
    sampling_params=tinker.SamplingParams(max_tokens=512, temperature=0.7),
)
sampled_tokens = list(sample_response.sequences[0].tokens)

# 2. Teacher-force the completion to recover top-K logprobs at each position.
teacher_forced = tinker.ModelInput.from_ints(prompt_tokens + sampled_tokens)
topk_response = await teacher_sampling_client.sample_async(
    prompt=teacher_forced,
    num_samples=1,
    sampling_params=tinker.SamplingParams(max_tokens=1),
    include_prompt_logprobs=True,
    topk_prompt_logprobs=K,
)

output_toks_with_logprobs = topk_response.topk_prompt_logprobs[len(prompt_tokens):]
seq_len = len(output_toks_with_logprobs)

teacher_tokens = torch.tensor(
    [[tok_id for tok_id, _ in row] for row in output_toks_with_logprobs], dtype=torch.long
)  # [seq_len, K]
teacher_logprobs = torch.tensor(
    [[lp for _, lp in row] for row in output_toks_with_logprobs]
)  # [seq_len, K]

# Renormalize teacher logprobs over top-K via logsumexp.
teacher_logprobs -= torch.logsumexp(teacher_logprobs, dim=-1, keepdim=True)

# Build shifted student_input: prompt + completion[:-1], so position t predicts token t+1.
student_input = prompt_tokens + sampled_tokens[:-1]
gen_start = len(prompt_tokens) - 1

target_tokens = torch.zeros(len(student_input), K, dtype=torch.long)
target_tokens[gen_start : gen_start + seq_len] = teacher_tokens

weights = torch.zeros(len(student_input), K)
weights[gen_start : gen_start + seq_len] = teacher_logprobs.exp()

datum = tinker.Datum(
    model_input=tinker.ModelInput.from_ints(student_input),
    loss_fn_inputs={
        "target_tokens": target_tokens,
        "weights": weights,
    },
)
fwdbwd_future = await student_training_client.forward_backward_async(
    [datum], loss_fn="cross_entropy"
)
fwdbwd_result = await fwdbwd_future.result_async()