Saving and loading weights and optimizer state
During training, you'll need to save checkpoints for two main purposes: sampling (to test your model) and resuming training (to continue from where you left off). The TrainingClient
provides three methods to handle these cases:
save_weights_for_sampler()
: saves a copy of the model weights that can be used for sampling.save_state()
: saves the weights and the optimizer state. You can fully resume training from this checkpoint.load_state()
: load the weights and the optimizer state. You can fully resume training from this checkpoint.
Note that (1) is faster and requires less storage space than (2).
Both save_*
functions require a name
parameter---a string that you can set to identify the checkpoint within the current training run. For example, you can name your checkpoints "0000"
, "0001"
, "step_1000"
, etc.
The return value contains a path
field, which is a fully-qualified path, which will look something like tinker://<model_id>/<name>
. This path is persistent and can be loaded later by a new ServiceClient
or TrainingClient
.
Example: Saving for sampling
# Setup
import tinker
service_client = tinker.ServiceClient()
training_client = service_client.create_lora_training_client(
base_model="meta-llama/Llama-3.2-1B", rank=32
)
# Save a checkpoint that you can use for sampling
sampling_path = training_client.save_weights_for_sampler(name="0000").result().path
# Create a sampling client with that checkpoint
sampling_client = service_client.create_sampling_client(model_path=sampling_path) #
Shortcut: Combine these steps with:
sampling_client = training_client.save_weights_and_get_sampling_client(name="0000")
Example: Saving to resume training
Use save_state()
and load_state()
when you need to pause and continue training with full optimizer state preserved:
# Save a checkpoint that you can resume from
resume_path = training_client.save_state(name="0010").result().path
# Load that checkpoint
training_client.load_state(resume_path)
When to use save_state()
and load_state()
:
- Multi-step training pipelines (e.g. supervised learning followed by reinforcement learning)
- Adjusting hyperparameters or data mid-run
- Recovery from interruptions or failures
- Any scenario where you need to preserve exact optimizer state (momentum, learning rate schedules, etc.)