Skip to content

Tutorial 204: Weights and Checkpoints

Prerequisites

Run it interactively [source]

curl -O https://raw.githubusercontent.com/thinking-machines-lab/tinker-cookbook/main/tutorials/204_weights.py && marimo edit 204_weights.py

Tinker stores model checkpoints on the server. This tutorial covers the full checkpoint lifecycle:

  1. Save -- save_weights_for_sampler (for inference) and save_state (for resuming training)
  2. Resume -- create_training_client_from_state to continue training from a checkpoint
  3. Manage -- RestClient to list, set TTL, publish, and unpublish checkpoints
  4. Download -- weights.download to pull checkpoints to local disk for merging or serving
Train --> save_weights_for_sampler --> create_sampling_client (inference)
      \-> save_state ----------------> create_training_client_from_state (resume)
                                   \-> weights.download (local export)
import warnings

warnings.filterwarnings("ignore", message="IProgress not found")

import tinker
import torch
from tinker import TensorData

Setup: train for one step

We need a trained checkpoint to work with. Let's create a training client and run a single training step.

api_key = mo.ui.text(kind="password", label="Paste your Tinker API key")
api_key  # noqa: B018
import os

mo.stop(
    "TINKER_API_KEY" not in os.environ and not api_key.value,
    "Paste your API key above",
)

if api_key.value:
    os.environ["TINKER_API_KEY"] = api_key.value

MODEL_NAME = "Qwen/Qwen3.5-4B"

service_client = tinker.ServiceClient()
training_client = await service_client.create_lora_training_client_async(
    base_model=MODEL_NAME, rank=16
)
tokenizer = training_client.get_tokenizer()

# One quick SFT step
text = "The Pythagorean theorem states that a^2 + b^2 = c^2."
ids = tokenizer.encode(text)
model_input = tinker.ModelInput.from_ints(ids[:-1])
target_tokens = ids[1:]
_weights = [1.0] * len(target_tokens)

datum = tinker.Datum(
    model_input=model_input,
    loss_fn_inputs={
        "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
        "weights": TensorData.from_torch(torch.tensor(_weights)),
    },
)

fb_future = await training_client.forward_backward_async([datum], loss_fn="cross_entropy")
await fb_future.result_async()
optim_future = await training_client.optim_step_async(tinker.AdamParams(learning_rate=1e-4))
await optim_future.result_async()
print("Training step complete")
Output
Training step complete

Saving checkpoints

Tinker has two types of saves:

Method What it saves Use case
save_weights_for_sampler Model weights only Create a SamplingClient for inference
save_state Weights + optimizer state Resume training later

Both return an APIFuture whose result contains a path -- a tinker:// URI that identifies the checkpoint.

# Save weights for inference (sampler checkpoint)
sampler_future = training_client.save_weights_for_sampler("tutorial-sampler")
sampler_path = (await sampler_future.result_async()).path
print(f"Sampler weights saved to: {sampler_path}")

# Save full state for resuming training
state_future = training_client.save_state("tutorial-state")
state_path = (await state_future.result_async()).path
print(f"Training state saved to:  {state_path}")
Output
Sampler weights saved to: tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-sampler
Training state saved to:  tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/weights/tutorial-state

TTL on checkpoints

You can set a time-to-live when saving. Checkpoints expire and are deleted after the TTL is reached. This is useful for intermediate checkpoints during training.

# Save with a 1-hour TTL
ephemeral_future = training_client.save_weights_for_sampler(
    "tutorial-ephemeral", ttl_seconds=3600
)
ephemeral_path = (await ephemeral_future.result_async()).path
print(f"Ephemeral checkpoint (1h TTL): {ephemeral_path}")
Output
Ephemeral checkpoint (1h TTL): tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-ephemeral

Resuming training from a checkpoint

create_training_client_from_state loads weights (but resets optimizer state). Use create_training_client_from_state_with_optimizer to also restore Adam momentum.

# Resume training from the saved state (weights only, fresh optimizer)
resumed_client = await service_client.create_training_client_from_state_async(state_path)
print(f"Resumed training client from: {state_path}")

# The resumed client has the same trained weights but a fresh optimizer.
# You can also use create_training_client_from_state_with_optimizer_async
# to restore the full optimizer state (Adam momentum, etc).
info = await resumed_client.get_info_async()
print(f"Model ID: {info.model_id}")
Output
Resumed training client from: tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/weights/tutorial-state
Model ID: 86e22e25-3753-575e-80cb-4644401e3a66:train:1

Using the sampler checkpoint for inference

The sampler checkpoint can be loaded as a SamplingClient for inference. Use service_client.create_sampling_client(model_path=...).

# Create a sampling client from the saved checkpoint
fine_tuned_sampler = await service_client.create_sampling_client_async(model_path=sampler_path)

prompt_text = "The Pythagorean theorem"
prompt_ids = tokenizer.encode(prompt_text)
prompt = tinker.ModelInput.from_ints(prompt_ids)

result = await fine_tuned_sampler.sample_async(
    prompt=prompt,
    sampling_params=tinker.SamplingParams(max_tokens=50, temperature=0.5, stop=["\n"]),
    num_samples=3,
)

for sequence in result.sequences:
    print(prompt_text + tokenizer.decode(sequence.tokens))
