Skip to content

Math RL

Train LLMs to solve math problems using reinforcement learning with correctness-based rewards.

What you'll build

A math-reasoning model trained with RL on arithmetic, MATH, or GSM8K datasets. The model learns to produce correct answers in \boxed{} format, scored by exact-match grading.

Prerequisites

uv pip install 'tinker-cookbook[math-rl]'

Key concepts

  • GRPO — group relative policy optimization, comparing multiple rollouts per prompt to estimate advantages
  • Exact-match reward — binary reward based on whether the extracted answer matches the ground truth

Run it

Arithmetic (fast sanity check)

python -m tinker_cookbook.recipes.math_rl.train \
    model_name=meta-llama/Llama-3.2-1B \
    group_size=4 \
    groups_per_batch=100 \
    learning_rate=1e-4

MATH dataset

python -m tinker_cookbook.recipes.math_rl.train \
    env=math \
    model_name=Qwen/Qwen3-8B \
    group_size=16 \
    groups_per_batch=64 \
    learning_rate=2e-5 \
    max_tokens=512

GSM8K

python -m tinker_cookbook.recipes.math_rl.train \
    env=gsm8k \
    model_name=meta-llama/Llama-3.1-8B-Instruct \
    group_size=64 \
    groups_per_batch=32 \
    learning_rate=8e-5 \
    max_tokens=1024

Expected results

Dataset Model Steps Accuracy
Arithmetic Llama-3.2-1B ~5 Reward 0.66 → 1.0
MATH Qwen3-8B 180 76.8%
GSM8K Llama-3.1-8B-Instruct 220 90.9%

For GSM8K, a smaller group_size=8 with larger groups_per_batch=64 achieves 88.2% in roughly a quarter of the time.

Example output

After training on MATH, the model produces chain-of-thought reasoning followed by a boxed answer. Here is an example from a Qwen3-8B model after 180 steps:

<|im_start|>user
The numbers 2, 3, 5, 7, 11, 13 are arranged in a multiplication table, with three along the top and the other three down the left.  The multiplication table is completed and the sum of the nine entries is tabulated.  What is the largest possible sum of the nine entries?

\[
\begin{array}{c||c|c|c|}
\times & a & b & c \\ \hline \hline
d & & & \\ \hline
e & & & \\ \hline
f & & & \\ \hline
\end{array}
\] Write your answer in \boxed{} format.<|im_end|>
<|im_start|>assistant
<think>
Okay, so we have to arrange the primes 2,3,5,7,11,13 into a 3x3 multiplication table. The sum of all entries is (a+b+c)*(d+e+f) - because the sum of the products is the product of the sums. To maximize the total sum, we need to maximize the product of the sums of the two groups. So, split the numbers into two groups of three, one group for the top row (a,b,c) and one for the left column (d,e,f). The product (sum1)*(sum2) is maximized when the sums are as close as possible? Or maybe when one is as large as possible and the other as small? Wait, for a fixed sum, the product is maximized when they're equal. But since the total sum of all numbers is 2+3+5+7+11+13= 41. So sum1 + sum2 =41. To maximize sum1*sum2, we need sum1 and sum2 as close as possible to 41/2=20.5. So find two subsets of three numbers that add up to around 20.5.

Let me try combinations. Let's see:

Possible sums for three numbers:

- 2+3+13=18

- 2+5+13=20
- 2+7+11=20

- 3+5+13=21

- 3+7+11=21

- 5+7+11=23

So the best possible is sum1=20 and sum2=21, product 420. So the maximum sum is 420. \boxed{420}</think>

\boxed{420}<|im_end|>

Metrics are logged to disk at /tmp/tinker-examples/math_rl/math-Qwen_Qwen3-8B-32rank-2e-05lr-${DATE}/metrics.jsonl.

Learn more