Cross-Entropy
For supervised learning, we implement the standard cross-entropy loss (i.e., negative log-likelihood), which optimizes the policy \(p_\theta\) to maximize the log-probability of the tokens \(x\):
where weights is either 0 or 1, typically generated from renderer.build_supervised_example() which returns (model_input, weights) (i.e., to specify the desired assistant turns to train on).
This is implemented as:
# Apply weights and compute elementwise loss
elementwise_loss = -target_logprobs * weights
# Apply sum reduction to get the total loss
loss = elementwise_loss.sum() # scalar
Input tensors:
target_tokens: array[(N,), int] | array[(N, K), int]— Target token IDsweights: array[(N,), float] | array[(N, K), float]— Token-level loss weights (typically from the renderer)
Output tensors:
logprobs: array[(N,), float] | array[(N, K), float]— Log probabilities of the requested target tokens
Output diagnostics:
loss:sum(scalar) — Sum of weighted cross-entropy losses
Top-K distillation
When the input tensors have shape (N, K), cross_entropy extracts the student's logprobs for \(K\) target tokens per position and computes:
where \(w_{t,k}\) are the user-provided weights. For soft-target distillation, set \(w_{t,k} = \tilde{p}_{\text{teacher}}(x_{t,k})\) - the teacher's probabilities renormalized over the top-K tokens. For hard-target training, put weight 1.0 on a single candidate and 0.0 on the rest.
The example below distills from Qwen3-235B-Instruct-2507 (teacher) into Qwen3-30B-Instruct-2507 (student) with \(K{=}20\). We sample a completion from the teacher, recover its top-K distribution at each generated position, and train the student to match it.
import tinker
import torch
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-235B-A22B-Instruct-2507")
K = 20
service_client = tinker.ServiceClient()
teacher_sampling_client = await service_client.create_sampling_client_async(
base_model="Qwen/Qwen3-235B-A22B-Instruct-2507"
)
student_training_client = await service_client.create_lora_training_client_async(
base_model="Qwen/Qwen3-30B-A3B-Instruct-2507"
)
messages = [{"role": "user", "content": "Write a Python function to compute the nth Fibonacci number efficiently."}]
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# 1. Sample a completion from the teacher.
sample_response = await teacher_sampling_client.sample_async(
prompt=tinker.ModelInput.from_ints(prompt_tokens),
num_samples=1,
sampling_params=tinker.SamplingParams(max_tokens=512, temperature=0.7),
)
sampled_tokens = list(sample_response.sequences[0].tokens)
# 2. Teacher-force the completion to recover top-K logprobs at each position.
teacher_forced = tinker.ModelInput.from_ints(prompt_tokens + sampled_tokens)
topk_response = await teacher_sampling_client.sample_async(
prompt=teacher_forced,
num_samples=1,
sampling_params=tinker.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
topk_prompt_logprobs=K,
)
output_toks_with_logprobs = topk_response.topk_prompt_logprobs[len(prompt_tokens):]
seq_len = len(output_toks_with_logprobs)
teacher_tokens = torch.tensor(
[[tok_id for tok_id, _ in row] for row in output_toks_with_logprobs], dtype=torch.long
) # [seq_len, K]
teacher_logprobs = torch.tensor(
[[lp for _, lp in row] for row in output_toks_with_logprobs]
) # [seq_len, K]
# Renormalize teacher logprobs over top-K via logsumexp.
teacher_logprobs -= torch.logsumexp(teacher_logprobs, dim=-1, keepdim=True)
# Build shifted student_input: prompt + completion[:-1], so position t predicts token t+1.
student_input = prompt_tokens + sampled_tokens[:-1]
gen_start = len(prompt_tokens) - 1
target_tokens = torch.zeros(len(student_input), K, dtype=torch.long)
target_tokens[gen_start : gen_start + seq_len] = teacher_tokens
weights = torch.zeros(len(student_input), K)
weights[gen_start : gen_start + seq_len] = teacher_logprobs.exp()
datum = tinker.Datum(
model_input=tinker.ModelInput.from_ints(student_input),
loss_fn_inputs={
"target_tokens": target_tokens,
"weights": weights,
},
)
fwdbwd_future = await student_training_client.forward_backward_async(
[datum], loss_fn="cross_entropy"
)
fwdbwd_result = await fwdbwd_future.result_async()