Skip to content

Tutorial 304: RL with Config

Prerequisites

Run it interactively [source]

curl -O https://raw.githubusercontent.com/thinking-machines-lab/tinker-cookbook/main/tutorials/304_rl_with_config.py && marimo edit 304_rl_with_config.py

Configure and run a full RL pipeline using the cookbook's RL abstractions with RLDatasetBuilder.

In tutorials 301-302 you wrote RL loops manually. The cookbook also provides rl.train.Config + rl.train.main() which handles:

  • Rollout collection (sync or async)
  • Advantage computation and data assembly
  • Pipelined training steps
  • Checkpointing, evaluation, and logging
from collections.abc import Sequence

import chz

from tinker_cookbook import renderers
from tinker_cookbook.rl.problem_env import ProblemEnv, ProblemGroupBuilder
from tinker_cookbook.rl.types import EnvGroupBuilder, RLDataset, RLDatasetBuilder
from tinker_cookbook.tokenizer_utils import get_tokenizer

Step 1 -- Define a ProblemEnv for math

We reuse the ProblemEnv pattern: the model answers arithmetic questions and gets reward 1 for correct answers.

import random

class ArithmeticEnv(ProblemEnv):
    """Single-turn env: solve a simple arithmetic problem."""

    def __init__(self, renderer, a, b, op):
        super().__init__(renderer)
        self.a, self.b, self.op = a, b, op
        if op == "+":
            self.answer = str(a + b)
        else:
            self.answer = str(a * b)

    def get_question(self):
        return f"What is {self.a} {self.op} {self.b}? Reply with just the number."

    def check_answer(self, response):
        return self.answer in response.strip()

    def check_format(self, response):
        return len(response.strip()) > 0

    def get_reference_answer(self):
        return self.answer

Step 2 -- Build an RLDatasetBuilder

The RLDatasetBuilder is a chz dataclass that the config system can serialize. It constructs the RLDataset at training time.

from functools import partial

class ArithmeticDataset(RLDataset):
    """Generates batches of arithmetic problems."""

    def __init__(self, renderer, batch_size, num_batches, group_size):
        self.renderer = renderer
        self.batch_size = batch_size
        self.num_batches = num_batches
        self.group_size = group_size
        self.rng = random.Random(42)

    def _make_group_builder(self):
        # Sample ONE problem per group. ProblemGroupBuilder makes
        # `group_size` envs from this thunk, all the same problem -- GRPO
        # centers advantages within a group, so a group must be one problem
        # sampled group_size times, not a mix of different problems.
        a = self.rng.randint(1, 300)
        b = self.rng.randint(1, 300)
        op = self.rng.choice(["+", "*"])
        return ProblemGroupBuilder(
            env_thunk=partial(ArithmeticEnv, self.renderer, a, b, op),
            num_envs=self.group_size,
            dataset_name="arithmetic",
        )

    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:
        return [self._make_group_builder() for _ in range(self.batch_size)]

    def __len__(self) -> int:
        return self.num_batches

@chz.chz
class ArithmeticDatasetBuilder(RLDatasetBuilder):
    model_name: str
    renderer_name: str
    batch_size: int = 4
    num_batches: int = 20
    group_size: int = 4

    async def __call__(self):
        tokenizer = get_tokenizer(self.model_name)
        renderer = renderers.get_renderer(self.renderer_name, tokenizer)
        train_ds = ArithmeticDataset(
            renderer, self.batch_size, self.num_batches, self.group_size
        )

Step 3 -- Create the RL Config and run

rl.train.Config accepts the dataset builder, model name, learning rate, and many optional knobs (KL penalty, loss function, async mode, etc.).

from tinker_cookbook.rl import train as rl_train

MODEL_NAME = "Qwen/Qwen3.5-4B"

rl_config = rl_train.Config(
    log_path="/tmp/tinker-tutorials/rl-config",
    model_name=MODEL_NAME,
    recipe_name="tutorial_rl",
    dataset_builder=ArithmeticDatasetBuilder(
        model_name=MODEL_NAME,
        renderer_name="qwen3_5_disable_thinking",
        batch_size=4,
        num_batches=20,
        group_size=4,
    ),
    learning_rate=1e-5,
    max_tokens=64,
    lora_rank=32,
    loss_fn="importance_sampling",
    eval_every=5,
    save_every=5,
    max_steps=10,  # Short run for the tutorial
)

print(f"Model:         {rl_config.model_name}")
print(f"Learning rate: {rl_config.learning_rate}")
print(f"Loss function: {rl_config.loss_fn}")
print(f"Max tokens:    {rl_config.max_tokens}")
Output
Model:         Qwen/Qwen3.5-4B
Learning rate: 1e-05
Loss function: importance_sampling
Max tokens:    64
api_key = mo.ui.text(kind="password", label="Paste your Tinker API key")
api_key  # noqa: B018
import os

mo.stop(
    "TINKER_API_KEY" not in os.environ and not api_key.value,
    "Paste your API key above",
)

if api_key.value:
    os.environ["TINKER_API_KEY"] = api_key.value

# Run the full RL pipeline
await rl_train.main(rl_config)

Step 4 -- Inspect reward curves

After training, check log_path for metrics (logged to console and optionally W&B). Key metrics to watch:

  • env/all/reward/total -- average reward across trajectories
  • env/all/correct -- fraction of correct answers
  • optim/kl_sample_train_v1 -- KL divergence from the sampling policy
from pathlib import Path

log_dir = Path("/tmp/tinker-tutorials/rl-config")
if log_dir.exists():
    for f in sorted(log_dir.iterdir()):
        print(f"  {f.name}")
else:
    print("(Log directory not found -- training may not have run)")
Output
  checkpoints.jsonl
  code.diff
  config.json
  iteration_000000
  iteration_000001
  iteration_000002
  iteration_000003
  iteration_000004
  iteration_000005
  iteration_000006
  iteration_000007
  iteration_000008
  iteration_000009
  logs.log
  metrics.jsonl
  timing_spans.jsonl

Summary

The rl.train.Config + rl.train.main() pattern handles:

  • Rollout collection with do_group_rollout
  • Advantage centering via compute_advantages
  • Pipelined forward_backward + optim_step
  • Optional KL penalty, async mode, and streaming minibatches

For custom RL loops, use the lower-level abstractions from tutorials 104, 301, and 302.