Output
The Pythagorean theorem is given by the formula $a^2 + b^2 = c^2$, where $a$ and $b$ are the lengths of the legs of a right triangle, and $c$ is the length of the hypotenuse.
The Pythagorean theorem states that for any right triangle, the square of the length of the hypotenuse (the side opposite the right angle) equals the sum of the squares of the lengths of the other two sides. This is written as:  

The Pythagorean theorem states that in a right triangle, the square of the hypotenuse equals the sum of the squares of the other two sides: $a^2 + b^2 = c^2$. This relationship can be visualized by constructing squares on each

Managing checkpoints with RestClient

The RestClient provides REST API access for checkpoint management. You get one via service_client.create_rest_client().

rest_client = service_client.create_rest_client()

# Get the run ID from the training client's model_id
_info = await training_client.get_info_async()
model_id = _info.model_id
# model_id format: "<run_id>:train:<seq>"
print(f"Model ID: {model_id}")
Output
Model ID: 86e22e25-3753-575e-80cb-4644401e3a66:train:0

List checkpoints

list_checkpoints shows all checkpoints for a training run. list_user_checkpoints shows checkpoints across all your training runs.

# List checkpoints for this training run
checkpoints_response = await rest_client.list_checkpoints_async(model_id)
print(f"Found {len(checkpoints_response.checkpoints)} checkpoints:")
for cp in checkpoints_response.checkpoints:
    print(f"  [{cp.checkpoint_type}] {cp.checkpoint_id}")
Output
Found 3 checkpoints:
  [sampler] sampler_weights/tutorial-sampler
  [training] weights/tutorial-state
  [sampler] sampler_weights/tutorial-ephemeral
# List all your checkpoints across training runs
all_checkpoints = await rest_client.list_user_checkpoints_async(limit=5)
print(f"Recent checkpoints across all runs ({len(all_checkpoints.checkpoints)}):")
for _cp in all_checkpoints.checkpoints:
    print(f"  {_cp.tinker_path} ({_cp.checkpoint_type})")
Output
Recent checkpoints across all runs (5):
  tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-ephemeral (sampler)
  tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/weights/tutorial-state (training)
  tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-sampler (sampler)
  tinker://0c1995b2-1b97-57fe-b7d6-0352f93be9f8:train:0/weights/tutorial-state (training)
  tinker://0c1995b2-1b97-57fe-b7d6-0352f93be9f8:train:0/sampler_weights/tutorial-sampler (sampler)

Set TTL on existing checkpoints

You can change or remove the TTL on any checkpoint after creation.

# Set a 7-day TTL on the sampler checkpoint
await rest_client.set_checkpoint_ttl_from_tinker_path_async(
    sampler_path, ttl_seconds=7 * 24 * 3600
)
print(f"Set 7-day TTL on {sampler_path}")

# Remove TTL (keep indefinitely)
await rest_client.set_checkpoint_ttl_from_tinker_path_async(sampler_path, ttl_seconds=None)
print(f"Removed TTL on {sampler_path}")
Output
Set 7-day TTL on tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-sampler
Removed TTL on tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-sampler

Publish and unpublish checkpoints

Publishing a checkpoint makes it accessible to other users. Only the owner can publish or unpublish.

# Publish the checkpoint
await rest_client.publish_checkpoint_from_tinker_path_async(sampler_path)
print(f"Published: {sampler_path}")

# Unpublish it
await rest_client.unpublish_checkpoint_from_tinker_path_async(sampler_path)
print(f"Unpublished: {sampler_path}")
Output
Published: tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-sampler
Unpublished: tinker://86e22e25-3753-575e-80cb-4644401e3a66:train:0/sampler_weights/tutorial-sampler

Downloading weights locally

weights.download fetches a checkpoint archive from Tinker storage and extracts it to a local directory. This is the first step for merging LoRA weights into a full model or serving with vLLM.

import asyncio

from tinker_cookbook import weights

# Download the sampler checkpoint to a local directory
adapter_dir = await asyncio.to_thread(
    weights.download,
    tinker_path=sampler_path,
    output_dir="/tmp/tinker-tutorials/weights-adapter",
)
print(f"Downloaded adapter to: {adapter_dir}")
Output
Downloaded adapter to: /tmp/tinker-tutorials/weights-adapter

After downloading

Once you have the adapter files locally, see the deployment tutorials for next steps:

  • Export a Merged HuggingFace Model -- merge LoRA into a standalone model with weights.build_hf_model()
  • Build a PEFT LoRA Adapter -- convert to PEFT format for serving with vLLM or SGLang via weights.build_lora_adapter()
  • Publish to HuggingFace Hub -- upload models with custom model cards via weights.publish_to_hf_hub()

Checkpoint lifecycle summary

create_lora_training_client()
    |
    v
[Train: forward_backward + optim_step]
    |
    +-- save_weights_for_sampler("name")
    |       |
    |       +-- create_sampling_client(model_path=...) --> inference
    |       +-- weights.download(tinker_path=...) -------> local export
    |       +-- rest_client.publish_checkpoint_from_tinker_path(...)
    |       +-- rest_client.set_checkpoint_ttl_from_tinker_path(...)
    |
    +-- save_state("name")
            |
            +-- create_training_client_from_state(path) --> resume training
            +-- create_training_client_from_state_with_optimizer(path) --> resume with optimizer