Skip to content

Tutorial 204: Weights and Checkpoints

Prerequisites

Run it interactively

curl -O https://raw.githubusercontent.com/thinking-machines-lab/tinker-cookbook/main/tutorials/204_weights.py && uv run marimo edit 204_weights.py

Tinker stores model checkpoints on the server. This tutorial covers the full checkpoint lifecycle:

  1. Save -- save_weights_for_sampler (for inference) and save_state (for resuming training)
  2. Resume -- create_training_client_from_state to continue training from a checkpoint
  3. Manage -- RestClient to list, set TTL, publish, and unpublish checkpoints
  4. Download -- weights.download to pull checkpoints to local disk for merging or serving
Train --> save_weights_for_sampler --> create_sampling_client (inference)
      \-> save_state ----------------> create_training_client_from_state (resume)
                                   \-> weights.download (local export)
import warnings

warnings.filterwarnings("ignore", message="IProgress not found")

import tinker
import torch
from tinker import TensorData

Setup: train for one step

We need a trained checkpoint to work with. Let's create a training client and run a single training step.

MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"

service_client = tinker.ServiceClient()
training_client = await service_client.create_lora_training_client_async(
    base_model=MODEL_NAME, rank=16
)
tokenizer = training_client.get_tokenizer()

# One quick SFT step
text = "The Pythagorean theorem states that a^2 + b^2 = c^2."
ids = tokenizer.encode(text)
model_input = tinker.ModelInput.from_ints(ids[:-1])
target_tokens = ids[1:]
weights = [1.0] * len(target_tokens)

datum = tinker.Datum(
    model_input=model_input,
    loss_fn_inputs={
        "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
        "weights": TensorData.from_torch(torch.tensor(weights)),
    },
)

fb_future = await training_client.forward_backward_async([datum], loss_fn="cross_entropy")
await fb_future.result_async()
optim_future = await training_client.optim_step_async(tinker.AdamParams(learning_rate=1e-4))
await optim_future.result_async()
print("Training step complete")

Saving checkpoints

Tinker has two types of saves:

Method What it saves Use case
save_weights_for_sampler Model weights only Create a SamplingClient for inference
save_state Weights + optimizer state Resume training later

Both return an APIFuture whose result contains a path -- a tinker:// URI that identifies the checkpoint.

# Save weights for inference (sampler checkpoint)
sampler_result = await training_client.save_weights_for_sampler_async("tutorial-sampler")
sampler_path = sampler_result.path
print(f"Sampler weights saved to: {sampler_path}")

# Save full state for resuming training
state_result = await training_client.save_state_async("tutorial-state")
state_path = state_result.path
print(f"Training state saved to:  {state_path}")

TTL on checkpoints

You can set a time-to-live when saving. Checkpoints expire and are deleted after the TTL. This is useful for intermediate checkpoints during training.

# Save with a 1-hour TTL
ephemeral_result = await training_client.save_weights_for_sampler_async(
    "tutorial-ephemeral", ttl_seconds=3600
)
print(f"Ephemeral checkpoint (1h TTL): {ephemeral_result.path}")

Resuming training from a checkpoint

create_training_client_from_state loads weights (but resets optimizer state). Use create_training_client_from_state_with_optimizer to also restore Adam momentum.

# Resume training from the saved state (weights only, fresh optimizer)
resumed_client = await service_client.create_training_client_from_state_async(state_path)
print(f"Resumed training client from: {state_path}")

# The resumed client has the same trained weights but a fresh optimizer.
# You can also use create_training_client_from_state_with_optimizer_async
# to restore the full optimizer state (Adam momentum, etc).
info = resumed_client.get_info()
print(f"Training run ID: {info.training_run_id}")

Using the sampler checkpoint for inference

The sampler checkpoint can be loaded as a SamplingClient for inference. Use service_client.create_sampling_client(model_path=...).

# Create a sampling client from the saved checkpoint
fine_tuned_sampler = await service_client.create_sampling_client_async(model_path=sampler_path)

prompt_text = "The Pythagorean theorem"
prompt_ids = tokenizer.encode(prompt_text)
prompt = tinker.ModelInput.from_ints(prompt_ids)

result = await fine_tuned_sampler.sample_async(
    prompt=prompt,
    sampling_params=tinker.SamplingParams(max_tokens=50, temperature=0.5, stop=["\n"]),
    num_samples=1,
)

print(prompt_text + tokenizer.decode(result.sequences[0].tokens))

Managing checkpoints with RestClient

The RestClient provides REST API access for checkpoint management. You get one via service_client.create_rest_client().

rest_client = service_client.create_rest_client()

# Get the training run ID from the training client
run_info = training_client.get_info()
run_id = run_info.training_run_id
print(f"Training run: {run_id}")

List checkpoints

list_checkpoints shows all checkpoints for a training run. list_user_checkpoints shows checkpoints across all your training runs.

# List checkpoints for this training run
checkpoints_response = rest_client.list_checkpoints(run_id).result()
print(f"Found {len(checkpoints_response.checkpoints)} checkpoints:")
for cp in checkpoints_response.checkpoints:
    print(f"  [{cp.checkpoint_type}] {cp.checkpoint_id}")
# List all your checkpoints across training runs
all_checkpoints = rest_client.list_user_checkpoints(limit=5).result()
print(f"Recent checkpoints across all runs ({len(all_checkpoints.checkpoints)}):")
for _cp in all_checkpoints.checkpoints:
    print(f"  {_cp.training_run_id}/{_cp.checkpoint_id} ({_cp.checkpoint_type})")

Set TTL on existing checkpoints

You can change or remove the TTL on any checkpoint after creation.

# Set a 7-day TTL on the sampler checkpoint
rest_client.set_checkpoint_ttl_from_tinker_path(
    sampler_path, ttl_seconds=7 * 24 * 3600
).result()
print(f"Set 7-day TTL on {sampler_path}")

# Remove TTL (keep indefinitely)
rest_client.set_checkpoint_ttl_from_tinker_path(sampler_path, ttl_seconds=None).result()
print(f"Removed TTL on {sampler_path}")

Publish and unpublish checkpoints

Publishing a checkpoint makes it accessible to other users. Only the owner can publish or unpublish.

# Publish the checkpoint
rest_client.publish_checkpoint_from_tinker_path(sampler_path).result()
print(f"Published: {sampler_path}")

# Unpublish it
rest_client.unpublish_checkpoint_from_tinker_path(sampler_path).result()
print(f"Unpublished: {sampler_path}")

Downloading weights locally

weights.download fetches a checkpoint archive from Tinker storage and extracts it to a local directory. This is the first step for merging LoRA weights into a full model or serving with vLLM.

from tinker_cookbook import weights

# Download the sampler checkpoint to a local directory
adapter_dir = weights.download(
    tinker_path=sampler_path,
    output_dir="/tmp/tinker-tutorial-adapter",
)
print(f"Downloaded adapter to: {adapter_dir}")

After downloading

Once you have the adapter files locally, see the deployment tutorials for next steps:

Checkpoint lifecycle summary

create_lora_training_client()
    |
    v
[Train: forward_backward + optim_step]
    |
    +-- save_weights_for_sampler("name")
    |       |
    |       +-- create_sampling_client(model_path=...) --> inference
    |       +-- weights.download(tinker_path=...) -------> local export
    |       +-- rest_client.publish_checkpoint_from_tinker_path(...)
    |       +-- rest_client.set_checkpoint_ttl_from_tinker_path(...)
    |
    +-- save_state("name")
            |
            +-- create_training_client_from_state(path) --> resume training
            +-- create_training_client_from_state_with_optimizer(path) --> resume with optimizer