Vibed Learning
Welcome to Vibed Learning — a collection of self-guided courses and technical references on programming, computer science, and software engineering.
What This Is
This site is an experiment in generating custom learning resources using large language models. Every course and reference here was created with Claude Code, using Anthropic’s Opus 4.6 and Sonnet 4.6 models. The goal is to explore how AI-assisted authorship can produce educational content that is clear, accurate, and genuinely useful.
The content is a work in progress. New courses are added as topics are explored, and existing material is revised as needed.
Content
Browse the table of contents on the left to find courses. Current topics include:
- Markov Chains — a self-guided introduction to the math and intuition behind Markov chains
- Vector Databases — how vector search works and when to use it
- Git Worktrees — using multiple working trees with a single Git repository
- Writing a Lisp-to-C Compiler in Rust — building a compiler from scratch
License
All content on this site is released under CC0 (Public Domain). You may use, copy, modify, and distribute it freely, without asking for permission and without attribution.
Contact
Questions, corrections, or suggestions: vibebooks@elijah.run
Markov Chain Self-Guided Course
This document is a self-guided course on Markov chains. It is organized into four parts: conceptual foundations, first Rust implementations, text generation, and deeper theory. Each section is either a reading lesson or a hands-on Rust programming exercise. Sections marked 🚧 are stubs whose full content is tracked in an nbd ticket — follow the ticket ID to find the detailed learning objectives and instructions.
Table of Contents
Part 1 — Foundations
Part 2 — First Implementation
Part 3 — Text Generation
- Text Generation with Markov Chains
- Exercise 3 — Bigram Text Generator
- Exercise 4 — N-gram Generalization
Part 4 — Deeper Concepts
Part 1 — Foundations
1. What Is a Markov Chain?
A Markov chain is a mathematical model describing a sequence of events where the probability of each event depends only on the state reached in the previous event — not on the full history. This “memoryless” property is called the Markov property. You will learn where Markov chains appear in the real world and develop intuition for why the memoryless property is both a useful simplification and a meaningful assumption.
The core idea: only now matters. Imagine you are tracking today’s weather. Intuitively, you might think yesterday’s weather, the week before, and the entire season all influence what tomorrow will bring. A Markov chain says: forget all of that. Given that you know today’s weather, knowledge of every earlier day adds nothing to your prediction of tomorrow. The present state captures everything relevant from the past. This is the Markov property — colloquially “memorylessness” — and it is a surprisingly powerful modelling assumption.
Formally, let X₀, X₁, X₂, … be a sequence of random variables each taking values in some set of states. The sequence is a Markov chain if, for every time step n and every state s:
P(Xₙ₊₁ = s | X₀, X₁, …, Xₙ) = P(Xₙ₊₁ = s | Xₙ)
The left-hand side conditions on the entire history up to step n. The right-hand side conditions on only the current state Xₙ. The equation says these two quantities are always equal: no matter how you got to the current state, your distribution over the next state is the same.
A worked example — the weather model. Suppose a city has two kinds of days: Sunny and Rainy. You observe that:
- After a Sunny day, there is an 80 % chance of another Sunny day and a 20 % chance of Rain.
- After a Rainy day, there is a 40 % chance of Sun and a 60 % chance of Rain.
Under the Markov assumption, these two rules are all you need. If today is Sunny, the chance of rain tomorrow is 20 % — regardless of whether the preceding week was a drought or a monsoon. The model is simple because it deliberately ignores deep history, and it is useful precisely because that history often adds little predictive power once you know the current state.
Where Markov chains appear in the real world. Once you recognise the Markov property you will spot it everywhere:
- Board games. In Snakes and Ladders the only thing that matters is which square you are on right now. The sequence of rolls that brought you there is irrelevant — your future depends only on your current position.
- Web surfing (PageRank). Google’s original PageRank algorithm modelled a hypothetical random web surfer who, at each page, clicks a link chosen uniformly at random. Where the surfer goes next depends only on the current page, not the path taken to reach it. The long-run fraction of time spent at each page is the page’s rank.
- Genetics. In simple population-genetics models, the number of copies of a gene in the current generation determines the distribution of copies in the next generation. The frequencies in prior generations, once summarised in the current count, carry no additional information.
- Text generation. Given the last word (or last few words) of a sentence, the probability of the next word can be estimated from a corpus. The full sentence history is ignored — only the recent context matters. Sections 6–8 of this course build exactly this kind of model in Rust.
Why the Markov property is a useful assumption. Real systems are almost never perfectly memoryless — yesterday’s weather genuinely does carry a whisper of information beyond today’s. So why use Markov models? Because they strike a remarkable balance between tractability and expressiveness. A model that conditions on the entire history is usually intractable; one that ignores history entirely is too crude. The Markov property is the sweet spot: it allows rigorous mathematical analysis (stationary distributions, convergence theorems, efficient simulation) while still capturing the essential dynamics of many real processes. When the assumption is too crude, you can extend it by enlarging the state space — for instance, tracking the last two days of weather instead of one — and the Markov property holds again at that richer level of description.
2. States and Transitions
Every Markov chain consists of a set of states and a rule for how the system moves between them. The set of all possible states is called the state space, commonly written S. In this course the state space is always discrete — it is either finite (e.g., {Sunny, Rainy}) or countably infinite (e.g., the non-negative integers {0, 1, 2, …}, as in a random walk on the number line). Discrete state spaces are by far the most common setting for introductory Markov chain theory and for the kinds of simulations you will build in Rust.
A transition is a single step from one state to another. At each discrete time step the chain occupies some state i ∈ S, then moves to state j ∈ S with probability P(i → j). Two things matter: (1) transitions are directed — going from i to j and going from j to i are distinct transitions with potentially different probabilities; and (2) every transition has an associated probability, and the probabilities of all transitions out of a given state must sum to 1. A transition from i back to i — a self-loop — is perfectly valid and simply means the chain stays put with some nonzero probability.
State-transition diagrams make these rules visual. Draw one node for each state and one labelled arrow for each possible transition; the label is the transition probability. For the two-state weather model from Section 1 the diagram looks like this:
0.8 0.6
╭─────────╮ ╭─────────╮
│ │ 0.2 │ │
▼ │ ─────────────► │ ▼
╔═══════╗ ───┘ └─── ╔═══════╗
║ SUNNY ║ ║ RAINY ║
╚═══════╝ ───┐ ┌─── ╚═══════╝
│ 0.4 │
└──────────────────┘
The self-loops (Sunny → Sunny with 0.8, Rainy → Rainy with 0.6) show that the chain can stay in the same state. The cross-arrows capture the transitions between states. Every row of outgoing arrows sums to 1: from Sunny, 0.8 + 0.2 = 1; from Rainy, 0.4 + 0.6 = 1.
States in a Markov chain differ in how they relate to the chain’s long-run behaviour. Three categories matter most:
- An absorbing state is one you can never leave: once the chain enters it, it stays there forever. A self-loop with probability 1 is the defining feature. A simple example is a gambler who reaches £0 in a gambling game with no credit — they are stuck.
- A transient state is one you are not guaranteed to return to. If you leave a transient state there is some positive probability of never coming back. In many models early states are transient — the chain passes through them and moves on.
- A recurrent state (also called a persistent state) is one you are guaranteed to return to eventually, with probability 1. In the weather model both Sunny and Rainy are recurrent: no matter which state you are in, you will eventually visit both states again and again.
The distinction between transient and recurrent states determines the chain’s long-run behaviour. A chain that only visits transient states will eventually leave them forever; a chain trapped in recurrent states will cycle among them indefinitely. Understanding this classification is the foundation for studying stationary distributions, which Section 9 covers in detail.
3. Transition Probabilities and Matrices
The rules governing how a Markov chain moves are captured in a transition matrix P, where P[i][j] is the probability of moving from state i to state j in one step. This section covers how to construct P, the constraints it must satisfy (rows sum to 1), and how to use matrix multiplication to compute multi-step probabilities.
Defining the transition matrix. Label the states 0, 1, …, n−1. The transition matrix P is an n × n array where entry P[i][j] gives the probability of moving to state j on the very next step, given that you are currently in state i. Because these are probabilities of mutually exclusive, exhaustive outcomes (from state i you must go somewhere), every row must sum to exactly 1 and every entry must lie between 0 and 1 inclusive. A matrix satisfying these two constraints is called a stochastic matrix (or row-stochastic matrix). Each row is itself a probability distribution over the next state.
The stochastic-matrix constraints, stated precisely. For an n-state chain:
P[i][j] >= 0 for all i, j
sum_j P[i][j] = 1 for every row i
A zero entry means the transition is impossible; a one means it is certain. Columns have no such constraint — column sums need not equal 1.
Multi-step probabilities via matrix multiplication. Suppose you start in state i at time 0. After one step the probability of being in state j is P[i][j]. After two steps you pass through some intermediate state k, so:
P^2[i][j] = sum_k P[i][k] * P[k][j]
This is exactly the (i, j) entry of P × P = P². In general, the probability of going from state i to state j in exactly k steps is the (i, j) entry of P^k. If you encode your current uncertainty as a row vector π₀ — a probability distribution over all states — then after k steps your updated distribution is:
π_k = π₀ · P^k
Each right-multiplication by P advances the clock one tick and blends probabilities according to the transition rules.
Worked example — a two-state weather chain. Consider a model with two states: Sunny (state 0) and Rainy (state 1). From data:
- If today is Sunny, tomorrow is Sunny with probability 0.8 and Rainy with probability 0.2.
- If today is Rainy, tomorrow is Sunny with probability 0.4 and Rainy with probability 0.6.
Writing this as a matrix:
Sunny Rainy
Sunny [ 0.8 0.2 ]
Rainy [ 0.4 0.6 ]
Row 0 sums to 1.0; row 1 sums to 1.0. All entries are non-negative. P is a valid stochastic matrix.
One step. Start with certainty in Sunny: π₀ = [1, 0].
π₁ = π₀ · P = [1, 0] · [[0.8, 0.2], [0.4, 0.6]] = [0.8, 0.2]
Tomorrow: 80 % Sunny, 20 % Rainy.
Two steps. Apply P again:
π₂ = π₁ · P = [0.8, 0.2] · [[0.8, 0.2], [0.4, 0.6]]
= [0.8*0.8 + 0.2*0.4, 0.8*0.2 + 0.2*0.6]
= [0.72, 0.28]
Equivalently, compute P² once and read off the row for state 0:
P^2 = [[0.8*0.8 + 0.2*0.4, 0.8*0.2 + 0.2*0.6],
[0.4*0.8 + 0.6*0.4, 0.4*0.2 + 0.6*0.6]]
= [[0.72, 0.28],
[0.56, 0.44]]
P²[0] = [0.72, 0.28] — matching the step-by-step result. Starting from Rainy gives P²[1] = [0.56, 0.44]; the two rows are already noticeably closer to each other than the original [0.8, 0.2] vs [0.4, 0.6]. As k grows, both rows converge toward the same limiting vector — the stationary distribution that Section 9 analyses in depth. The matrix-multiplication perspective makes this convergence precise and computable.
Part 2 — First Implementation
4. Exercise 1 — Weather Model
Goal: Build a two-state Markov chain in Rust that models daily weather as either Sunny or Rainy, driven by a transition matrix, and simulate 30 days of weather.
Setup
Create a new Cargo binary project and add the rand crate:
cargo new weather-chain
cd weather-chain
cargo add rand
Starter Code
Replace the contents of src/main.rs with the following skeleton. Do not change the struct layout or function signatures — your task is to fill in the todo!() bodies and write main.
use rand::Rng;
#[derive(Debug, Clone, Copy, PartialEq)]
enum Weather { Sunny, Rainy }
struct WeatherChain {
/// transition[current][next] = probability
transition: [[f64; 2]; 2],
}
impl WeatherChain {
fn step(&self, current: Weather, rng: &mut impl Rng) -> Weather { todo!() }
fn simulate(&self, start: Weather, steps: usize, rng: &mut impl Rng) -> Vec<Weather> { todo!() }
}
fn main() { todo!() }
Step 1 — Index conversion
step needs to index into self.transition using the current state, and later convert an index back to a Weather value. Add two helper methods to the Weather enum:
fn index(self) -> usize— returns0forSunny,1forRainyfn from_index(i: usize) -> Self— returnsSunnyfor0,Rainyfor anything else
A simple match expression handles both. These are the only tools you need to bridge the enum and the matrix.
Step 2 — Implement WeatherChain::step
step must sample the next state from the probability row for the current state. The technique is a cumulative-probability walk:
- Retrieve the transition row:
let row = self.transition[current.index()]; - Draw a uniform float in [0, 1):
let r: f64 = rng.gen(); - Walk the row, accumulating probability. When the running total first exceeds
r, return that state.
For the two-state case this reduces to a single comparison — if r < row[0] return Sunny, otherwise return Rainy — but implementing the loop works for any number of states and is worth practising.
Add a fallback Weather::from_index(row.len() - 1) after the loop to satisfy the compiler; floating-point rounding can in rare cases leave the loop without returning.
Step 3 — Implement WeatherChain::simulate
simulate runs the chain for steps transitions, collecting every state visited including the start:
- Allocate a
Vecwith capacitysteps + 1. - Push
startand setcurrent = start. - Loop
stepstimes: callself.step(current, rng), updatecurrent, and push. - Return the
Vec.
The returned slice will have length steps + 1 (the initial state plus one state per step).
Step 4 — Run the simulation
In main, create a WeatherChain with the two-state matrix from Section 3:
Sunny [0.8, 0.2]
Rainy [0.4, 0.6]
Seed a repeatable RNG with rand::rngs::SmallRng::seed_from_u64(42) (add use rand::SeedableRng;). Simulate 30 steps starting from Sunny, print the resulting sequence, and count how many days were sunny vs rainy.
Expected output structure (exact numbers vary by seed):
[Sunny, Sunny, Rainy, Sunny, ...]
Sunny days: 21 (67.7%)
Rainy days: 10 (32.3%)
Step 5 — Compare to the stationary distribution
The stationary distribution π satisfies π = πP, meaning once the chain reaches it, the distribution no longer changes. For this matrix, solve the two equations:
π₀ = 0.8·π₀ + 0.4·π₁
π₁ = 0.2·π₀ + 0.6·π₁
π₀ + π₁ = 1
The first equation simplifies to 0.2·π₀ = 0.4·π₁, giving π₀ = 2·π₁. Substituting into the normalisation constraint: π₀ = 2/3 ≈ 66.7%, π₁ = 1/3 ≈ 33.3%.
Print the stationary percentages alongside your simulated counts. With only 31 data points the match will be rough; re-run with 1 000 or 10 000 steps to see the empirical frequencies converge.
Reference Solution
Show full solution
use rand::{Rng, SeedableRng, rngs::SmallRng};
#[derive(Debug, Clone, Copy, PartialEq)]
enum Weather { Sunny, Rainy }
impl Weather {
fn index(self) -> usize {
match self {
Weather::Sunny => 0,
Weather::Rainy => 1,
}
}
fn from_index(i: usize) -> Self {
match i {
0 => Weather::Sunny,
_ => Weather::Rainy,
}
}
}
struct WeatherChain {
/// transition[current][next] = probability
transition: [[f64; 2]; 2],
}
impl WeatherChain {
fn step(&self, current: Weather, rng: &mut impl Rng) -> Weather {
let row = self.transition[current.index()];
let r: f64 = rng.gen();
let mut cumulative = 0.0;
for (i, &prob) in row.iter().enumerate() {
cumulative += prob;
if r < cumulative {
return Weather::from_index(i);
}
}
Weather::from_index(row.len() - 1)
}
fn simulate(&self, start: Weather, steps: usize, rng: &mut impl Rng) -> Vec<Weather> {
let mut states = Vec::with_capacity(steps + 1);
states.push(start);
let mut current = start;
for _ in 0..steps {
current = self.step(current, rng);
states.push(current);
}
states
}
}
fn main() {
let mut rng = SmallRng::seed_from_u64(42);
let chain = WeatherChain {
transition: [[0.8, 0.2], [0.4, 0.6]],
};
let states = chain.simulate(Weather::Sunny, 30, &mut rng);
println!("{:?}", states);
let total = states.len() as f64;
let sunny = states.iter().filter(|&&s| s == Weather::Sunny).count();
let rainy = states.len() - sunny;
println!("Sunny days: {} ({:.1}%)", sunny, 100.0 * sunny as f64 / total);
println!("Rainy days: {} ({:.1}%)", rainy, 100.0 * rainy as f64 / total);
println!("Stationary: Sunny ≈ 66.7%, Rainy ≈ 33.3%");
}
5. Exercise 2 — Simulating a Random Walk
Goal: Implement a one-dimensional random walk on the integers (states −N … +N) with reflecting boundaries, then measure the empirical distribution of positions after T steps.
A random walk is one of the simplest Markov chains: the state is a single integer position, and at each step the walker moves left or right according to a fixed probability. Adding reflecting boundaries means the walker cannot escape the interval — a step that would leave the interval is clamped to the nearest boundary.
Learning objectives
- Model a finite integer state space in Rust
- Implement reflecting (clamped) boundary conditions
- Aggregate many simulation runs into a histogram
- Visualise the distribution as an ASCII bar chart
- Observe how asymmetric step probabilities skew the stationary distribution
Setup
You can extend the weather-chain project from Exercise 1, or create a fresh binary:
cargo new random-walk
cd random-walk
cargo add rand
You will also need use std::collections::HashMap; at the top of src/main.rs.
Starter Code
Replace the contents of src/main.rs with the following skeleton. Do not change the struct layout or function signatures — your task is to fill in the todo!() bodies.
use rand::Rng;
use std::collections::HashMap;
struct RandomWalk {
min: i32,
max: i32,
/// prob_right[i] = probability of stepping right from position `min + i as i32`
prob_right: Vec<f64>,
}
impl RandomWalk {
fn step(&self, pos: i32, rng: &mut impl Rng) -> i32 { todo!() }
fn simulate(&self, start: i32, steps: usize, rng: &mut impl Rng) -> Vec<i32> { todo!() }
fn histogram(&self, start: i32, steps: usize, trials: usize, rng: &mut impl Rng)
-> HashMap<i32, usize> { todo!() }
}
fn print_histogram(hist: &HashMap<i32, usize>, min: i32, max: i32, trials: usize) { todo!() }
fn main() { todo!() }
Step 1 — Constructor and index conversion
The prob_right vec is indexed from 0, but positions range from min to max. Add two helpers to RandomWalk:
fn new(min: i32, max: i32, p: f64) -> Self— creates a uniform walk where every position has the same step probabilityp. Fillprob_rightwithvec![p; (max - min + 1) as usize].fn idx(&self, pos: i32) -> usize— converts a position to a vec index:(pos - self.min) as usize. Use this everywhere you index intoprob_right.
Step 2 — Implement RandomWalk::step
step must sample the next position:
- Draw a uniform float
rin [0, 1):let r: f64 = rng.gen(); - Look up the step probability:
let p = self.prob_right[self.idx(pos)]; - Choose direction: if
r < p, move right (pos + 1); otherwise move left (pos - 1) - Apply the reflecting boundary by clamping:
moved.clamp(self.min, self.max)
#![allow(unused)]
fn main() {
let moved = if r < p { pos + 1 } else { pos - 1 };
moved.clamp(self.min, self.max)
}
Clamping is the simplest reflecting rule: a step that would leave the interval is silently held at the boundary. An alternative is strict reflection (a step right from max lands at max - 1), but clamping is sufficient here.
Step 3 — Implement RandomWalk::simulate and RandomWalk::histogram
simulate runs the chain for steps transitions and collects every position visited:
- Allocate a
Vecwith capacitysteps + 1. - Push
startand setcurrent = start. - Loop
stepstimes: callself.step(current, rng), updatecurrent, push. - Return the vec.
histogram runs many independent trials and records where each trial ends:
- Allocate an empty
HashMap<i32, usize>. - Loop
trialstimes: callself.simulate(start, steps, rng), extract*positions.last().unwrap(), and increment its count with*hist.entry(pos).or_insert(0) += 1. - Return the map.
The histogram answers: after steps steps starting from start, what fraction of the time does the walker end at each position?
Step 4 — Print an ASCII bar chart
Write print_histogram to display the distribution. For each position from min to max:
- Look up
count = hist.get(&pos).copied().unwrap_or(0). - Scale to a bar:
bar_len = count * bar_width / max_count(usebar_width = 40andmax_count = hist.values().copied().max().unwrap_or(1)). - Print the position, a bar of
#characters, and the percentage.
#![allow(unused)]
fn main() {
let bar: String = "#".repeat(bar_len);
println!("{:4}: {:40} ({:.1}%)", pos, bar, 100.0 * count as f64 / trials as f64);
}
Step 5 — Compare symmetric vs asymmetric walks
In main, run two experiments. Seed a repeatable RNG with SmallRng::seed_from_u64(42) (add use rand::{SeedableRng, rngs::SmallRng};).
Symmetric walk (p = 0.5, range −5 … 5):
#![allow(unused)]
fn main() {
let walk = RandomWalk::new(-5, 5, 0.5);
let hist = walk.histogram(0, 200, 10_000, &mut rng);
print_histogram(&hist, -5, 5, 10_000);
}
Asymmetric walk (p = 0.7, same range):
#![allow(unused)]
fn main() {
let walk = RandomWalk::new(-5, 5, 0.7);
let hist = walk.histogram(0, 200, 10_000, &mut rng);
print_histogram(&hist, -5, 5, 10_000);
}
Observe:
- Symmetric (p = 0.5): after enough steps the distribution is approximately uniform — each of the 11 positions appears with roughly 9% frequency. This is the stationary distribution for a symmetric walk with clamping boundaries.
- Asymmetric (p = 0.7): the walker drifts toward the right boundary; positions near +5 accumulate much higher frequencies than those near −5. The stationary distribution is geometric, rising steeply toward
max.
Try lowering steps from 200 to 10 and observe that the distribution has not yet converged — it is still concentrated near start. This illustrates the mixing time of the chain.
Reference Solution
Show full solution
use rand::{Rng, SeedableRng, rngs::SmallRng};
use std::collections::HashMap;
struct RandomWalk {
min: i32,
max: i32,
/// prob_right[i] = probability of stepping right from position `min + i as i32`
prob_right: Vec<f64>,
}
impl RandomWalk {
fn new(min: i32, max: i32, p: f64) -> Self {
let n = (max - min + 1) as usize;
Self { min, max, prob_right: vec![p; n] }
}
fn idx(&self, pos: i32) -> usize {
(pos - self.min) as usize
}
fn step(&self, pos: i32, rng: &mut impl Rng) -> i32 {
let r: f64 = rng.gen();
let p = self.prob_right[self.idx(pos)];
let moved = if r < p { pos + 1 } else { pos - 1 };
moved.clamp(self.min, self.max)
}
fn simulate(&self, start: i32, steps: usize, rng: &mut impl Rng) -> Vec<i32> {
let mut positions = Vec::with_capacity(steps + 1);
let mut current = start;
positions.push(current);
for _ in 0..steps {
current = self.step(current, rng);
positions.push(current);
}
positions
}
fn histogram(
&self,
start: i32,
steps: usize,
trials: usize,
rng: &mut impl Rng,
) -> HashMap<i32, usize> {
let mut hist = HashMap::new();
for _ in 0..trials {
let positions = self.simulate(start, steps, rng);
let final_pos = *positions.last().unwrap();
*hist.entry(final_pos).or_insert(0) += 1;
}
hist
}
}
fn print_histogram(hist: &HashMap<i32, usize>, min: i32, max: i32, trials: usize) {
let max_count = hist.values().copied().max().unwrap_or(1);
let bar_width = 40;
for pos in min..=max {
let count = hist.get(&pos).copied().unwrap_or(0);
let bar_len = count * bar_width / max_count;
let bar: String = "#".repeat(bar_len);
println!("{:4}: {:40} ({:.1}%)", pos, bar, 100.0 * count as f64 / trials as f64);
}
}
fn main() {
let mut rng = SmallRng::seed_from_u64(42);
println!("=== Symmetric walk (p = 0.5) ===");
let walk = RandomWalk::new(-5, 5, 0.5);
let hist = walk.histogram(0, 200, 10_000, &mut rng);
print_histogram(&hist, -5, 5, 10_000);
println!();
println!("=== Asymmetric walk (p = 0.7) ===");
let walk = RandomWalk::new(-5, 5, 0.7);
let hist = walk.histogram(0, 200, 10_000, &mut rng);
print_histogram(&hist, -5, 5, 10_000);
}
Expected output structure (exact percentages vary by seed):
=== Symmetric walk (p = 0.5) ===
-5: ######################################## (9.3%)
-4: ##################################### (8.6%)
-3: ######################################## (9.4%)
-2: ##################################### (8.7%)
-1: ##################################### (8.7%)
0: #################################### (8.5%)
1: ##################################### (8.8%)
2: ##################################### (8.7%)
3: ##################################### (8.9%)
4: #################################### (8.4%)
5: ######################################### (9.6%)
=== Asymmetric walk (p = 0.7) ===
-5: (0.0%)
-4: (0.1%)
-3: # (0.4%)
-2: ## (1.2%)
-1: ##### (3.6%)
0: ########### (7.8%)
1: ################### (13.4%)
2: ############################### (21.6%)
3: ######################################## (27.8%)
4: ####################################### (27.4%)
5: ################################ (22.3%)
The symmetric walk converges to a nearly flat distribution — each position visited roughly 1/11 ≈ 9.1% of the time. The asymmetric walk piles up at the right boundary, with positions +3 … +5 capturing the bulk of the probability mass.
Part 3 — Text Generation
6. Text Generation with Markov Chains
Text can be modeled as a Markov chain where each state is a word (or sequence of words) and transitions represent which words tend to follow. This section explains the bigram model, why it produces surprisingly coherent short sequences, and what its limitations reveal about the relationship between statistical models and language. No code is written here — it prepares you for Exercises 3 and 4.
Words as states. To model text as a Markov chain, treat each word as a state and the act of writing the next word as a transition. Given the word “cat,” what word comes next? A Markov model answers that question by consulting statistics drawn from a training corpus: it scans every occurrence of “cat” and records which words followed it, then samples from that empirical distribution. The result is a sequence of words generated one step at a time, each word chosen probabilistically based only on the current word — not on the full sentence or paragraph that came before. This is the Markov property applied to language.
Bigrams and transition tables. A bigram is an ordered pair of adjacent words. To build a bigram model, scan the corpus left to right and record every consecutive word pair. Count how many times each pair appears, then for each word w express those counts as a probability distribution over successor words. This distribution forms one row of the transition table: a lookup from every word to a weighted list of what can follow it. The table is learned entirely from data — no grammar rules, no meaning, just co-occurrence statistics.
Worked example. Consider this two-sentence corpus:
“the cat sat on the mat. the cat sat on the hat.”
Scanning the corpus (treating each sentence as a word sequence and ignoring the periods) yields the following bigrams and their counts. Dividing each count by the row total gives the transition probability:
| Current word | Next word | Count | Probability |
|---|---|---|---|
| the | cat | 2 | 0.50 |
| the | mat | 1 | 0.25 |
| the | hat | 1 | 0.25 |
| cat | sat | 2 | 1.00 |
| sat | on | 2 | 1.00 |
| on | the | 2 | 1.00 |
Notice that “cat,” “sat,” and “on” each have only one possible successor — the small corpus left them no choice. Only “the” branches: half the time it is followed by “cat,” a quarter of the time by “mat,” and a quarter by “hat.” Starting from “the,” one plausible generated sequence is: the → cat → sat → on → the → hat — and then “hat” would be an end-of-sentence state.
Why the text sounds strange. Short stretches of output are surprisingly readable: every adjacent pair the model produces appeared in the training data, so each local transition is both grammatically and semantically plausible. The problem surfaces over longer spans. Because the model has no memory beyond the current word, it cannot maintain a topic, complete a thought, or avoid contradictions introduced a few sentences back. The result resembles text written by someone who half-knows the language: each individual step looks right, but the destination keeps shifting without purpose.
Order-n chains. A bigram model is an order-1 Markov chain — the next state depends on exactly one previous word. An order-2 model (a trigram model) conditions on the last two words; order n uses the last n words as context. Increasing n brings sharply improved local coherence — at n = 3 or 4 the output begins to reproduce full phrases from the corpus verbatim — but at the cost of novelty. A high-order model has likely seen most of its context only once, so it often has no real choice but to copy the source text rather than recombine it. The right tradeoff depends on corpus size: larger corpora can support higher n without the model simply memorising its training data.
7. Exercise 3 — Bigram Text Generator
Goal: Build a Markov chain over words from an input text. Each state is a single word; transitions are learned from the corpus. Generate novel word sequences of a given length from a chosen seed word.
Setup
You can extend the random-walk project from Exercise 2, or create a fresh binary:
cargo new bigram-text
cd bigram-text
cargo add rand
You will also need use std::collections::HashMap; at the top of src/main.rs.
Starter Code
Replace the contents of src/main.rs with the following skeleton. Do not change the struct layout or function signatures — your task is to fill in the todo!() bodies and write main.
use rand::Rng;
use std::collections::HashMap;
struct BigramModel {
/// transitions[word] = list of (next_word, weight) pairs
transitions: HashMap<String, Vec<(String, usize)>>,
}
impl BigramModel {
fn train(corpus: &str) -> Self { todo!() }
fn generate(&self, seed: &str, length: usize, rng: &mut impl Rng) -> Vec<String> { todo!() }
}
fn main() { todo!() }
Step 1 — Tokenize the corpus
Inside train, split the corpus into lowercase words:
#![allow(unused)]
fn main() {
let words: Vec<String> = corpus
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
}
split_whitespace handles any run of spaces, newlines, or tabs. Lowercasing ensures “The” and “the” are treated as the same state. For this exercise we leave punctuation attached to words (so “sat.” and “sat” are distinct) to keep the code simple — stripping it is a good stretch goal.
Step 2 — Build the transition table
Iterate every consecutive word pair using windows(2). The first word is the current state; the second is the successor to record:
#![allow(unused)]
fn main() {
for window in words.windows(2) {
let current = window[0].clone();
let next = window[1].clone();
let entry = transitions.entry(current).or_default();
if let Some(pair) = entry.iter_mut().find(|(w, _)| w == &next) {
pair.1 += 1;
} else {
entry.push((next, 1));
}
}
}
For each word, build a list of (successor_word, count) pairs. If the successor already appears in the list, increment its count; otherwise push a new entry with count 1. After scanning the whole corpus, every word maps to a weighted list of what can follow it — exactly one row of the bigram transition table.
Return BigramModel { transitions }.
Step 3 — Implement weighted random sampling
generate must sample the next word proportionally to its observed count in the successor list. Add a helper function outside impl BigramModel that performs integer weighted sampling — faster and numerically exact compared to floating-point accumulation:
#![allow(unused)]
fn main() {
fn sample_weighted<'a>(choices: &'a [(String, usize)], rng: &mut impl Rng) -> &'a str {
let total: usize = choices.iter().map(|(_, w)| w).sum();
let mut r = rng.gen_range(0..total);
for (word, weight) in choices {
if r < *weight {
return word;
}
r -= weight;
}
&choices.last().unwrap().0 // fallback: satisfies the compiler
}
}
The r -= weight trick avoids floating-point comparisons: draw a random index in [0, total), then walk the list subtracting each bucket’s weight until the remaining value lands inside the current entry. This is the discrete-distribution equivalent of the cumulative-probability walk used in Exercise 1, rewritten in integer arithmetic.
Step 4 — Implement BigramModel::generate
generate starts from a seed word and appends one word at a time:
- Push the seed onto
outputand setcurrent = seed.to_string(). - Loop up to
lengthtimes: look upcurrentinself.transitionsand callsample_weightedon the result. - Push the sampled word onto
outputand advancecurrentto it. - If
currentis not in the transition table (a dead end — the word never appeared mid-sentence in the corpus), stop early withbreak.
#![allow(unused)]
fn main() {
let mut output = vec![seed.to_string()];
let mut current = seed.to_string();
for _ in 0..length {
match self.transitions.get(¤t) {
None => break,
Some(choices) => {
let next = sample_weighted(choices, rng).to_string();
output.push(next.clone());
current = next;
}
}
}
output
}
Step 5 — Run the generator
In main, train on a short corpus and generate several sequences from different seed words. Use a seeded RNG for reproducibility (add use rand::{SeedableRng, rngs::SmallRng};):
const CORPUS: &str =
"alice was beginning to get very tired of sitting by her sister on the \
bank and of having nothing to do once or twice she had peeped into the \
book her sister was reading but it had no pictures or conversations in \
it and what is the use of a book thought alice without pictures or \
conversations alice was beginning to get very tired of sitting";
fn main() {
let mut rng = SmallRng::seed_from_u64(42);
let model = BigramModel::train(CORPUS);
for seed in &["alice", "the", "book"] {
let words = model.generate(seed, 15, &mut rng);
println!("{}", words.join(" "));
}
}
Expected output structure (exact words vary by seed):
alice was beginning to get very tired of sitting by her sister was beginning
the use of a book her sister on the bank and of sitting by her sister
book thought alice without pictures or conversations alice was beginning to get
Observe that every adjacent pair in the output appeared in the training corpus — each local transition is faithful to the source text. But over longer spans the topic shifts erratically, because the model has no memory beyond the immediately preceding word.
Step 6 — Try a larger corpus
The small Alice snippet produces repetitive output because many words have only one possible successor in a short text. To see genuine branching behaviour, download the first chapter of Alice’s Adventures in Wonderland from Project Gutenberg (plain text, freely available) and feed the full text to BigramModel::train. With more data:
- Common words like “the” and “and” branch to dozens of successors.
- Generated sequences stay on-topic for longer stretches before the topic shifts.
- The difference between the single-word context used here and the two-word context in Exercise 4 becomes immediately obvious.
You can embed the text file directly in the binary with Rust’s include_str! macro:
#![allow(unused)]
fn main() {
const CORPUS: &str = include_str!("../corpus/alice_ch1.txt");
}
Reference Solution
Show full solution
use rand::{Rng, SeedableRng, rngs::SmallRng};
use std::collections::HashMap;
struct BigramModel {
/// transitions[word] = list of (next_word, weight) pairs
transitions: HashMap<String, Vec<(String, usize)>>,
}
impl BigramModel {
fn train(corpus: &str) -> Self {
let words: Vec<String> = corpus
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
let mut transitions: HashMap<String, Vec<(String, usize)>> = HashMap::new();
for window in words.windows(2) {
let current = window[0].clone();
let next = window[1].clone();
let entry = transitions.entry(current).or_default();
if let Some(pair) = entry.iter_mut().find(|(w, _)| w == &next) {
pair.1 += 1;
} else {
entry.push((next, 1));
}
}
BigramModel { transitions }
}
fn generate(&self, seed: &str, length: usize, rng: &mut impl Rng) -> Vec<String> {
let mut output = vec![seed.to_string()];
let mut current = seed.to_string();
for _ in 0..length {
match self.transitions.get(¤t) {
None => break,
Some(choices) => {
let next = sample_weighted(choices, rng).to_string();
output.push(next.clone());
current = next;
}
}
}
output
}
}
fn sample_weighted<'a>(choices: &'a [(String, usize)], rng: &mut impl Rng) -> &'a str {
let total: usize = choices.iter().map(|(_, w)| w).sum();
let mut r = rng.gen_range(0..total);
for (word, weight) in choices {
if r < *weight {
return word;
}
r -= weight;
}
&choices.last().unwrap().0
}
const CORPUS: &str =
"alice was beginning to get very tired of sitting by her sister on the \
bank and of having nothing to do once or twice she had peeped into the \
book her sister was reading but it had no pictures or conversations in \
it and what is the use of a book thought alice without pictures or \
conversations alice was beginning to get very tired of sitting";
fn main() {
let mut rng = SmallRng::seed_from_u64(42);
let model = BigramModel::train(CORPUS);
for seed in &["alice", "the", "book"] {
let words = model.generate(seed, 15, &mut rng);
println!("{}", words.join(" "));
}
}
8. Exercise 4 — N-gram Generalization
Goal: Generalize the bigram model to an n-gram model where each state is a window of n consecutive words. Compare the output quality for n = 1, 2, 3, and 4 on the same corpus.
Setup
Extend the project from Exercise 3, or create a fresh one:
cargo new ngram-chain
cd ngram-chain
cargo add rand
The NgramModel stores transitions keyed by Vec<String> — a window of n consecutive words. Vec<String> implements Hash and Eq, so it works directly as a HashMap key with no extra wrapping.
Starter Code
Replace (or extend) src/main.rs with the following skeleton:
use rand::Rng;
use std::collections::HashMap;
struct NgramModel {
n: usize,
transitions: HashMap<Vec<String>, Vec<(String, usize)>>,
}
impl NgramModel {
fn train(corpus: &str, n: usize) -> Self { todo!() }
fn generate(&self, seed: Vec<String>, length: usize, rng: &mut impl Rng) -> Vec<String> { todo!() }
}
fn main() { todo!() }
Step 1 — Tokenize and build the transition table
Inside train, convert the corpus into a flat list of lowercase words:
#![allow(unused)]
fn main() {
let words: Vec<String> = corpus
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
}
Then iterate over sliding windows of size n + 1. The first n elements form the key (the context); the last element is the successor word to record:
#![allow(unused)]
fn main() {
for window in words.windows(n + 1) {
let key: Vec<String> = window[..n].to_vec();
let next: String = window[n].clone();
// insert (key → next) into the transition table
}
}
Use HashMap::entry(...).or_default() to get-or-insert the successor list. Scan for an existing entry for next and increment its count, or push (next, 1) if it is new.
Step 2 — Weighted sampling helper
Reuse the same cumulative-weight technique from Exercise 3. The helper below accepts any slice of (word, count) pairs and returns a randomly chosen word with probability proportional to its count:
#![allow(unused)]
fn main() {
fn sample<'a>(choices: &'a [(String, usize)], rng: &mut impl Rng) -> &'a str {
let total: usize = choices.iter().map(|(_, w)| w).sum();
let mut r = rng.gen_range(0..total);
for (word, weight) in choices {
if r < *weight {
return word;
}
r -= weight;
}
&choices.last().unwrap().0
}
}
Step 3 — Implement NgramModel::generate
generate is called with a seed — a Vec<String> of exactly n words. It extends that seed word by word, up to length additional words, using a sliding window to track the current context:
- Copy the seed into
outputand into aVecDeque<String>calledwindow. - Each step: collect
windowinto aVec<String>to use as the lookup key. - If the key is missing from
self.transitions, stop early — the chain has reached a dead end. - Sample the next word, push it onto
output, pop the oldest word fromwindow, push the new word.
#![allow(unused)]
fn main() {
use std::collections::VecDeque;
// inside generate:
let mut window: VecDeque<String> = VecDeque::from(seed.clone());
let mut output = seed;
for _ in 0..length {
let key: Vec<String> = window.iter().cloned().collect();
match self.transitions.get(&key) {
None => break,
Some(choices) => {
let next = sample(choices, rng).to_string();
output.push(next.clone());
window.pop_front();
window.push_back(next);
}
}
}
output
}
Step 4 — Compare n = 1, 2, 3, 4
In main, train one model per value of n and generate 50 additional words from each. Use a fixed RNG seed for reproducibility. The short Alice corpus below is enough to observe the trend; swap in a larger public-domain text (e.g., the first chapter of Alice’s Adventures in Wonderland from Project Gutenberg) for more interesting output.
const CORPUS: &str =
"alice was beginning to get very tired of sitting by her sister on the \
bank and of having nothing to do once or twice she had peeped into the \
book her sister was reading but it had no pictures or conversations in \
it and what is the use of a book thought alice without pictures or \
conversations alice was beginning to get very tired of sitting";
fn main() {
use rand::{SeedableRng, rngs::SmallRng};
for n in 1..=4 {
let model = NgramModel::train(CORPUS, n);
let mut rng = SmallRng::seed_from_u64(42);
// seed = the first n words of the corpus
let seed: Vec<String> = CORPUS
.split_whitespace()
.take(n)
.map(|s| s.to_lowercase())
.collect();
let words = model.generate(seed, 50, &mut rng);
println!("n={}: {}", n, words.join(" "));
println!();
}
}
Expected observations:
- n = 1: no context at all; the model samples from the global word-frequency distribution, producing word soup with only rough statistical flavour.
- n = 2 (bigram): every adjacent pair appeared in the corpus, so individual transitions feel plausible; the topic still shifts erratically over longer runs.
- n = 3 (trigram): longer coherent stretches emerge; you will start to recognise verbatim phrases from the corpus.
- n = 4: on a small corpus, most 4-word contexts appear only once, leaving the model no real choice but to reproduce the training text nearly verbatim. Try a larger corpus to see n = 4 produce novel output.
Step 5 — Memorisation vs novelty
The pattern above is the central tension in all n-gram language models:
- Small n: short context → many plausible continuations → high novelty, low coherence.
- Large n: long context → typically a unique continuation → low novelty, high local fidelity to the corpus.
The corpus size determines the crossover point. For a paragraph-sized text, n = 2 is usually the maximum useful order. For a novel-length corpus, n = 4 or 5 can produce readable, novel output without simply transcribing the source.
Stretch Goal — Character-level n-grams
Swap words for individual characters: tokenize with .chars() instead of .split_whitespace(). The model and sampling logic are unchanged — only the “token” definition shifts from words to characters:
#![allow(unused)]
fn main() {
// Character-level tokenization for train:
let chars: Vec<String> = corpus.chars().map(|c| c.to_string()).collect();
// Replace `words` with `chars` and proceed identically.
}
Generate 200 characters at n = 3, 5, and 8. You will see the model progress from random letter sequences, to plausible letter clusters and syllables, to recognisable words and phrases as n grows. This also demonstrates that the n-gram approach is domain-agnostic: the same code works for words, characters, DNA bases, MIDI note sequences, or any discrete token stream.
Reference Solution
Show full solution
use rand::{Rng, SeedableRng, rngs::SmallRng};
use std::collections::{HashMap, VecDeque};
struct NgramModel {
n: usize,
transitions: HashMap<Vec<String>, Vec<(String, usize)>>,
}
impl NgramModel {
fn train(corpus: &str, n: usize) -> Self {
let words: Vec<String> = corpus
.split_whitespace()
.map(|s| s.to_lowercase())
.collect();
let mut transitions: HashMap<Vec<String>, Vec<(String, usize)>> = HashMap::new();
for window in words.windows(n + 1) {
let key: Vec<String> = window[..n].to_vec();
let next = window[n].clone();
let entry = transitions.entry(key).or_default();
if let Some(pair) = entry.iter_mut().find(|(w, _)| w == &next) {
pair.1 += 1;
} else {
entry.push((next, 1));
}
}
NgramModel { n, transitions }
}
fn generate(&self, seed: Vec<String>, length: usize, rng: &mut impl Rng) -> Vec<String> {
assert_eq!(seed.len(), self.n, "seed must have exactly n words");
let mut window: VecDeque<String> = VecDeque::from(seed.clone());
let mut output = seed;
for _ in 0..length {
let key: Vec<String> = window.iter().cloned().collect();
match self.transitions.get(&key) {
None => break,
Some(choices) => {
let next = sample(choices, rng).to_string();
output.push(next.clone());
window.pop_front();
window.push_back(next);
}
}
}
output
}
}
fn sample<'a>(choices: &'a [(String, usize)], rng: &mut impl Rng) -> &'a str {
let total: usize = choices.iter().map(|(_, w)| w).sum();
let mut r = rng.gen_range(0..total);
for (word, weight) in choices {
if r < *weight {
return word;
}
r -= weight;
}
&choices.last().unwrap().0
}
const CORPUS: &str =
"alice was beginning to get very tired of sitting by her sister on the \
bank and of having nothing to do once or twice she had peeped into the \
book her sister was reading but it had no pictures or conversations in \
it and what is the use of a book thought alice without pictures or \
conversations alice was beginning to get very tired of sitting";
fn main() {
for n in 1..=4 {
let model = NgramModel::train(CORPUS, n);
let mut rng = SmallRng::seed_from_u64(42);
let seed: Vec<String> = CORPUS
.split_whitespace()
.take(n)
.map(|s| s.to_lowercase())
.collect();
let words = model.generate(seed, 50, &mut rng);
println!("n={}: {}", n, words.join(" "));
println!();
}
}
Part 4 — Deeper Concepts
9. Stationary Distributions
A stationary distribution π is a probability distribution over states that is unchanged by one step of the chain: π P = π. This section covers how to find stationary distributions analytically (for small chains) and via power iteration, and explains when they exist and are unique — introducing the concepts of irreducibility and aperiodicity.
The stationary equation. Write the state distribution as a row vector π = [π₀, π₁, …, π_{n−1}]. Then π is stationary if it satisfies two simultaneous conditions:
πP = π (fixed-point equation)
Σ πᵢ = 1 (normalisation)
πᵢ ≥ 0 for all i
The fixed-point equation says that right-multiplying π by the transition matrix returns exactly π: one step of the chain leaves the distribution unchanged. In terms of individual entries, the condition expands to:
πⱼ = Σᵢ πᵢ · P[i][j] for every j
The probability mass flowing into state j from every state i — weighted by how likely the chain is to be in each i — exactly equals the current probability already in state j. No state is gaining or losing probability over time; the distribution is in balance.
Long-run frequency interpretation. The simulations in Exercise 1 showed this balance in practice: after thousands of steps, the fraction of time the chain occupied each state stabilised near 2/3 Sunny and 1/3 Rainy, regardless of the starting state. This is the ergodic theorem for Markov chains: under appropriate conditions, the time-average frequency of visiting state i converges to πᵢ almost surely. The stationary distribution is not merely a mathematical fixed point — it is the long-run proportion of time the chain spends in each state, and it is precisely what your simulations were converging toward.
Irreducibility. A chain is irreducible if every state is reachable from every other state in some finite number of steps. Formally, for every pair (i, j) there exists k ≥ 1 with P^k[i][j] > 0. An irreducible chain cannot get permanently trapped in a proper subset of states; all states form a single communicating class. A chain with an absorbing state (one where P[i][i] = 1 and the chain can enter but never leave) is not irreducible. Irreducibility is what rules out a chain splitting into separate “islands” each with its own local stationary distribution.
Aperiodicity. A state has period d if returns to that state can only happen in a number of steps that is a multiple of d. Formally, d = gcd{k ≥ 1 : P^k[i][i] > 0}. A state is aperiodic if d = 1, and a chain is aperiodic if every state is. An aperiodic chain does not oscillate on a fixed cycle — there is no rhythm forcing the chain to return to state i only every 2 or every 5 steps. A self-loop (P[i][i] > 0) immediately makes state i aperiodic, because the chain can return after 1 step, so the gcd collapses to 1. A finite, irreducible, aperiodic chain is called ergodic, and such a chain is guaranteed to have exactly one stationary distribution. Furthermore, starting from any initial distribution π₀, the sequence π₀ P^k converges to π as k → ∞. The weather chain is both irreducible (Sunny → Rainy and Rainy → Sunny are both reachable in one step) and aperiodic (both states have positive self-loop probabilities: P[0][0] = 0.8 and P[1][1] = 0.6).
Analytical solution for the weather chain. To solve π P = π for the 2-state chain, write out the equation component-by-component. With π = [π₀, π₁] and the weather transition matrix:
πP = π expands to:
π₀ · P[0][0] + π₁ · P[1][0] = π₀ → 0.8·π₀ + 0.4·π₁ = π₀
π₀ · P[0][1] + π₁ · P[1][1] = π₁ → 0.2·π₀ + 0.6·π₁ = π₁
Both equations carry the same information (one follows from the other because the rows of P sum to 1 and the probabilities sum to 1), so use only the first and add the normalisation constraint:
0.8·π₀ + 0.4·π₁ = π₀
0.4·π₁ = 0.2·π₀
π₀ = 2·π₁
Normalisation: π₀ + π₁ = 1
2·π₁ + π₁ = 1
π₁ = 1/3 ≈ 0.333
π₀ = 2/3 ≈ 0.667
The unique stationary distribution is π ≈ [0.667, 0.333]: about two-thirds of all days are Sunny and one-third are Rainy in the long run. This matches the 66.7 % / 33.3 % figures from the Exercise 1 simulation.
Power iteration. For chains with many states, setting up and solving the linear system π P = π directly is impractical. Power iteration is a standard numerical alternative: start from any distribution and repeatedly apply P until successive distributions are indistinguishable. Convergence to the unique π is guaranteed for ergodic chains.
// Pseudocode: power iteration for stationary distribution
let π = uniform distribution over all n states
loop:
let π_next = π · P // one step of the chain
if max |π_next[i] - π[i]| < tolerance:
break
π = π_next
// π now approximates the stationary distribution
A concrete Rust sketch for the 2-state weather chain:
fn power_iteration(p: [[f64; 2]; 2], tolerance: f64) -> [f64; 2] {
let mut pi = [0.5_f64, 0.5];
loop {
let pi_next = [
pi[0] * p[0][0] + pi[1] * p[1][0],
pi[0] * p[0][1] + pi[1] * p[1][1],
];
let diff = (pi_next[0] - pi[0]).abs().max((pi_next[1] - pi[1]).abs());
pi = pi_next;
if diff < tolerance {
break;
}
}
pi
}
fn main() {
let p = [[0.8, 0.2], [0.4, 0.6]];
let pi = power_iteration(p, 1e-9);
println!("π ≈ [{:.6}, {:.6}]", pi[0], pi[1]);
// Output: π ≈ [0.666667, 0.333333]
}
Running this for the weather chain converges in roughly 60 iterations and agrees with the analytical result. For an n-state chain, each iteration costs O(n²) and the number of iterations needed scales inversely with the spectral gap — the difference between the two largest eigenvalues of P. A large spectral gap means fast convergence; a gap close to zero means the chain mixes slowly and power iteration requires many steps.
10. Applications and Further Reading
Markov chains appear throughout computer science and mathematics: PageRank, MCMC sampling, hidden Markov models, reinforcement learning, and more. This section surveys these applications at a high level and points to books, papers, and courses for learners who want to go deeper on any thread.
Application Survey
PageRank. Google’s original ranking algorithm modeled the web as a Markov chain: each page is a state, and each hyperlink is a transition with probability proportional to the number of outgoing links on the source page. A small “teleportation” probability was added so a random surfer occasionally jumps to a uniformly random page rather than following a link, ensuring the chain is irreducible and has a unique stationary distribution. The stationary probability of each page — the fraction of time a random surfer spends there in the long run — becomes its rank. Because the web had billions of pages, computing the stationary distribution via power iteration rather than direct matrix inversion was essential; the same convergence guarantee from Section 9 applies at planetary scale.
Markov Chain Monte Carlo (MCMC). Bayesian inference often requires integrating over high-dimensional parameter spaces where the posterior distribution has no closed form. MCMC methods solve this by constructing a Markov chain whose stationary distribution is the target posterior, then running the chain long enough that its samples approximate draws from that distribution. The Metropolis-Hastings algorithm is the foundational recipe: propose a move to a new state, accept it with a probability that preserves detailed balance, and reject it otherwise. Variants such as Gibbs sampling, Hamiltonian Monte Carlo, and the No-U-Turn Sampler (NUTS) power nearly every modern probabilistic programming framework, from Stan to PyMC.
Hidden Markov Models (HMMs). An HMM separates a Markov chain into two layers: a hidden state sequence that evolves according to a transition matrix, and an observation sequence where each hidden state emits an observable symbol with some probability. The key insight is that the true states are never directly seen — only the observations are. HMMs were the dominant approach in speech recognition for decades (phonemes as hidden states, acoustic features as observations) and remain central in bioinformatics for gene prediction and sequence segmentation. The Viterbi algorithm finds the most likely hidden state path for a given observation sequence in time linear in the sequence length; the Baum-Welch algorithm trains HMMs from unlabelled data using expectation-maximisation.
Reinforcement Learning. Most reinforcement learning problems are formulated as Markov Decision Processes (MDPs), which augment a Markov chain with a set of actions and a scalar reward signal. At each step the agent chooses an action, the environment transitions to a new state according to a transition distribution that depends on both the current state and the chosen action, and the agent receives a reward. The goal is to find a policy — a rule mapping states to actions — that maximises cumulative discounted reward. Because the next state depends only on the current state and action (not the history), the Markov property is what makes value-function algorithms like Q-learning and policy-gradient methods tractable. Sutton and Barto’s textbook (listed below) treats this connection in full rigour.
Bioinformatics — Sequence Analysis. DNA and protein sequences are naturally modelled as Markov chains over an alphabet of bases or amino acids. A simple k-th order Markov model assigns probabilities to short subsequences and can distinguish coding regions from non-coding regions: CpG islands in mammalian genomes, for example, have transition probabilities measurably different from the genomic background. Profile HMMs generalise this to align whole families of sequences — each column in a multiple sequence alignment becomes a hidden state with its own emission distribution, allowing robust database search even for distantly related proteins.
Queueing Theory. The classic M/M/1 queue — arrivals according to a Poisson process, exponential service times, a single server — is a continuous-time Markov chain on the non-negative integers, where the state is the number of customers in the system. Its stationary distribution is geometric, giving simple closed-form expressions for average queue length and waiting time. More complex queueing networks (multiple servers, priorities, finite buffers) extend the same framework and are used to size data-centre infrastructure, analyse hospital emergency departments, and design network switches. Continuous-time Markov chains replace the transition matrix with a generator matrix Q whose off-diagonal entries are transition rates rather than probabilities.
Language Models. The n-gram models from Exercises 3 and 4 are finite-order Markov chains over word tokens, and they directly preceded modern neural language models. In the 1990s and 2000s, trigram and 4-gram models with smoothing (Kneser-Ney, Witten-Bell) were the state of the art for machine translation and speech recognition. Neural language models replaced explicit Markov structure with learned representations, but the conceptual scaffolding is the same: predict the next token from a bounded window of context. Understanding the Markov chain view of language — its local coherence, its lack of long-range memory, its reliance on corpus statistics — clarifies both what earlier systems got right and what neural approaches had to learn to transcend.
Further Reading
Books
-
Blitzstein & Hwang — Introduction to Probability (2nd ed.) A beautifully written undergraduate probability text with a full chapter on Markov chains. The authors’ lecture videos and course materials are freely available; a free PDF of the book is offered on the book’s companion site. Start here if your probability background is thin or if you want every concept illustrated with concrete examples before the formalism arrives.
-
Norris — Markov Chains The standard rigorous treatment at the advanced-undergraduate / early-graduate level. Covers discrete- and continuous-time chains, convergence, reversibility, and applications with full proofs. Dense but thorough; worth working through if you intend to read research papers that use Markov chains as a theoretical tool.
-
Sutton & Barto — Reinforcement Learning: An Introduction (2nd ed.) The canonical RL textbook, freely available as a PDF from the authors. Chapters 3 and 4 formalise MDPs and dynamic programming using exactly the Markov chain machinery developed in this course. Reading those chapters after completing this course is a natural next step toward understanding how modern game-playing and robotics agents are designed.
Articles
- Metropolis-Hastings algorithm — Wikipedia. A well-maintained article that covers the algorithm statement, acceptance-ratio derivation, intuition for why detailed balance implies the correct stationary distribution, and pseudocode. A good companion to the original 1953 Metropolis et al. paper (five pages, freely available) and the 1970 Hastings generalisation.
Rust Ecosystem Pointers
The exercises in this course used rand for sampling. A few other crates are useful as you build more serious probabilistic programs in Rust:
-
nalgebra— a comprehensive linear-algebra library covering vectors, dense matrices, and decompositions. Use it to compute P^k via repeated matrix multiplication or to solve the stationary-distribution equations π P = π, Σ πᵢ = 1 as a linear system. -
petgraph— graph data structures and algorithms. Markov chains are directed weighted graphs, andpetgraphlets you represent them as such, run graph-theoretic algorithms (strongly connected components give you irreducible sub-chains), and visualise the structure via its Graphviz export. -
statrs— a statistics library providing common probability distributions with their PDFs, CDFs, and samplers. Useful when building emission distributions for HMMs or when you need chi-squared tests to check whether simulated chain frequencies match theoretical stationary probabilities.
Vector Database Self-Guided Course
This document is a self-guided course on vector databases. It is organized into four parts: conceptual foundations, the internals of vector search systems, hands-on Rust exercises with Turso and sqlite-vec, and real-world application pipelines. Each section is either a reading lesson or a hands-on Rust programming exercise. Sections marked 🚧 are stubs whose full content is tracked in an nbd ticket — follow the ticket ID to find the detailed learning objectives and instructions.
Table of Contents
Part 1 — Foundations
Part 2 — Vector Databases
Part 3 — Turso + sqlite-vec Basics
Part 4 — Real Applications
- Generating Embeddings in Rust
- Exercise 3 — Semantic Document Search
- Exercise 4 — Recommendation Engine
- Exercise 5 — Retrieval-Augmented Generation
Part 1 — Foundations
1. What Is a Vector?
A vector is an ordered list of numbers. That is the entire definition — nothing more exotic than a list where position matters. A two-element list [3.0, 4.0] is a vector; so is a 1 536-element list of floating-point values produced by a language model. What makes vectors useful is that the numbers have a geometric interpretation: each element is a coordinate along one axis of a space, and the vector as a whole names a point (or an arrow from the origin to that point) in that space.
Geometric intuition in two and three dimensions. Start with the familiar. A 2-dimensional vector [x, y] is a point in the plane — the kind you plot on graph paper. The vector [3.0, 4.0] sits three units to the right of the origin and four units up. An arrow drawn from [0, 0] to [3, 4] has a magnitude (length) of √(3² + 4²) = 5 and points in a specific direction. Magnitude and direction together completely characterise the vector; change either one and you have a different vector.
A 3-dimensional vector [x, y, z] extends this to physical space: three coordinates, three axes, one point. You can still compute a magnitude — √(x² + y² + z²) — and you can still talk about direction. Two 3D vectors point in the same direction if one is a positive scalar multiple of the other; they are perpendicular (orthogonal) if their dot product is zero.
High-dimensional spaces. Nothing in the definition of a vector limits it to two or three elements. A d-dimensional vector [x₁, x₂, …, x_d] is a point in d-dimensional space. The geometry extends perfectly: magnitude is √(x₁² + x₂² + … + x_d²), the dot product of two vectors is Σᵢ aᵢ · bᵢ, and you can compute angles and distances between points just as you would in 2D or 3D.
High-dimensional geometry is counterintuitive in subtle ways that are worth knowing:
-
The curse of dimensionality. In high-dimensional spaces, most of the volume of a hypersphere is concentrated near its surface rather than its interior. Two randomly chosen high-dimensional vectors from a standard distribution tend to be nearly orthogonal — their dot product is close to zero — even when you have not deliberately constructed them that way. This means “nearest neighbour” in high dimensions is a harder problem than it sounds: there are exponentially many directions, and nearby points can seem far away using simple distance measures.
-
Normalisation changes the geometry. A unit vector has magnitude exactly 1. Dividing a vector by its magnitude — normalisation — projects all vectors onto the surface of the unit hypersphere. On that sphere, distance and angle are equivalent measures of similarity, which simplifies many computations. Embedding models often output unit-normalised vectors precisely to exploit this equivalence.
-
Dimensions are not independent features. When people say a language model embeds words into a 768-dimensional space, they do not mean “dimension 42 encodes the concept of colour.” The axes of an embedding space are rarely interpretable on their own. Meaning is encoded in the relative positions of points — which vectors are close to which others — not in the values along any single axis.
Vectors as representations. The key insight that makes vector databases useful is that real-world objects — documents, images, audio clips, products, users — can be represented as vectors such that similarity in meaning or content corresponds to proximity in the vector space. Two documents that discuss the same topic will, if embedded well, produce vectors that are close together. Two documents on unrelated topics will produce vectors that are far apart.
This is not magic; it is the result of training a model to produce embeddings where similar inputs cluster near each other. Once you have such a model, every search or comparison problem reduces to a geometric problem: find the vectors closest to a query vector. The rest of this course is about how to do that efficiently at scale.
A note on notation. Throughout this course, vectors are written in bold or with subscripts: v, q, or v₁. The i-th element of a vector v is written v[i] or vᵢ. The magnitude of v is written |v| or ‖v‖. Dimension is written d and the number of stored vectors is written n.
2. Embeddings
What an embedding is. An embedding is a function — learned from data, not hand-crafted — that maps an input to a fixed-size vector of floating-point numbers. The input can be a word, a sentence, an image, a product listing, or anything else that can be fed into a neural network. The output is always the same shape: a Vec<f32> of some predetermined length d. The function is trained so that inputs with similar meaning produce vectors that are close together in the d-dimensional space, while unrelated inputs produce vectors that are far apart. Once you have such a function, comparing the meaning of two inputs reduces to comparing their vectors — which is exactly the geometric problem that vector databases are built to solve.
Word embeddings and a brief history. The idea that meaning can live in a vector took hold in 2013 when Mikolov et al. published Word2Vec. Word2Vec trains on raw text and assigns every word in its vocabulary a single static vector, typically of 100 to 300 dimensions. The striking result was that vector arithmetic captured semantic relationships: the vector for king minus the vector for man plus the vector for woman produced a vector closest to queen. GloVe (2014) and fastText (2016) refined the approach, but the core limitation remained — each word gets exactly one vector regardless of context. The word bank has the same embedding whether it refers to a riverbank or a financial institution. Static word embeddings are largely a historical curiosity today, but they established the foundational principle: meaning can be encoded as geometry.
Contextual embeddings from encoder models. Modern embedding models solve the polysemy problem by reading the entire input before producing a vector. Models such as those in the sentence-transformers library or OpenAI’s text-embedding-3-small take a full sentence (or paragraph) as input, process it through a transformer encoder, and output one vector that represents the whole input. Internally these models produce a vector for every token; the single sentence-level vector is obtained either by averaging all token vectors (mean-pooling) or by taking the vector at a special [CLS] token position. You do not need to understand transformer internals to use these models — the interface is simple: input is a string, output is a Vec<f32> of fixed length. Because the model sees the full context, the same word in different sentences yields different final embeddings, correctly distinguishing river bank from investment bank.
What makes a good embedding model. Embedding models are trained with a contrastive objective: given a pair of inputs known to be similar (a question and its answer, two paraphrases, a caption and its image), the loss function pulls their vectors closer together; given a dissimilar pair, it pushes them apart. The quality of the training data — how many pairs, how diverse, how accurately labelled — matters as much as model size. Models are evaluated on the MTEB (Massive Text Embedding Benchmark), which measures performance across retrieval, classification, clustering, and semantic similarity tasks. In general, larger models produce better embeddings but cost more compute per input and return higher-dimensional vectors that consume more storage.
Practical dimensionalities. Different models produce different vector sizes, and the choice affects speed, memory, and quality. Common dimensions include 384 (MiniLM — fast inference, model size around 80–130 MB, a good default for prototyping), 768 (BERT-base and many sentence-transformers models — the most common open-source default), 1 536 (OpenAI text-embedding-3-small — a strong hosted option balancing quality and cost), and 3 072 (OpenAI text-embedding-3-large — highest quality from OpenAI at roughly double the cost). Higher dimensionality is not always better: on small datasets or narrow domains, a 384-dimensional model may match or outperform a 1 536-dimensional one while using a quarter of the storage and running faster at query time. Choose based on your task, your latency budget, and empirical evaluation — not on the assumption that bigger is automatically better.
Embeddings for non-text data. Vectors are not limited to language. CLIP (Contrastive Language-Image Pretraining) trains a text encoder and an image encoder jointly so that their output vectors inhabit the same space — a photo of a dog and the sentence “a photograph of a dog” end up near each other, enabling text-to-image and image-to-text search with no modality-specific logic. Product embeddings can be learned from purchase co-occurrence: items frequently bought together are trained to have nearby vectors, powering recommendation engines. Audio, code, and molecular structures have their own embedding models. The vector database does not care what produced the floats — it stores arrays of f32 and computes distances. This modality-agnostic storage is one of the reasons vector databases have become a general-purpose building block in modern AI systems.
3. Vector Similarity
Once you have two vectors, how do you measure how alike they are? This section covers the three most common similarity and distance functions used in vector search — their formulas, geometric interpretations, and trade-offs — then works through a concrete example so the arithmetic is familiar before you encounter these functions in SQL.
Cosine similarity. The cosine similarity of two vectors a and b is defined as cos(θ) = (a · b) / (‖a‖ · ‖b‖), where a · b is the dot product and ‖a‖ is the magnitude of a. The result ranges from −1 to 1: a value of 1 means the vectors point in exactly the same direction, 0 means they are orthogonal (perpendicular), and −1 means they point in exactly opposite directions. The critical property of cosine similarity is that it measures only the angle between vectors, ignoring their magnitudes entirely. This makes it ideal for text embeddings: a short document and a long document on the same topic may produce vectors that differ in magnitude but point in nearly the same direction, and cosine similarity correctly identifies them as similar.
Cosine distance. Cosine distance is simply 1 − cosine_similarity. Its range is 0 to 2, where 0 means the vectors are identical in direction and 2 means they are fully opposite. This is what sqlite-vec’s vector_distance_cos function returns. Pay attention to the naming: the function name contains “cos” but it returns a distance, not a similarity — smaller values mean more similar vectors, not less. This is a common source of confusion when writing queries for the first time.
Dot product. The dot product of two vectors a and b is a · b = Σᵢ aᵢbᵢ — multiply corresponding elements and sum the results. For unit-normalised vectors (vectors whose magnitude is exactly 1), the dot product equals cosine similarity, because the denominator ‖a‖ · ‖b‖ = 1 · 1 = 1 and cancels out. For unnormalised vectors, the dot product conflates magnitude and direction: a longer vector will produce a larger dot product even if the angle is the same. Some embedding models are trained specifically for maximum inner product search (MIPS), meaning their vectors are not unit-normalised and the raw dot product is the intended similarity metric. The model’s documentation or model card will say so when this is the case.
Euclidean (L2) distance. The Euclidean distance between two vectors is ‖a − b‖ = √(Σᵢ (aᵢ − bᵢ)²) — the straight-line distance between two points in d-dimensional space. Its range is 0 to ∞, with 0 meaning the vectors are identical. Unlike cosine similarity, L2 distance is sensitive to vector magnitude: two vectors pointing in the same direction but with different lengths will have a non-zero L2 distance. L2 is most appropriate for low-dimensional geometric or tabular data where absolute coordinate values carry meaning — for example, geographic coordinates or sensor readings.
When to use each. For text and sentence embeddings, use cosine similarity (or equivalently, dot product if your model outputs unit-normalised vectors, which many do). When in doubt, follow the recommendation on the model card. For low-dimensional geometric features where absolute position matters, use L2 distance.
Worked example. Let a = [1, 0, 1] and b = [1, 1, 0]. Compute all three metrics by hand:
Dot product: a · b = (1)(1) + (0)(1) + (1)(0) = 1 + 0 + 0 = 1.
Magnitudes: ‖a‖ = √(1² + 0² + 1²) = √2 ≈ 1.414. ‖b‖ = √(1² + 1² + 0²) = √2 ≈ 1.414.
Cosine similarity: cos(θ) = 1 / (√2 · √2) = 1 / 2 = 0.5. The cosine distance is 1 − 0.5 = 0.5, which is what vector_distance_cos would return.
Euclidean distance: ‖a − b‖ = √((1−1)² + (0−1)² + (1−0)²) = √(0 + 1 + 1) = √2 ≈ 1.414.
These three numbers — dot product = 1, cosine similarity = 0.5, L2 distance ≈ 1.414 — describe different aspects of the relationship between a and b. In the exercises that follow, you will see these same computations expressed as SQL function calls over stored vectors.
Part 2 — Vector Databases
4. What Is a Vector Database?
A vector database is a data store built around one core operation: given a query vector q, return the k stored vectors most similar to q. Every other feature — indexing, filtering, replication, APIs — exists to make that single operation fast, accurate, and convenient at scale. This section explains why that operation is hard, what problems it solves, and how vector databases compare to the data systems you already know.
The core operation. Given a query vector q and n stored vectors, find the k vectors most similar to q. This is the k-nearest-neighbour (KNN) problem. Exact KNN requires computing the distance from q to every stored vector — O(n · d) work per query. At n = 1 000 000 and d = 768, that is 768 million floating-point operations for a single query, far too slow for interactive use. Vector databases solve this by using approximate nearest-neighbour (ANN) algorithms (covered in §5) that trade a small accuracy loss for orders-of-magnitude speed gains. An ANN index can answer the same query in milliseconds by examining only a tiny fraction of the stored vectors.
Use cases. The ability to find “semantically similar” items powers a wide range of applications:
- Semantic search: find documents that match the meaning of a query, not just its keywords — a search for “how to fix a flat tyre” retrieves results about “changing a punctured wheel” even though no words overlap.
- Recommendation: given an item a user just viewed or purchased, return the k most similar items from the catalogue (§11), or surface content preferred by users with similar taste profiles.
- Retrieval-Augmented Generation (RAG): retrieve the most relevant passages from a knowledge base before prompting a large language model, so the model’s answer is grounded in real documents rather than its training data alone (§12).
- Duplicate and near-duplicate detection: identify items that are semantically identical or extremely close to a given item — useful for deduplicating support tickets, detecting plagiarism, or clustering similar product listings.
- Anomaly detection: items whose vectors are far from all stored vectors are likely anomalous, enabling outlier detection without hand-crafted rules.
- Multi-modal search: find images matching a text description, or vice versa, by storing CLIP-style joint embeddings where text and image vectors share the same space.
vs. relational databases. SQL WHERE clauses perform exact matches and range queries on scalar values — equality, greater-than, LIKE, IN. There is no built-in notion of “nearest” for an array of floats. You cannot write ORDER BY similarity(embedding, ?) in standard SQL because the concept does not exist in the relational model. Extensions like pgvector (PostgreSQL) and sqlite-vec (SQLite / Turso) add vector column types, distance functions, and ANN indexes to existing relational databases, letting you combine vector search with traditional filtering in a single query. This course uses sqlite-vec via the libsql crate, which means you get vector search without leaving the SQLite ecosystem you may already know.
vs. full-text search (BM25 / TF-IDF). Traditional keyword search scores documents by how often query terms appear, weighted by rarity across the corpus. It works well when users know the exact vocabulary of the documents they want, but it cannot handle synonymy — “car” and “automobile” are unrelated tokens unless you maintain an explicit synonym list — and it has no concept of sentence-level meaning. Vector search captures both synonymy and broader conceptual similarity because the embedding model learns those relationships from data. In practice, hybrid search — combining a BM25 keyword score with an ANN vector score — outperforms either method alone and is a common pattern in production systems.
Key metrics. When evaluating a vector database or an ANN index, four numbers matter:
- Recall@k: the fraction of the true k nearest neighbours that the ANN algorithm actually returns. A recall@10 of 0.95 means 95 out of every 100 true top-10 results are found; the other 5 are replaced by slightly less similar vectors.
- QPS (queries per second): how many queries the index can serve per second at a given recall target. Higher is better; this is the throughput you care about in production.
- Index build time: the one-time cost paid to construct the search index from raw vectors. HNSW indexes, for example, require inserting each vector into a multi-layer graph, which can take minutes to hours for large datasets.
- Memory footprint: HNSW stores graph edges in RAM alongside the vectors themselves, which limits how large the index can grow on a single machine. Quantisation and disk-backed indexes reduce memory at the cost of recall or latency.
Where sqlite-vec and Turso fit. sqlite-vec is an excellent choice for embedded applications, local development, prototyping, and small-to-medium corpora — up to a few million vectors. It runs inside your application process with no separate server, and Turso adds cloud hosting, replication, and edge caching on top of the same SQLite foundation. For larger-scale deployments — tens of millions of vectors, multi-tenancy, complex filtered search, or distributed indexing — dedicated vector databases such as Pinecone, Qdrant, or Weaviate provide additional infrastructure. The concepts you learn in this course transfer directly: the same embeddings, distance functions, and query patterns apply regardless of which engine you choose.
5. Under the Hood: ANN Algorithms
Why not exact search? Brute-force KNN computes the distance from the query vector to every stored vector — O(n · d) work per query. At n = 1 000 000 vectors, d = 768 dimensions, and 1 000 queries per second, that is roughly 768 billion floating-point operations per second — infeasible on a commodity CPU. Approximate nearest-neighbour (ANN) algorithms find results in O(log n) or sub-linear time at the cost of occasionally missing a few true nearest neighbours. The two dominant families are HNSW and IVFFlat.
HNSW — Hierarchical Navigable Small World. HNSW is the dominant algorithm for in-memory ANN and is the algorithm used by sqlite-vec.
Imagine a multi-level skip list where each level is a proximity graph. The top level is sparse, containing only a small subset of nodes connected by long-range edges that enable fast coarse navigation across the dataset. Each subsequent level adds more nodes and shorter-range edges, increasing density. The bottom level contains every vector, connected to its nearest neighbours by short-range edges that enable precise local search. When a query arrives, the algorithm starts at an entry point on the top level and greedily moves to whichever neighbour is closest to the query vector. When no neighbour on the current level is closer than the current node, the algorithm descends one level and repeats the greedy walk with the denser graph. At the bottom level, it collects the k nearest candidates encountered during traversal and returns them as the result.
HNSW key parameters:
M— the number of bidirectional connections each node maintains per layer. Higher M improves recall (the algorithm has more paths to explore) but increases memory consumption and slows down inserts because more edges must be evaluated and updated. A typical default is 16.ef_construction— the size of the dynamic candidate list used when inserting a new vector into the graph. Higher values produce a higher-quality index (better-connected graph) at the cost of slower index construction. A typical default is 200.ef_search— the size of the candidate list used during query-time traversal. Higher values improve recall at the cost of higher query latency. This parameter is often set equal to k by default, but increasing it is the easiest way to trade latency for accuracy at query time.
HNSW supports incremental inserts with no full rebuild — each new vector is linked into the existing graph structure, which is why the CREATE INDEX ... USING libsql_vector_idx in §6 requires no separate training step. The memory cost of the graph is O(n · M · 4 bytes) on top of the vectors themselves.
IVFFlat — Inverted File with flat quantisation. IVFFlat is the dominant approach for disk-based or GPU-accelerated ANN and is used by default in systems like Faiss and pgvector.
The idea is to partition the dataset into nlist Voronoi cells using k-means clustering during a one-time training step. Each cell is defined by a centroid vector, and every stored vector is assigned to the cell whose centroid is closest. At query time, the algorithm computes the distance from the query to all nlist centroids, selects the nprobe nearest centroids, and then performs exact brute-force search only within those cells — skipping the vast majority of the dataset entirely.
IVFFlat key parameters:
nlist— the number of clusters (Voronoi cells). A common heuristic is to setnlist≈ √n. More clusters mean each cell is smaller, so query-time search within a cell is faster, but training takes longer and very small cells increase the risk of a query’s true neighbours falling in an unsearched cell.nprobe— the number of clusters examined at query time. Highernprobeimproves recall at the cost of higher latency. Settingnprobe=nlistdegenerates to exact search; settingnprobe= 1 checks only the single most likely cluster.
Unlike HNSW, IVFFlat requires a training step (the k-means clustering) before any data can be inserted. Incremental inserts require assigning each new vector to an existing cluster, which can degrade quality over time as the data distribution drifts from the original centroids — periodic retraining is recommended for heavily updated datasets. IVFFlat uses less memory than HNSW for the same n because it does not store graph edges.
sqlite-vec uses HNSW. The libsql_vector_idx index type you created in §6 builds an HNSW index — which is why rows can be inserted incrementally with no training step. The current sqlite-vec API does not expose M or ef parameters directly; sensible defaults are chosen for broad applicability.
Summary table.
| Property | HNSW | IVFFlat |
|---|---|---|
| Query time | O(log n) | O(nprobe · n / nlist) |
| Insert | Incremental | Batch (requires training) |
| Memory | Higher (graph edges) | Lower |
| Recall@10 at defaults | ~0.95+ | ~0.90+ (depends on nprobe) |
| Used by | sqlite-vec, Qdrant, Weaviate | pgvector, Faiss |
Part 3 — Turso + sqlite-vec Basics
6. Setting Up
This section walks through everything you need before writing a single SQL query: adding the right crates, opening a local Turso connection, and loading the sqlite-vec extension that gives SQLite vector-search superpowers.
What You Are Building
Turso is a SQLite-compatible database with built-in support for vector similarity search via the sqlite-vec extension. In local development you use a file-backed SQLite database; in production the same code points at a Turso cloud database. The libsql crate (the Rust client for Turso) speaks the Turso wire protocol and also handles local SQLite files transparently.
Cargo.toml
Create a new binary project and add the following dependencies:
cargo new vec-demo
cd vec-demo
Replace the [dependencies] section of Cargo.toml with:
[dependencies]
libsql = "0.9"
tokio = { version = "1", features = ["full"] }
libsql is the official Rust client for Turso / libSQL databases. It supports both local SQLite files and remote Turso connections with the same API, making it straightforward to develop locally and deploy to the cloud. tokio provides the async runtime — all libsql operations are async.
Add the release-build optimisation profile from the project conventions:
[profile.release]
opt-level = "z"
lto = true
strip = true
codegen-units = 1
Opening a Local Connection
Replace src/main.rs with the following:
use libsql::{Builder, Database};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db: Database = Builder::new_local("vectors.db").build().await?;
let conn = db.connect()?;
// Verify the connection works
let mut rows = conn.query("SELECT sqlite_version()", ()).await?;
if let Some(row) = rows.next().await? {
let version: String = row.get(0)?;
println!("SQLite version: {version}");
}
Ok(())
}
Run it with cargo run. You should see output like:
SQLite version: 3.46.0
A file named vectors.db will appear in the current directory. This is a standard SQLite database — you can open it with any SQLite client to inspect its contents.
Enabling Vector Support with sqlite-vec
The libsql crate ships with sqlite-vec built in. No separate installation is required. Vector functions become available automatically once you use the right column types and functions in your SQL.
The key types and functions you will use throughout this course:
| Construct | Purpose |
|---|---|
F32_BLOB(d) | Column type for storing a d-dimensional float32 vector |
vector(json_array) | Creates a vector from a JSON array literal |
vector_extract(blob) | Converts a stored vector blob back to a JSON array |
vector_distance_cos(a, b) | Cosine distance between two vectors (0 = identical, 2 = opposite) |
libsql_vector_idx(col) | Index type for fast approximate nearest-neighbour search |
vector_top_k(table, query, k) | Table-valued function: returns the k nearest rows to a query vector |
Creating a Vector Table
Extend main to create a table that stores 3-dimensional float32 vectors:
#![allow(unused)]
fn main() {
conn.execute(
"CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY,
label TEXT NOT NULL,
embedding F32_BLOB(3) NOT NULL
)",
(),
).await?;
}
F32_BLOB(3) declares a column that holds a 3-dimensional float32 vector stored as a binary blob. The 3 is the dimensionality — use the actual size of your embedding model’s output (e.g., F32_BLOB(768) for a 768-dimensional model) in real projects.
Creating a Vector Index
Without an index, nearest-neighbour search performs a full table scan — computing the distance from the query to every stored vector. For small tables this is fine; at scale you need an index:
#![allow(unused)]
fn main() {
conn.execute(
"CREATE INDEX IF NOT EXISTS items_vec_idx
ON items (embedding)
USING libsql_vector_idx(embedding)",
(),
).await?;
}
This creates an HNSW index over the embedding column. Queries that use vector_top_k will automatically use this index. The index is updated incrementally as rows are inserted or deleted — no manual rebuild is required.
Putting It Together
At this point your main.rs should look like this:
use libsql::{Builder, Database};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db: Database = Builder::new_local("vectors.db").build().await?;
let conn = db.connect()?;
// Verify connection
let mut rows = conn.query("SELECT sqlite_version()", ()).await?;
if let Some(row) = rows.next().await? {
let version: String = row.get(0)?;
println!("SQLite version: {version}");
}
// Create vector table
conn.execute(
"CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY,
label TEXT NOT NULL,
embedding F32_BLOB(3) NOT NULL
)",
(),
).await?;
// Create HNSW index
conn.execute(
"CREATE INDEX IF NOT EXISTS items_vec_idx
ON items (embedding)
USING libsql_vector_idx(embedding)",
(),
).await?;
println!("Database ready.");
Ok(())
}
cargo run should print:
SQLite version: 3.46.0
Database ready.
You now have a working local vector database. Exercises 1 through 5 build on this foundation, adding data, querying it, and connecting the full embedding-to-search pipeline.
7. Exercise 1 — Storing and Retrieving Vectors
Goal: Insert 6 labelled 3-dimensional vectors into the items table created in §6, then SELECT all rows and print each label alongside its deserialized Vec<f32>.
The Dataset
We use a tiny hand-crafted set of 3D vectors so the results are easy to verify by inspection. The vectors are designed so that items in the same category cluster together — animals near [high, low, low], vehicles near [low, high, low], and programming languages near [low, low, high]:
| id | label | embedding |
|---|---|---|
| 1 | “cat” | [0.9, 0.1, 0.2] |
| 2 | “dog” | [0.8, 0.2, 0.3] |
| 3 | “car” | [0.1, 0.9, 0.1] |
| 4 | “truck” | [0.2, 0.8, 0.2] |
| 5 | “python” | [0.15, 0.1, 0.95] |
| 6 | “rust” | [0.1, 0.05, 0.9] |
In later exercises you will query these vectors to see how cosine distance naturally separates the three clusters.
Step 1 — Formatting a Vector for INSERT
sqlite-vec’s vector(?) SQL function accepts a JSON array string — for example "[0.9,0.1,0.2]". You pass this string as a text parameter and vector() converts it into the internal F32_BLOB format for storage.
A small helper keeps the conversion in one place:
#![allow(unused)]
fn main() {
fn vec_to_json(v: &[f32]) -> String {
format!("[{}]", v.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","))
}
}
Calling vec_to_json(&[0.9, 0.1, 0.2]) returns the string "[0.9,0.1,0.2]", ready to bind as a SQL parameter.
Step 2 — Inserting Rows
Use INSERT OR IGNORE so the program is idempotent — running it twice does not produce duplicate-key errors or duplicate data:
INSERT OR IGNORE INTO items (id, label, embedding) VALUES (?, ?, vector(?))
Define the dataset as a Vec<(i64, &str, Vec<f32>)> and loop over it:
#![allow(unused)]
fn main() {
let data: Vec<(i64, &str, Vec<f32>)> = vec![
(1, "cat", vec![0.9, 0.1, 0.2]),
(2, "dog", vec![0.8, 0.2, 0.3]),
(3, "car", vec![0.1, 0.9, 0.1]),
(4, "truck", vec![0.2, 0.8, 0.2]),
(5, "python", vec![0.15, 0.1, 0.95]),
(6, "rust", vec![0.1, 0.05, 0.9]),
];
for (id, label, embedding) in &data {
conn.execute(
"INSERT OR IGNORE INTO items (id, label, embedding) VALUES (?, ?, vector(?))",
libsql::params![*id, *label, vec_to_json(embedding)],
).await?;
}
println!("Inserted {} rows.", data.len());
}
Step 3 — Selecting and Deserializing
Query all rows back out. The vector_extract function converts the stored F32_BLOB back into a JSON array string that you can parse in Rust:
SELECT id, label, vector_extract(embedding) FROM items ORDER BY id
Add serde_json to your Cargo.toml dependencies for JSON parsing:
serde_json = "1"
Then fetch and deserialize:
#![allow(unused)]
fn main() {
let mut rows = conn
.query("SELECT id, label, vector_extract(embedding) FROM items ORDER BY id", ())
.await?;
while let Some(row) = rows.next().await? {
let id: i64 = row.get(0)?;
let label: String = row.get(1)?;
let json_str: String = row.get(2)?;
let embedding: Vec<f32> = serde_json::from_str(&json_str)?;
println!("{id:<3}{label:<10}{embedding:?}");
}
}
Step 4 — Expected Output
Running cargo run should print:
SQLite version: 3.46.0
Database ready.
Inserted 6 rows.
1 cat [0.9, 0.1, 0.2]
2 dog [0.8, 0.2, 0.3]
3 car [0.1, 0.9, 0.1]
4 truck [0.2, 0.8, 0.2]
5 python [0.15, 0.1, 0.95]
6 rust [0.1, 0.05, 0.9]
Every vector round-trips through the database intact: Rust Vec<f32> → JSON string → vector() → F32_BLOB storage → vector_extract() → JSON string → serde_json → Rust Vec<f32>.
Cargo.toml Additions
Your full [dependencies] section should now be:
[dependencies]
libsql = "0.9"
tokio = { version = "1", features = ["full"] }
serde_json = "1"
Reference Solution
Show full solution
Cargo.toml (dependencies only):
[dependencies]
libsql = "0.9"
tokio = { version = "1", features = ["full"] }
serde_json = "1"
src/main.rs:
use libsql::{Builder, Database};
/// Convert a float slice to a JSON array string for sqlite-vec's `vector()` function.
fn vec_to_json(v: &[f32]) -> String {
format!(
"[{}]",
v.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",")
)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// --- Open database ---
let db: Database = Builder::new_local("vectors.db").build().await?;
let conn = db.connect()?;
// Verify connection
let mut rows = conn.query("SELECT sqlite_version()", ()).await?;
if let Some(row) = rows.next().await? {
let version: String = row.get(0)?;
println!("SQLite version: {version}");
}
// --- Create table (from §6) ---
conn.execute(
"CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY,
label TEXT NOT NULL,
embedding F32_BLOB(3) NOT NULL
)",
(),
)
.await?;
// --- Create HNSW index (from §6) ---
conn.execute(
"CREATE INDEX IF NOT EXISTS items_vec_idx
ON items (embedding)
USING libsql_vector_idx(embedding)",
(),
)
.await?;
println!("Database ready.");
// --- Insert 6 labelled vectors ---
let data: Vec<(i64, &str, Vec<f32>)> = vec![
(1, "cat", vec![0.9, 0.1, 0.2]),
(2, "dog", vec![0.8, 0.2, 0.3]),
(3, "car", vec![0.1, 0.9, 0.1]),
(4, "truck", vec![0.2, 0.8, 0.2]),
(5, "python", vec![0.15, 0.1, 0.95]),
(6, "rust", vec![0.1, 0.05, 0.9]),
];
for (id, label, embedding) in &data {
conn.execute(
"INSERT OR IGNORE INTO items (id, label, embedding) VALUES (?, ?, vector(?))",
libsql::params![*id, *label, vec_to_json(embedding)],
)
.await?;
}
println!("Inserted {} rows.", data.len());
// --- Select and deserialize ---
let mut rows = conn
.query(
"SELECT id, label, vector_extract(embedding) FROM items ORDER BY id",
(),
)
.await?;
while let Some(row) = rows.next().await? {
let id: i64 = row.get(0)?;
let label: String = row.get(1)?;
let json_str: String = row.get(2)?;
let embedding: Vec<f32> = serde_json::from_str(&json_str)?;
println!("{id:<3}{label:<10}{embedding:?}");
}
Ok(())
}
8. Exercise 2 — K-Nearest Neighbor Search
Goal: Given a query vector, use vector_top_k to find the 3 most similar items, join with the items table to retrieve labels and exact cosine distances, and display the results ranked by distance.
Step 1 — Introduce vector_top_k
vector_top_k is a table-valued function (TVF) that returns row IDs of approximate nearest neighbours without performing a full table scan. It leverages the HNSW index created in §6 to navigate directly to the neighbourhood of the query vector. The syntax is:
SELECT i.rowid FROM vector_top_k('items', vector(?), ?) i
The three arguments are:
- Table name (string literal) — the table whose vector index should be searched.
- Query vector — passed through
vector()as a JSON array string, just like when inserting data. - k — the number of nearest neighbours to return.
The function returns only rowid values — it does not return labels, embeddings, or distances. To access other columns you must JOIN the result back to the original table. This design keeps the TVF focused on index traversal and lets you choose exactly which columns to retrieve.
Step 2 — Full KNN Query
Combine the TVF with a JOIN and an exact distance computation to get labelled, ranked results:
SELECT items.id, items.label, vector_distance_cos(items.embedding, vector(?)) AS dist
FROM vector_top_k('items', vector(?), ?) AS knn
JOIN items ON items.rowid = knn.rowid
ORDER BY dist ASC
Notice that the query vector must be passed twice — once as the second argument to vector_top_k (for index traversal to find candidate rows) and once as the second argument to vector_distance_cos (for exact distance computation on those candidates). Both are the same JSON array string bound to separate SQL parameters.
Why two passes? vector_top_k uses the HNSW index to quickly identify which rows are likely nearest neighbours, but it does not return distance values. vector_distance_cos then computes the exact cosine distance for each candidate row, which you use for ranking and display.
Step 3 — Run Three Queries and Print Results
Define a helper function that runs the KNN query for a given query vector and prints the results:
#![allow(unused)]
fn main() {
async fn knn_query(
conn: &libsql::Connection,
query: &[f32],
k: i32,
) -> Result<(), Box<dyn std::error::Error>> {
let q = vec_to_json(query);
let mut rows = conn
.query(
"SELECT items.id, items.label, vector_distance_cos(items.embedding, vector(?)) AS dist
FROM vector_top_k('items', vector(?), ?) AS knn
JOIN items ON items.rowid = knn.rowid
ORDER BY dist ASC",
libsql::params![q.clone(), q.clone(), k],
)
.await?;
println!("Query: {q}");
let mut rank = 1;
while let Some(row) = rows.next().await? {
let label: String = row.get(1)?;
let dist: f64 = row.get(2)?;
println!(" {rank}. {label:<10} dist={dist:.4}");
rank += 1;
}
println!();
Ok(())
}
}
Run three queries, each probing one of the three clusters from the §7 dataset:
#![allow(unused)]
fn main() {
// Animal cluster
knn_query(&conn, &[0.85, 0.15, 0.25], 3).await?;
// Vehicle cluster
knn_query(&conn, &[0.15, 0.85, 0.15], 3).await?;
// Language cluster
knn_query(&conn, &[0.1, 0.05, 0.92], 3).await?;
}
Expected output (exact distances depend on floating-point precision):
Query: [0.85,0.15,0.25]
1. cat dist=0.0023
2. dog dist=0.0089
3. python dist=0.1834
Query: [0.15,0.85,0.15]
1. car dist=0.0006
2. truck dist=0.0030
3. cat dist=0.3885
Query: [0.1,0.05,0.92]
1. rust dist=0.0003
2. python dist=0.0016
3. dog dist=0.2197
Each query correctly identifies the two items in its target cluster as the closest matches. The third result is always from a different cluster with a noticeably larger distance.
Step 4 — ANN vs. Exact Search
For the 6-row dataset used in these exercises, vector_top_k falls back to exact search — the HNSW index has too few nodes to offer a meaningful shortcut, so the algorithm examines every vector. The results are identical to brute-force KNN.
At scale — millions of rows — vector_top_k returns approximate results. The HNSW index navigates the graph greedily, which means some true nearest neighbours may be missed if they are poorly connected in the graph. This is the recall-vs-speed trade-off discussed in §5: the index answers queries in milliseconds instead of seconds, but recall@k is typically ~0.95 rather than 1.0.
vector_distance_cos, by contrast, always gives the exact cosine distance for any specific pair of vectors. It is a pure computation with no approximation. The approximation lives only in the selection of which candidates to evaluate — that is the job of the index.
In practice this means: trust vector_top_k for fast retrieval, but understand that at scale a small fraction of true nearest neighbours may not appear in the result set. If perfect recall is required, you can increase the index’s ef_search parameter (when exposed by the engine) or fall back to brute-force search over a filtered subset.
Reference Solution
Show full solution
Cargo.toml (dependencies only):
[dependencies]
libsql = "0.9"
tokio = { version = "1", features = ["full"] }
serde_json = "1"
src/main.rs:
use libsql::{Builder, Database};
/// Convert a float slice to a JSON array string for sqlite-vec's `vector()` function.
fn vec_to_json(v: &[f32]) -> String {
format!(
"[{}]",
v.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",")
)
}
/// Run a KNN query and print the top-k results with labels and distances.
async fn knn_query(
conn: &libsql::Connection,
query: &[f32],
k: i32,
) -> Result<(), Box<dyn std::error::Error>> {
let q = vec_to_json(query);
let mut rows = conn
.query(
"SELECT items.id, items.label, vector_distance_cos(items.embedding, vector(?)) AS dist
FROM vector_top_k('items', vector(?), ?) AS knn
JOIN items ON items.rowid = knn.rowid
ORDER BY dist ASC",
libsql::params![q.clone(), q.clone(), k],
)
.await?;
println!("Query: {q}");
let mut rank = 1;
while let Some(row) = rows.next().await? {
let label: String = row.get(1)?;
let dist: f64 = row.get(2)?;
println!(" {rank}. {label:<10} dist={dist:.4}");
rank += 1;
}
println!();
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// --- Open database ---
let db: Database = Builder::new_local("vectors.db").build().await?;
let conn = db.connect()?;
// Verify connection
let mut rows = conn.query("SELECT sqlite_version()", ()).await?;
if let Some(row) = rows.next().await? {
let version: String = row.get(0)?;
println!("SQLite version: {version}");
}
// --- Create table (from §6) ---
conn.execute(
"CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY,
label TEXT NOT NULL,
embedding F32_BLOB(3) NOT NULL
)",
(),
)
.await?;
// --- Create HNSW index (from §6) ---
conn.execute(
"CREATE INDEX IF NOT EXISTS items_vec_idx
ON items (embedding)
USING libsql_vector_idx(embedding)",
(),
)
.await?;
println!("Database ready.");
// --- Insert 6 labelled vectors (from §7) ---
let data: Vec<(i64, &str, Vec<f32>)> = vec![
(1, "cat", vec![0.9, 0.1, 0.2]),
(2, "dog", vec![0.8, 0.2, 0.3]),
(3, "car", vec![0.1, 0.9, 0.1]),
(4, "truck", vec![0.2, 0.8, 0.2]),
(5, "python", vec![0.15, 0.1, 0.95]),
(6, "rust", vec![0.1, 0.05, 0.9]),
];
for (id, label, embedding) in &data {
conn.execute(
"INSERT OR IGNORE INTO items (id, label, embedding) VALUES (?, ?, vector(?))",
libsql::params![*id, *label, vec_to_json(embedding)],
)
.await?;
}
println!("Inserted {} rows.", data.len());
// --- KNN queries ---
// Animal cluster
knn_query(&conn, &[0.85, 0.15, 0.25], 3).await?;
// Vehicle cluster
knn_query(&conn, &[0.15, 0.85, 0.15], 3).await?;
// Language cluster
knn_query(&conn, &[0.1, 0.05, 0.92], 3).await?;
Ok(())
}
Part 4 — Real Applications
9. Generating Embeddings in Rust
Before you can search by meaning, you need a way to convert text into vectors. This section covers two approaches available in Rust: running a local embedding model with fastembed-rs (no API key, works offline, suited for smaller models) and calling an HTTP embedding API such as the OpenAI Embeddings endpoint (larger, higher-quality models at the cost of latency and a network dependency).
Option A — fastembed-rs (local, recommended for exercises). The fastembed crate wraps ONNX Runtime and ships pre-trained sentence-transformer models. No API key is required, it works fully offline after the first run, inference is CPU-only, and results are deterministic — all properties that make it ideal for the exercises in §10–§12. Add it to your project:
fastembed = "4"
The default model is BGE-Small-EN-v1.5, which produces 384-dimensional vectors. On first use, the model weights (~130 MB) are downloaded to ~/.cache/huggingface/hub/ and reused from there on subsequent runs. Here is the minimal code to embed two strings:
#![allow(unused)]
fn main() {
use fastembed::{TextEmbedding, InitOptions, EmbeddingModel};
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::BGESmallENV15)
.with_show_download_progress(true),
)?;
let docs = vec!["hello world", "Rust is fast"];
let embeddings: Vec<Vec<f32>> = model.embed(docs, None)?;
// embeddings[0].len() == 384
}
Batch embedding matters. Passing multiple strings in a single model.embed() call is significantly more efficient than embedding one string at a time, because the runtime can batch tensor operations. Always collect your corpus into a Vec and embed it in one shot rather than looping with individual calls.
Option B — HTTP API (OpenAI-compatible). When you need a specific production-grade model — or your deployment already relies on an external embeddings service — you can call an OpenAI-compatible endpoint instead. You will need three additional crates:
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
Define request and response types that match the API schema:
#![allow(unused)]
fn main() {
#[derive(serde::Serialize)]
struct EmbedRequest {
model: String,
input: Vec<String>,
}
#[derive(serde::Deserialize)]
struct EmbedResponse {
data: Vec<EmbedData>,
}
#[derive(serde::Deserialize)]
struct EmbedData {
embedding: Vec<f32>,
}
async fn embed_texts(texts: Vec<String>) -> anyhow::Result<Vec<Vec<f32>>> {
let api_key = std::env::var("OPENAI_API_KEY")?;
let client = reqwest::Client::new();
let res: EmbedResponse = client
.post("https://api.openai.com/v1/embeddings")
.bearer_auth(&api_key)
.json(&EmbedRequest {
model: "text-embedding-3-small".into(),
input: texts,
})
.send()
.await?
.json()
.await?;
Ok(res.data.into_iter().map(|d| d.embedding).collect())
}
}
Choosing between them. For the remaining exercises in this course (§10–§12), use fastembed. It requires no API key, has no network dependency, and produces deterministic results — which means your assertions will be stable across runs. Inference is sub-100 ms per batch on a modern CPU, more than fast enough for the dataset sizes used here. Reach for the HTTP approach when you need a specific production-grade model, when your application already communicates with an embeddings service, or when you need multilingual support beyond what the local models offer.
Dimensionality note. The F32_BLOB(d) column type you define in your schema must match the model’s output dimension exactly — you cannot mix dimensions within a single column. The toy examples in §6–§8 used F32_BLOB(3) for hand-written 3-D vectors. Now that you are working with real models, change that declaration to F32_BLOB(384) for BGE-Small-EN-v1.5, F32_BLOB(768) for all-MiniLM-L6-v2, or F32_BLOB(1536) for OpenAI’s text-embedding-3-small. If you change the dimension of an existing column, you must drop and recreate both the column and its associated vector index — sqlite-vec cannot reindex vectors whose dimensions have changed.
10. Exercise 3 — Semantic Document Search
Goal: Embed a corpus of 15 short text passages with fastembed-rs, store the embeddings in Turso, then accept a natural-language query, embed it, and return the top-5 most semantically relevant passages — with no keyword matching.
Setup
Create a new project (or extend your existing vec-demo crate). Your Cargo.toml dependencies:
[dependencies]
libsql = "0.9"
fastembed = "4"
tokio = { version = "1", features = ["full"] }
serde_json = "1"
The table schema uses F32_BLOB(384) because BGE-Small-EN-v1.5 produces 384-dimensional embeddings:
CREATE TABLE IF NOT EXISTS docs (
id INTEGER PRIMARY KEY,
passage TEXT NOT NULL,
embedding F32_BLOB(384) NOT NULL
)
Corpus
Use these 15 passages spanning three topics.
Rust programming (5):
- “Rust uses an ownership system to guarantee memory safety without a garbage collector.”
- “The borrow checker enforces that references do not outlive the data they point to.”
- “Cargo is Rust’s build system and package manager, used to manage dependencies and run tests.”
- “Rust’s trait system enables zero-cost abstractions and compile-time polymorphism.”
- “Async Rust uses futures and the tokio runtime to handle concurrent I/O efficiently.”
Astronomy (5):
- “A black hole is a region of spacetime where gravity is so strong that nothing can escape.”
- “The Milky Way galaxy contains an estimated 100 to 400 billion stars.”
- “Neutron stars are the collapsed cores of massive stars, with densities exceeding atomic nuclei.”
- “The cosmic microwave background is the thermal radiation left over from the early universe.”
- “Exoplanets are planets outside our solar system, detected via transit photometry or radial velocity.”
Cooking (5):
- “Maillard reaction gives browned foods their distinctive flavour through amino acid and sugar reactions.”
- “Sous vide cooking involves sealing food in vacuum bags and cooking at precise low temperatures.”
- “Emulsification combines two immiscible liquids, such as oil and water, using an emulsifier like lecithin.”
- “Fermentation converts sugars to acids or alcohol using microorganisms, used in bread, beer, and yogurt.”
- “Knife skills — julienne, brunoise, chiffonade — determine the surface area and cooking time of vegetables.”
Step 1 — Embed the corpus
Use fastembed::TextEmbedding with the default model (BGE-Small-EN-v1.5) to embed all 15 passages in a single model.embed() call. This returns a Vec<Vec<f32>> — one 384-dimensional vector per passage.
#![allow(unused)]
fn main() {
use fastembed::{TextEmbedding, InitOptions, EmbeddingModel};
let model = TextEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGESmallENV15,
show_download_progress: true,
..Default::default()
})?;
let embeddings = model.embed(passages.clone(), None)?;
}
Step 2 — Insert into Turso
Loop over the passages and their corresponding embeddings. Convert each Vec<f32> to a JSON string so it can be passed to the vector(?) SQL function. Use INSERT OR IGNORE so re-runs are idempotent.
#![allow(unused)]
fn main() {
fn vec_to_json(v: &[f32]) -> String {
let parts: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
format!("[{}]", parts.join(","))
}
for (i, (passage, emb)) in passages.iter().zip(embeddings.iter()).enumerate() {
let json = vec_to_json(emb);
conn.execute(
"INSERT OR IGNORE INTO docs (id, passage, embedding) VALUES (?, ?, vector(?))",
libsql::params![i as i64, passage.as_str(), json.as_str()],
)
.await?;
}
}
Step 3 — Embed the query and search
Embed the query string the same way you embedded the corpus — using model.embed() with a single-element slice. Then run vector_top_k('docs_idx', vector(?), 5) and join back to the docs table to retrieve the passage text and cosine distance.
#![allow(unused)]
fn main() {
let query = "memory safety in systems programming";
let q_emb = model.embed(vec![query.to_string()], None)?;
let q_json = vec_to_json(&q_emb[0]);
let mut rows = conn
.query(
"SELECT d.passage, v.distance
FROM vector_top_k('docs_idx', vector(?), 5) AS v
JOIN docs AS d ON d.rowid = v.id
ORDER BY v.distance",
libsql::params![q_json.as_str()],
)
.await?;
}
Step 4 — Run three queries and verify
Run the following queries and confirm the results cluster by topic:
| Query | Expected top results |
|---|---|
"memory safety in systems programming" | Rust passages |
"stars and galaxies" | Astronomy passages |
"fermentation and cooking techniques" | Cooking passages |
Print each result ranked by distance, showing the passage text and the cosine distance score:
#![allow(unused)]
fn main() {
let queries = vec![
"memory safety in systems programming",
"stars and galaxies",
"fermentation and cooking techniques",
];
for query in &queries {
println!("\n=== Query: \"{query}\" ===\n");
let q_emb = model.embed(vec![query.to_string()], None)?;
let q_json = vec_to_json(&q_emb[0]);
let mut rows = conn
.query(
"SELECT d.passage, v.distance
FROM vector_top_k('docs_idx', vector(?), 5) AS v
JOIN docs AS d ON d.rowid = v.id
ORDER BY v.distance",
libsql::params![q_json.as_str()],
)
.await?;
let mut rank = 1;
while let Some(row) = rows.next().await? {
let passage: String = row.get(0)?;
let distance: f64 = row.get(1)?;
println!(" {rank}. [{distance:.4}] {passage}");
rank += 1;
}
}
}
Reference Solution
Show full solution
// src/main.rs — Semantic Document Search (Exercise 3)
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use libsql::Builder;
fn vec_to_json(v: &[f32]) -> String {
let parts: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
format!("[{}]", parts.join(","))
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// ── 1. Connect to Turso (local file) ──
let db = Builder::new_local("semantic_search.db").build().await?;
let conn = db.connect()?;
// ── 2. Create the docs table ──
conn.execute(
"CREATE TABLE IF NOT EXISTS docs (
id INTEGER PRIMARY KEY,
passage TEXT NOT NULL,
embedding F32_BLOB(384) NOT NULL
)",
(),
)
.await?;
// ── 3. Create the vector index ──
conn.execute(
"CREATE INDEX IF NOT EXISTS docs_idx ON docs (libsql_vector_idx(embedding))",
(),
)
.await?;
// ── 4. Define the corpus ──
let passages: Vec<String> = vec![
// Rust programming
"Rust uses an ownership system to guarantee memory safety without a garbage collector.",
"The borrow checker enforces that references do not outlive the data they point to.",
"Cargo is Rust's build system and package manager, used to manage dependencies and run tests.",
"Rust's trait system enables zero-cost abstractions and compile-time polymorphism.",
"Async Rust uses futures and the tokio runtime to handle concurrent I/O efficiently.",
// Astronomy
"A black hole is a region of spacetime where gravity is so strong that nothing can escape.",
"The Milky Way galaxy contains an estimated 100 to 400 billion stars.",
"Neutron stars are the collapsed cores of massive stars, with densities exceeding atomic nuclei.",
"The cosmic microwave background is the thermal radiation left over from the early universe.",
"Exoplanets are planets outside our solar system, detected via transit photometry or radial velocity.",
// Cooking
"Maillard reaction gives browned foods their distinctive flavour through amino acid and sugar reactions.",
"Sous vide cooking involves sealing food in vacuum bags and cooking at precise low temperatures.",
"Emulsification combines two immiscible liquids, such as oil and water, using an emulsifier like lecithin.",
"Fermentation converts sugars to acids or alcohol using microorganisms, used in bread, beer, and yogurt.",
"Knife skills — julienne, brunoise, chiffonade — determine the surface area and cooking time of vegetables.",
]
.into_iter()
.map(String::from)
.collect();
// ── 5. Embed the corpus ──
let model = TextEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGESmallENV15,
show_download_progress: true,
..Default::default()
})?;
let embeddings = model.embed(passages.clone(), None)?;
// ── 6. Insert passages + embeddings ──
for (i, (passage, emb)) in passages.iter().zip(embeddings.iter()).enumerate() {
let json = vec_to_json(emb);
conn.execute(
"INSERT OR IGNORE INTO docs (id, passage, embedding) VALUES (?, ?, vector(?))",
libsql::params![i as i64, passage.as_str(), json.as_str()],
)
.await?;
}
println!("Inserted {} passages.\n", passages.len());
// ── 7. Run three queries ──
let queries = vec![
"memory safety in systems programming",
"stars and galaxies",
"fermentation and cooking techniques",
];
for query in &queries {
println!("=== Query: \"{query}\" ===\n");
let q_emb = model.embed(vec![query.to_string()], None)?;
let q_json = vec_to_json(&q_emb[0]);
let mut rows = conn
.query(
"SELECT d.passage, v.distance
FROM vector_top_k('docs_idx', vector(?), 5) AS v
JOIN docs AS d ON d.rowid = v.id
ORDER BY v.distance",
libsql::params![q_json.as_str()],
)
.await?;
let mut rank = 1;
while let Some(row) = rows.next().await? {
let passage: String = row.get(0)?;
let distance: f64 = row.get(1)?;
println!(" {rank}. [{distance:.4}] {passage}");
rank += 1;
}
println!();
}
Ok(())
}
11. Exercise 4 — Recommendation Engine
Goal: Build an item-based recommendation engine. Store item feature vectors in Turso, then given a target item, find the k most similar items using KNN and exclude the query item from the results.
We will use hand-crafted 5-dimensional feature vectors for a product catalogue (no fastembed dependency — this keeps the focus on the recommendation logic itself). The five dimensions represent affinity scores for: [electronics, clothing, sports, food, books].
Catalogue (10 items):
| id | name | embedding |
|---|---|---|
| 1 | Laptop | [0.95, 0.0, 0.1, 0.0, 0.2] |
| 2 | Mechanical Keyboard | [0.85, 0.0, 0.0, 0.0, 0.1] |
| 3 | USB-C Hub | [0.9, 0.0, 0.0, 0.0, 0.0] |
| 4 | Running Shoes | [0.0, 0.6, 0.9, 0.0, 0.0] |
| 5 | Yoga Mat | [0.0, 0.2, 0.95, 0.0, 0.0] |
| 6 | Water Bottle | [0.1, 0.1, 0.7, 0.0, 0.0] |
| 7 | T-Shirt | [0.0, 0.95, 0.1, 0.0, 0.0] |
| 8 | Cookbook | [0.0, 0.0, 0.0, 0.6, 0.9] |
| 9 | Protein Bar | [0.0, 0.0, 0.3, 0.95, 0.0] |
| 10 | Novel | [0.0, 0.0, 0.0, 0.1, 0.95] |
Step 1 — Schema
Create a products table and an HNSW vector index:
#![allow(unused)]
fn main() {
conn.execute(
"CREATE TABLE IF NOT EXISTS products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
embedding F32_BLOB(5) NOT NULL
)",
(),
)
.await?;
conn.execute(
"CREATE INDEX IF NOT EXISTS products_idx
ON products (libsql_vector_idx(embedding))",
(),
)
.await?;
}
Step 2 — Insert items
Use the same pattern as Exercise 1: format each Vec<f32> as a JSON array string and insert with INSERT OR IGNORE:
#![allow(unused)]
fn main() {
let products: Vec<(i64, &str, Vec<f32>)> = vec![
(1, "Laptop", vec![0.95, 0.0, 0.1, 0.0, 0.2]),
(2, "Mechanical Keyboard", vec![0.85, 0.0, 0.0, 0.0, 0.1]),
(3, "USB-C Hub", vec![0.9, 0.0, 0.0, 0.0, 0.0]),
(4, "Running Shoes", vec![0.0, 0.6, 0.9, 0.0, 0.0]),
(5, "Yoga Mat", vec![0.0, 0.2, 0.95, 0.0, 0.0]),
(6, "Water Bottle", vec![0.1, 0.1, 0.7, 0.0, 0.0]),
(7, "T-Shirt", vec![0.0, 0.95, 0.1, 0.0, 0.0]),
(8, "Cookbook", vec![0.0, 0.0, 0.0, 0.6, 0.9]),
(9, "Protein Bar", vec![0.0, 0.0, 0.3, 0.95, 0.0]),
(10, "Novel", vec![0.0, 0.0, 0.0, 0.1, 0.95]),
];
for (id, name, emb) in &products {
let emb_json = serde_json::to_string(emb)?;
conn.execute(
"INSERT OR IGNORE INTO products (id, name, embedding)
VALUES (?, ?, vector(?))",
libsql::params![*id, *name, emb_json.as_str()],
)
.await?;
}
}
Step 3 — Recommend function
Write a helper that retrieves recommendations for a given item:
#![allow(unused)]
fn main() {
async fn recommend(
conn: &libsql::Connection,
item_id: i64,
k: usize,
) -> Result<Vec<(String, f64)>, Box<dyn std::error::Error>> {
// 1. Get the query item's embedding as a JSON string.
let mut stmt = conn
.prepare("SELECT vector_extract(embedding) FROM products WHERE id = ?")
.await?;
let mut rows = stmt.query(libsql::params![item_id]).await?;
let row = rows
.next()
.await?
.ok_or("item not found")?;
let query_vec: String = row.get(0)?;
// 2. Use vector_top_k with k+1 to leave room for the query item itself.
let sql = format!(
"SELECT products.id, products.name,
vector_distance_cos(products.embedding, vector(?1)) AS distance
FROM vector_top_k('products_idx', ?1, {limit})
JOIN products ON products.rowid = id
WHERE products.id != ?2
ORDER BY distance
LIMIT ?3",
limit = k + 1
);
let mut stmt = conn.prepare(&sql).await?;
let mut rows = stmt
.query(libsql::params![query_vec.as_str(), item_id, k as i64])
.await?;
// 3. Collect (name, distance) pairs.
let mut results = Vec::new();
while let Some(row) = rows.next().await? {
let name: String = row.get(1)?;
let distance: f64 = row.get(2)?;
results.push((name, distance));
}
Ok(results)
}
}
The key ideas:
- Retrieve the query vector —
vector_extractreturns the stored embedding as a JSON string that can be passed straight back tovector_top_k. - Over-fetch by one — request
k + 1candidates becausevector_top_kwill return the query item itself (distance ≈ 0). TheWHERE products.id != ?2clause filters it out. - Cosine distance —
vector_distance_cosreturns a value between 0 (identical) and 2 (opposite). Lower means more similar.
Step 4 — Print recommendations
Request recommendations for three representative items and verify the clusters make sense:
#![allow(unused)]
fn main() {
let queries = vec![
(1, "Laptop"),
(4, "Running Shoes"),
(8, "Cookbook"),
];
for (id, name) in &queries {
let recs = recommend(&conn, *id, 2).await?;
let rec_str: Vec<String> = recs
.iter()
.map(|(n, d)| format!("{n} ({d:.3})"))
.collect();
println!(
"Customers who liked {name} also liked: {}",
rec_str.join(", ")
);
}
}
Expected output (distances are approximate):
Customers who liked Laptop also liked: Mechanical Keyboard (0.023), USB-C Hub (0.041)
Customers who liked Running Shoes also liked: Yoga Mat (0.019), Water Bottle (0.063)
Customers who liked Cookbook also liked: Novel (0.168), Protein Bar (0.397)
- Laptop → electronics cluster (Mechanical Keyboard, USB-C Hub)
- Running Shoes → sports cluster (Yoga Mat, Water Bottle)
- Cookbook → food/books cluster (Novel, Protein Bar)
Show full solution
use libsql::Builder;
/// Find the k most similar products to the given item, excluding the item itself.
async fn recommend(
conn: &libsql::Connection,
item_id: i64,
k: usize,
) -> Result<Vec<(String, f64)>, Box<dyn std::error::Error>> {
// Retrieve the query item's embedding as a JSON string.
let mut stmt = conn
.prepare("SELECT vector_extract(embedding) FROM products WHERE id = ?")
.await?;
let mut rows = stmt.query(libsql::params![item_id]).await?;
let row = rows.next().await?.ok_or("item not found")?;
let query_vec: String = row.get(0)?;
// KNN search: fetch k+1 to account for the query item appearing in its
// own results, then filter it out.
let sql = format!(
"SELECT products.id, products.name,
vector_distance_cos(products.embedding, vector(?1)) AS distance
FROM vector_top_k('products_idx', ?1, {limit})
JOIN products ON products.rowid = id
WHERE products.id != ?2
ORDER BY distance
LIMIT ?3",
limit = k + 1
);
let mut stmt = conn.prepare(&sql).await?;
let mut rows = stmt
.query(libsql::params![query_vec.as_str(), item_id, k as i64])
.await?;
let mut results = Vec::new();
while let Some(row) = rows.next().await? {
let name: String = row.get(1)?;
let distance: f64 = row.get(2)?;
results.push((name, distance));
}
Ok(results)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db = Builder::new_local(":memory:").build().await?;
let conn = db.connect()?;
// --- Schema ---
conn.execute(
"CREATE TABLE IF NOT EXISTS products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
embedding F32_BLOB(5) NOT NULL
)",
(),
)
.await?;
conn.execute(
"CREATE INDEX IF NOT EXISTS products_idx
ON products (libsql_vector_idx(embedding))",
(),
)
.await?;
// --- Seed data ---
let products: Vec<(i64, &str, Vec<f32>)> = vec![
(1, "Laptop", vec![0.95, 0.0, 0.1, 0.0, 0.2]),
(2, "Mechanical Keyboard", vec![0.85, 0.0, 0.0, 0.0, 0.1]),
(3, "USB-C Hub", vec![0.9, 0.0, 0.0, 0.0, 0.0]),
(4, "Running Shoes", vec![0.0, 0.6, 0.9, 0.0, 0.0]),
(5, "Yoga Mat", vec![0.0, 0.2, 0.95, 0.0, 0.0]),
(6, "Water Bottle", vec![0.1, 0.1, 0.7, 0.0, 0.0]),
(7, "T-Shirt", vec![0.0, 0.95, 0.1, 0.0, 0.0]),
(8, "Cookbook", vec![0.0, 0.0, 0.0, 0.6, 0.9]),
(9, "Protein Bar", vec![0.0, 0.0, 0.3, 0.95, 0.0]),
(10, "Novel", vec![0.0, 0.0, 0.0, 0.1, 0.95]),
];
for (id, name, emb) in &products {
let emb_json = serde_json::to_string(emb)?;
conn.execute(
"INSERT OR IGNORE INTO products (id, name, embedding)
VALUES (?, ?, vector(?))",
libsql::params![*id, *name, emb_json.as_str()],
)
.await?;
}
// --- Recommendations ---
let queries = vec![
(1, "Laptop"),
(4, "Running Shoes"),
(8, "Cookbook"),
];
for (id, name) in &queries {
let recs = recommend(&conn, *id, 2).await?;
let rec_str: Vec<String> = recs
.iter()
.map(|(n, d)| format!("{n} ({d:.3})"))
.collect();
println!(
"Customers who liked {name} also liked: {}",
rec_str.join(", ")
);
}
Ok(())
}
12. Exercise 5 — Retrieval-Augmented Generation
Goal: Build a retrieval-augmented generation (RAG) pipeline that:
- Stores the 15-passage corpus from §10 in Turso
- Accepts a natural-language question
- Retrieves the top-3 most relevant passages using vector KNN
- Injects the passages into a prompt as context
- Sends the prompt to an OpenAI-compatible LLM API
- Prints the grounded answer
Setup:
[dependencies]
libsql = "0.9"
fastembed = "4"
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["full"] }
You will need an API key stored in the OPENAI_API_KEY environment variable. This exercise works with any OpenAI-compatible provider — OpenAI itself, Groq, Together AI, or a local Ollama instance (base URL http://localhost:11434/v1, model llama3.2). Adjust the base URL and model name accordingly if you are not using OpenAI.
Step 1 — Retrieval function
Reuse the semantic search logic from §10. Write a function that embeds the query, runs a KNN search, and returns the top-k passage texts:
#![allow(unused)]
fn main() {
async fn retrieve(
conn: &libsql::Connection,
model: &TextEmbedding,
query: &str,
k: usize,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let q_emb = model.embed(vec![query.to_string()], None)?;
let q_json = vec_to_json(&q_emb[0]);
let mut rows = conn
.query(
"SELECT d.passage
FROM vector_top_k('docs_idx', vector(?), ?) AS v
JOIN docs AS d ON d.rowid = v.id
ORDER BY v.distance",
libsql::params![q_json.as_str(), k as i64],
)
.await?;
let mut passages = Vec::new();
while let Some(row) = rows.next().await? {
let passage: String = row.get(0)?;
passages.push(passage);
}
Ok(passages)
}
}
Step 2 — Prompt construction
Build a prompt string that instructs the model to answer using only the retrieved context:
#![allow(unused)]
fn main() {
fn build_prompt(context_passages: &[String], question: &str) -> String {
let mut prompt = String::from(
"You are a helpful assistant. Answer the question using only the provided context.\n\
If the context does not contain enough information, say so.\n\n\
Context:\n",
);
for passage in context_passages {
prompt.push_str(passage);
prompt.push_str("\n\n");
}
prompt.push_str(&format!("Question: {question}\n\nAnswer:"));
prompt
}
}
Step 3 — LLM API call
POST to the chat completions endpoint. Define request and response structs with serde, then send the prompt as a user message:
#![allow(unused)]
fn main() {
#[derive(serde::Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
}
#[derive(serde::Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(serde::Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(serde::Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(serde::Deserialize)]
struct ResponseMessage {
content: String,
}
async fn call_llm(
client: &reqwest::Client,
api_key: &str,
prompt: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let request = ChatRequest {
model: "gpt-4o-mini".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}],
};
let resp = client
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(api_key)
.json(&request)
.send()
.await?
.error_for_status()?
.json::<ChatResponse>()
.await?;
Ok(resp.choices[0].message.content.clone())
}
}
Step 4 — Wire it together and run
Set up the database and corpus exactly as in §10, then run three example questions that exercise each topic cluster:
#![allow(unused)]
fn main() {
let questions = vec![
"How does Rust ensure memory safety?",
"What is a black hole?",
"What is the Maillard reaction?",
];
let client = reqwest::Client::new();
let api_key = std::env::var("OPENAI_API_KEY")?;
for question in &questions {
println!("=== Question: \"{question}\" ===\n");
let passages = retrieve(&conn, &model, question, 3).await?;
println!("Retrieved passages:");
for (i, p) in passages.iter().enumerate() {
println!(" {}: {p}", i + 1);
}
println!();
let prompt = build_prompt(&passages, question);
let answer = call_llm(&client, &api_key, &prompt).await?;
println!("Answer: {answer}\n");
}
}
Each question should pull passages from the matching cluster — Rust passages for the first, astronomy for the second, and cooking for the third. The LLM’s answer will be grounded in those passages rather than relying on its own parametric knowledge.
Step 5 — Discussion: RAG patterns
Chunk size and overlap. The 15-passage corpus used here is already conveniently pre-chunked into single sentences, but real documents are rarely so tidy. In practice, long documents are split into overlapping chunks — typically 200–500 tokens with a 50–100 token overlap between consecutive chunks. The overlap ensures that sentences near a chunk boundary are not orphaned from their surrounding context, which would hurt retrieval quality. Choosing the right chunk size is a trade-off: smaller chunks yield more precise retrieval but lose broader context, while larger chunks retain context at the cost of noisier matches.
Re-ranking. The ANN index returns approximate nearest neighbors quickly, but the ranking is based on a single embedding similarity score. A cross-encoder re-ranker — a model that takes (query, passage) pairs as input and produces a relevance score — can re-order the top-k candidates for significantly better precision. The typical pattern is to retrieve a larger set (e.g., top-20) with ANN and then re-rank to the final top-3 or top-5 with the cross-encoder.
Hybrid search. Semantic (ANN) search excels at matching meaning but can miss exact keywords, while keyword-based search (BM25) is great at exact term matching but blind to synonyms. Combining both — often called hybrid search — frequently outperforms either approach alone. A common fusion strategy is Reciprocal Rank Fusion (RRF), which merges the two ranked lists by summing the reciprocal of each result’s rank.
Context window limits. The number of passages you can inject depends on the model’s context length and the average passage length. GPT-4o-mini supports 128k tokens, but stuffing the entire context window with retrieved passages introduces noise and increases latency and cost. A good heuristic is to inject only enough passages to cover the question — typically 3 to 5 short passages or 1 to 2 longer chunks — and to place the most relevant passages first.
Reference Solution
Show full solution
// src/main.rs — Retrieval-Augmented Generation (Exercise 5)
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use libsql::Builder;
fn vec_to_json(v: &[f32]) -> String {
let parts: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
format!("[{}]", parts.join(","))
}
/// Retrieve the top-k passages most relevant to `query` using vector KNN.
async fn retrieve(
conn: &libsql::Connection,
model: &TextEmbedding,
query: &str,
k: usize,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let q_emb = model.embed(vec![query.to_string()], None)?;
let q_json = vec_to_json(&q_emb[0]);
let mut rows = conn
.query(
"SELECT d.passage
FROM vector_top_k('docs_idx', vector(?), ?) AS v
JOIN docs AS d ON d.rowid = v.id
ORDER BY v.distance",
libsql::params![q_json.as_str(), k as i64],
)
.await?;
let mut passages = Vec::new();
while let Some(row) = rows.next().await? {
let passage: String = row.get(0)?;
passages.push(passage);
}
Ok(passages)
}
/// Build a RAG prompt that instructs the model to answer from context only.
fn build_prompt(context_passages: &[String], question: &str) -> String {
let mut prompt = String::from(
"You are a helpful assistant. Answer the question using only the provided context.\n\
If the context does not contain enough information, say so.\n\n\
Context:\n",
);
for passage in context_passages {
prompt.push_str(passage);
prompt.push_str("\n\n");
}
prompt.push_str(&format!("Question: {question}\n\nAnswer:"));
prompt
}
#[derive(serde::Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
}
#[derive(serde::Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(serde::Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(serde::Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(serde::Deserialize)]
struct ResponseMessage {
content: String,
}
/// Send the prompt to an OpenAI-compatible chat completions API.
async fn call_llm(
client: &reqwest::Client,
api_key: &str,
prompt: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let request = ChatRequest {
model: "gpt-4o-mini".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}],
};
let resp = client
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(api_key)
.json(&request)
.send()
.await?
.error_for_status()?
.json::<ChatResponse>()
.await?;
Ok(resp.choices[0].message.content.clone())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// ── 1. Connect to Turso (local file) ──
let db = Builder::new_local("rag_search.db").build().await?;
let conn = db.connect()?;
// ── 2. Create the docs table ──
conn.execute(
"CREATE TABLE IF NOT EXISTS docs (
id INTEGER PRIMARY KEY,
passage TEXT NOT NULL,
embedding F32_BLOB(384) NOT NULL
)",
(),
)
.await?;
// ── 3. Create the vector index ──
conn.execute(
"CREATE INDEX IF NOT EXISTS docs_idx ON docs (libsql_vector_idx(embedding))",
(),
)
.await?;
// ── 4. Define the corpus ──
let passages: Vec<String> = vec![
// Rust programming
"Rust uses an ownership system to guarantee memory safety without a garbage collector.",
"The borrow checker enforces that references do not outlive the data they point to.",
"Cargo is Rust's build system and package manager, used to manage dependencies and run tests.",
"Rust's trait system enables zero-cost abstractions and compile-time polymorphism.",
"Async Rust uses futures and the tokio runtime to handle concurrent I/O efficiently.",
// Astronomy
"A black hole is a region of spacetime where gravity is so strong that nothing can escape.",
"The Milky Way galaxy contains an estimated 100 to 400 billion stars.",
"Neutron stars are the collapsed cores of massive stars, with densities exceeding atomic nuclei.",
"The cosmic microwave background is the thermal radiation left over from the early universe.",
"Exoplanets are planets outside our solar system, detected via transit photometry or radial velocity.",
// Cooking
"Maillard reaction gives browned foods their distinctive flavour through amino acid and sugar reactions.",
"Sous vide cooking involves sealing food in vacuum bags and cooking at precise low temperatures.",
"Emulsification combines two immiscible liquids, such as oil and water, using an emulsifier like lecithin.",
"Fermentation converts sugars to acids or alcohol using microorganisms, used in bread, beer, and yogurt.",
"Knife skills — julienne, brunoise, chiffonade — determine the surface area and cooking time of vegetables.",
]
.into_iter()
.map(String::from)
.collect();
// ── 5. Embed the corpus ──
let model = TextEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGESmallENV15,
show_download_progress: true,
..Default::default()
})?;
let embeddings = model.embed(passages.clone(), None)?;
// ── 6. Insert passages + embeddings ──
for (i, (passage, emb)) in passages.iter().zip(embeddings.iter()).enumerate() {
let json = vec_to_json(emb);
conn.execute(
"INSERT OR IGNORE INTO docs (id, passage, embedding) VALUES (?, ?, vector(?))",
libsql::params![i as i64, passage.as_str(), json.as_str()],
)
.await?;
}
println!("Inserted {} passages.\n", passages.len());
// ── 7. RAG pipeline ──
let api_key = std::env::var("OPENAI_API_KEY")?;
let client = reqwest::Client::new();
let questions = vec![
"How does Rust ensure memory safety?",
"What is a black hole?",
"What is the Maillard reaction?",
];
for question in &questions {
println!("=== Question: \"{question}\" ===\n");
let context = retrieve(&conn, &model, question, 3).await?;
println!("Retrieved passages:");
for (i, p) in context.iter().enumerate() {
println!(" {}: {p}", i + 1);
}
println!();
let prompt = build_prompt(&context, question);
let answer = call_llm(&client, &api_key, &prompt).await?;
println!("Answer: {answer}\n");
}
Ok(())
}
Git Worktrees
What Are They?
A git worktree is a checked-out copy of a branch that lives in a separate directory on your filesystem, but shares the same underlying git repository (.git database) as your main clone.
Normally, a single git clone has one working directory — the files you see and edit. A worktree lets you have multiple working directories at once, each on a different branch, all tied to the same repository. You don’t clone the repo a second time; you project a second (or third, or fourth) working directory from the one clone.
~/.git/ ← shared: all objects, refs, config
~/projects/myapp/ ← main worktree (branch: main)
~/projects/myapp-feature/ ← linked worktree (branch: feature/auth)
~/projects/myapp-hotfix/ ← linked worktree (branch: hotfix/crash)
All three directories above read and write to the same .git object store. Commits made in any worktree are immediately visible to all others.
Core Commands
Add a worktree
# Check out an existing branch into a new directory
git worktree add ../myapp-feature feature/auth
# Create a new branch and check it out in one step
git worktree add -b feature/new-thing ../myapp-new-thing main
The first argument is the path for the new directory. The second is the branch name (or starting commit).
List worktrees
git worktree list
Example output:
/home/you/projects/myapp abc1234 [main]
/home/you/projects/myapp-feature def5678 [feature/auth]
Remove a worktree
# Remove the directory and deregister the worktree
git worktree remove ../myapp-feature
# Force-remove even if there are untracked or modified files
git worktree remove --force ../myapp-feature
After deleting a linked worktree directory manually (e.g. with rm -rf), git still knows about it. Clean up the stale reference with:
git worktree prune
Move a worktree
git worktree move ../myapp-feature ../myapp-auth-feature
Typical Workflows
Work on two branches simultaneously
You’re deep in a feature branch when a urgent bug report comes in. Instead of stashing your work or committing a WIP, you add a worktree:
git worktree add ../myapp-hotfix main
cd ../myapp-hotfix
# fix the bug, commit, push — then come back to your feature
Your feature working directory is untouched.
Run tests against a different branch without switching
git worktree add /tmp/myapp-test origin/release/2.0
cd /tmp/myapp-test
cargo test
You keep your editor open on main while tests run elsewhere.
Review a colleague’s PR locally
git fetch origin
git worktree add ../review-pr-123 origin/pr/123
cd ../review-pr-123
# read, run, evaluate — then remove when done
git worktree remove ../review-pr-123
How It Works Internally
When you run git worktree add, git:
- Creates a subdirectory inside
.git/worktrees/<name>/to store the worktree’s private state (itsHEAD, index, and a back-reference to the working directory path). - Writes a
.gitfile (not a directory) into the new working directory that points back to.git/worktrees/<name>/. - Checks out the requested branch into the new directory.
The object database (all commits, trees, blobs) is shared. Only the index and HEAD are per-worktree.
Limitations and When to Avoid Them
Each branch can only be checked out once
The most important constraint: a single branch cannot be checked out in two worktrees at the same time. Git enforces this to prevent the two trees from diverging their indexes silently. If you try, you’ll get an error:
fatal: 'feature/auth' is already checked out at '/home/you/projects/myapp'
Workaround: check out the second worktree at a specific commit (detached HEAD), or create a local tracking branch.
Tools that look for .git as a directory can break
Some tools assume .git is a directory, not a file. A linked worktree has a .git file instead, which confuses:
- Older git GUIs that scan for
.git/directories - Some shell prompts and plugins that detect git repos naively
- Scripts that do
test -d .git
Build systems that use the working directory path
If your build system or language tooling caches absolute paths (e.g. Cargo’s target/ directory is placed relative to the workspace root), separate worktrees work fine — each has its own target/ directory. But if a tool bakes the source path into cached artifacts, switching between worktrees may cause spurious rebuilds or cache invalidation.
Hooks run in the worktree context, not the main repo
Git hooks in .git/hooks/ apply to all worktrees. But the working directory ($GIT_WORK_TREE) will be the linked worktree path, not the main repo path. Hook scripts that assume a fixed working directory can behave unexpectedly.
Not a substitute for branches in CI
Worktrees are a local developer convenience. Don’t model CI pipelines around them — use proper branch checkouts in ephemeral CI environments instead.
Submodules require extra care
Submodules are not automatically initialized in a new worktree. You need to run git submodule update --init inside each linked worktree where you need them.
Avoid for long-lived parallel development
If you find yourself maintaining two worktrees of the same repo for weeks at a time, that’s a signal you might want two separate clones or a proper branch strategy — not worktrees. Worktrees shine for short-lived parallel work (hotfixes, reviews, quick tests), not as a permanent parallel development setup.
Quick Reference
| Task | Command |
|---|---|
| Add a worktree for an existing branch | git worktree add <path> <branch> |
| Add a worktree and create a new branch | git worktree add -b <branch> <path> <start-point> |
| List all worktrees | git worktree list |
| Remove a worktree | git worktree remove <path> |
| Clean up stale worktree records | git worktree prune |
| Move a worktree | git worktree move <old-path> <new-path> |
Writing a Lisp-to-C Compiler in Rust
This course walks you through building a complete, working compiler from scratch. You will write every component yourself — a lexer, a parser, a semantic analyser, and a code generator — ending with a program that reads MiniLisp source code and emits valid C. The compiler is written in Rust and uses the nom parser-combinator library for all parsing work.
Table of Contents
Part 1 — Foundations
- Introduction: What We’re Building
- MiniLisp Language Specification
- Compiler Architecture: The Pipeline
Part 2 — Parsing with nom
- Introduction to nom: Parser Combinators
- Setting Up the Project
- Recognizing Atoms: Integers, Booleans, Strings, Symbols
- The Abstract Syntax Tree
- Parsing Atoms with nom
- Parsing S-Expressions and Special Forms
Part 3 — Semantic Analysis
Part 4 — Code Generation
- The C Runtime Preamble
- Generating C: Atoms and Expressions
- Generating C: Definitions and Functions
- Generating C: Control Flow and Sequencing
Part 5 — Putting It Together
Part 1 — Foundations
1. Introduction: What We’re Building
A compiler is a program that transforms source code written in one language into equivalent code in another. By the end of this course you will have written one that accepts MiniLisp — a small, clean dialect of Lisp — and produces human-readable C that you can compile and run with any standard C compiler. Along the way you will implement each classic compiler stage from scratch: lexical analysis, parsing, semantic analysis, and code generation.
Why Lisp-to-C?
Lisp is the ideal source language for a first compiler because its syntax is almost trivial to parse. Every program is either an atom (a number, string, boolean, or name) or a list of expressions wrapped in parentheses. There are no operator-precedence rules, no statement-versus-expression distinctions, and no ambiguous grammars. This lets you focus on what a compiler does rather than fighting with syntax.
C is the ideal target language because it is close to the machine — you can see exactly how each Lisp construct translates to something concrete — yet still readable. You will not be writing assembly; the C compiler handles that final step.
What you will learn
By completing this course you will understand:
- Lexical analysis and parsing — how raw text becomes structured data.
- Abstract syntax trees — how compilers represent programs internally.
- Semantic analysis — how a compiler validates that a program makes sense before generating code.
- Code generation — how high-level constructs map to lower-level constructs in a target language.
- nom parser combinators — a practical, idiomatic Rust approach to parsing.
What you will build
The finished compiler is a single Rust binary. You pipe MiniLisp source in and get C source out:
$ echo '(define (square x) (* x x)) (display (square 5))' | minilisp-cc > out.c
$ cc -o out out.c && ./out
25
The compiler handles integers, booleans, strings, arithmetic, comparisons, conditionals, local bindings, sequencing, first-class named functions, and recursion. It is deliberately minimal — roughly 800 lines of Rust — so you can read and understand every line.
Prerequisites
You should be comfortable reading and writing Rust. You do not need prior experience with compilers, Lisp, or nom — this course introduces all three from scratch. You should have a working Rust toolchain (rustc, cargo) and a C compiler (cc or gcc) installed.
Course structure
The course is divided into five parts:
- Foundations — what MiniLisp looks like, what the compiler’s architecture is.
- Parsing with nom — turning source text into an AST.
- Semantic Analysis — validating the AST.
- Code Generation — turning the AST into C.
- Putting It Together — wiring the stages into a working tool and testing it.
Each section builds directly on the previous one. Code samples are cumulative: by the end, every snippet fits together into the complete compiler.
2. MiniLisp Language Specification
MiniLisp is the source language of our compiler. It is a minimal Lisp dialect with integers, booleans, strings, first-class functions, lexical scope, and a small set of built-in operators. This section defines every syntactic form precisely, gives the grammar in EBNF, and shows a complete example program so you know exactly what the compiler must handle before you write a single line of Rust.
Data types
MiniLisp has four data types:
| Type | Examples | Notes |
|---|---|---|
| Integer | 42, -7, 0 | Signed 64-bit integers |
| Boolean | #t, #f | True and false |
| String | "hello", "line\n" | Double-quoted, with \\, \", \n, \t escapes |
| Function | (lambda (x) (* x x)) | First-class, lexically scoped |
There are no floating-point numbers, characters, lists-as-data, or nil/null values. This keeps the compiler simple while still being expressive enough to write interesting programs.
Expressions
Everything in MiniLisp is an expression — there are no statements. Every form evaluates to a value.
Atoms are the simplest expressions:
42 ; integer
#t ; boolean true
#f ; boolean false
"hello" ; string
x ; symbol (variable reference)
Function calls use prefix notation inside parentheses. The first element is the operator or function; the rest are arguments:
(+ 1 2) ; => 3
(* 3 (+ 1 2)) ; => 9
(display "hi") ; prints "hi", returns void
Built-in operators:
| Operator | Arity | Description |
|---|---|---|
+, -, *, / | 2 | Arithmetic (integer division for /) |
=, <, >, <=, >= | 2 | Comparison (returns #t or #f) |
not | 1 | Boolean negation |
and, or | 2 | Boolean connectives (not short-circuiting) |
display | 1 | Print a value to stdout |
newline | 0 | Print a newline character |
Special forms
Special forms look like function calls but have special evaluation rules. The compiler recognises them by their leading keyword.
define — top-level definitions
;; Define a variable
(define pi 3)
;; Define a function (syntactic sugar)
(define (square x) (* x x))
;; The function form is equivalent to:
(define square (lambda (x) (* x x)))
if — conditional
(if (> x 0) x (- 0 x)) ; absolute value
Always three sub-expressions: condition, then-branch, else-branch.
lambda — anonymous function
(lambda (x y) (+ x y))
Parameter list followed by one or more body expressions. The value of the last body expression is the return value.
let — local bindings
(let ((x 10)
(y 20))
(+ x y))
A list of (name value) bindings followed by one or more body expressions. Bindings are not mutually visible (this is let, not let*).
begin — sequencing
(begin
(display "hello")
(newline)
42)
Evaluates each expression in order, returns the value of the last one.
Comments
A semicolon ; starts a comment that extends to the end of the line:
(define x 10) ; this is a comment
Grammar (EBNF)
program = { expr } ;
expr = atom | list ;
atom = integer | boolean | string | symbol ;
integer = [ "-" ] digit { digit } ;
boolean = "#t" | "#f" ;
string = '"' { string_char } '"' ;
string_char = escape | any character except '"' and '\\' ;
escape = '\\' ( '"' | '\\' | 'n' | 't' ) ;
symbol = symbol_start { symbol_cont } ;
symbol_start = letter | '_' | '+' | '-' | '*' | '/' | '='
| '<' | '>' | '!' | '?' ;
symbol_cont = symbol_start | digit ;
list = '(' { ws } { expr { ws } } ')' ;
ws = ' ' | '\t' | '\n' | '\r' | comment ;
comment = ';' { any character except '\n' } '\n' ;
Complete example program
;; Recursive factorial
(define (fact n)
(if (= n 0)
1
(* n (fact (- n 1)))))
;; Compute and display 10!
(display (fact 10))
(newline)
This program defines a recursive factorial function, computes 10!, and prints 3628800 followed by a newline.
3. Compiler Architecture: The Pipeline
Our compiler is a classic multi-stage pipeline: source text passes through a parser, producing an AST; the AST passes through a semantic analyser, which validates scope and form usage; the validated AST passes through a code generator, which emits C. This section maps that pipeline onto the module structure you will build and explains how data and errors flow between stages.
The four stages
MiniLisp source text
│
▼
┌─────────────┐
│ Parser │ (nom combinators)
│ src/parser │
└──────┬──────┘
│ Vec<Expr>
▼
┌─────────────┐
│ Analyser │ (scope checking, form validation)
│ src/analysis │
└──────┬──────┘
│ Vec<Expr> (same type, now validated)
▼
┌─────────────┐
│ Code Gen │ (AST → C strings)
│ src/codegen │
└──────┬──────┘
│ String
▼
C source code
-
Parser — converts raw text into a
Vec<Expr>, whereExpris the AST node type. Uses nom parser combinators. Reports syntax errors with position information. -
Semantic Analyser — walks the AST, builds a symbol table to track variable scopes, and validates that special forms have the correct shape (e.g.
ifhas exactly three sub-expressions). Reports semantic errors. The analyser does not transform the AST — it only checks it. -
Code Generator — recursively traverses the validated AST and produces a
Stringcontaining the complete C program. Emits the runtime preamble first, then forward declarations, then function definitions, thenmain. -
Driver — the
mainfunction wires the stages together: read input, parse, analyse, generate, write output. It also handles CLI arguments and error display.
Module layout
minilisp-cc/
├── Cargo.toml
└── src/
├── main.rs # CLI entry point, pipeline driver
├── ast.rs # Expr enum and Display impl
├── parser.rs # nom parsers
├── analysis.rs # symbol table, scope checking, form validation
└── codegen.rs # C code generation
Each module has a single, clear responsibility. Data flows in one direction: main calls parser, passes the result to analysis, passes that to codegen, and writes the output.
Data flow
The key data type is Expr — the AST node. It is defined once in ast.rs and used by all stages. The parser produces Vec<Expr>. The analyser consumes Vec<Expr> and returns Vec<Expr> (unchanged, but validated). The code generator consumes Vec<Expr> and returns String.
Error handling
Each stage defines its own error type:
#![allow(unused)]
fn main() {
/// A compiler error with a human-readable message.
#[derive(Debug)]
pub struct CompileError {
pub message: String,
}
}
In a production compiler you would track source positions (line and column) in every AST node. For simplicity, our compiler attaches position information only in error messages from the parser (nom provides this automatically) and uses descriptive messages for semantic errors.
The driver collects errors from each stage. If parsing fails, the compiler stops. If analysis finds errors, it reports all of them before stopping (so you see every problem at once, not one at a time). Code generation assumes a valid AST and does not produce errors.
Why this architecture?
Separating the compiler into independent stages has practical benefits:
- Testability — you can unit-test each stage in isolation. Feed a string to the parser and assert on the AST. Feed an AST to the analyser and assert on the errors. Feed a valid AST to the code generator and assert on the C output.
- Debuggability — when something goes wrong, you know which stage to look at.
- Extensibility — adding a new language feature means adding an AST variant, a parser case, an analyser check, and a code generation case. Each change is localised to one module.
Part 2 — Parsing with nom
4. Introduction to nom: Parser Combinators
nom is a parser-combinator library: instead of writing a grammar file and running a generator, you write small Rust functions that each recognise a fragment of input, then combine them into larger parsers. This section introduces the core IResult<I, O, E> type, walks through the essential combinators (tag, char, alt, many0, map, tuple, delimited, preceded), and shows how to write, compose, and test parsers before you apply any of this to MiniLisp.
What is a parser combinator?
A parser is a function that takes some input and either succeeds (returning the parsed value and the remaining input) or fails (returning an error). A combinator is a function that takes one or more parsers and returns a new parser. By combining small parsers you build complex ones without writing a single regular expression or grammar file.
In nom, every parser has this signature:
#![allow(unused)]
fn main() {
fn parser(input: &str) -> IResult<&str, Output>
}
IResult<I, O> is defined roughly as:
#![allow(unused)]
fn main() {
type IResult<I, O> = Result<(I, O), Err<Error<I>>>;
}
On success, you get a tuple of (remaining_input, parsed_value). On failure, you get an error indicating where parsing failed.
Essential combinators
Here are the combinators you will use most. Each is a function in the nom crate.
tag — match a literal string
#![allow(unused)]
fn main() {
use nom::bytes::complete::tag;
fn parse_hello(input: &str) -> IResult<&str, &str> {
tag("hello")(input)
}
assert_eq!(parse_hello("hello world"), Ok((" world", "hello")));
}
char — match a single character
#![allow(unused)]
fn main() {
use nom::character::complete::char;
// Matches the character '('
let result = char('(')("(abc)");
assert_eq!(result, Ok(("abc)", '(')));
}
alt — try alternatives in order
#![allow(unused)]
fn main() {
use nom::branch::alt;
use nom::bytes::complete::tag;
fn parse_bool(input: &str) -> IResult<&str, &str> {
alt((tag("#t"), tag("#f")))(input)
}
assert_eq!(parse_bool("#t rest"), Ok((" rest", "#t")));
assert_eq!(parse_bool("#f rest"), Ok((" rest", "#f")));
}
alt tries each parser in order and returns the result of the first one that succeeds.
map — transform a parser’s output
#![allow(unused)]
fn main() {
use nom::combinator::map;
use nom::bytes::complete::tag;
fn parse_true(input: &str) -> IResult<&str, bool> {
map(tag("#t"), |_| true)(input)
}
assert_eq!(parse_true("#t"), Ok(("", true)));
}
tuple — sequence multiple parsers
#![allow(unused)]
fn main() {
use nom::sequence::tuple;
use nom::character::complete::char;
use nom::bytes::complete::tag;
fn parse_pair(input: &str) -> IResult<&str, (char, &str, char)> {
tuple((char('('), tag("hi"), char(')')))(input)
}
assert_eq!(parse_pair("(hi)"), Ok(("", ('(', "hi", ')'))));
}
delimited — match something between two delimiters
#![allow(unused)]
fn main() {
use nom::sequence::delimited;
use nom::character::complete::char;
use nom::bytes::complete::tag;
fn parse_parens(input: &str) -> IResult<&str, &str> {
delimited(char('('), tag("hi"), char(')'))(input)
}
assert_eq!(parse_parens("(hi)"), Ok(("", "hi")));
}
delimited discards the opening and closing delimiters and returns only the inner value.
preceded — match a prefix and discard it
#![allow(unused)]
fn main() {
use nom::sequence::preceded;
use nom::character::complete::char;
use nom::bytes::complete::tag;
fn parse_hash_t(input: &str) -> IResult<&str, &str> {
preceded(char('#'), tag("t"))(input)
}
assert_eq!(parse_hash_t("#t"), Ok(("", "t")));
}
many0 — match zero or more repetitions
#![allow(unused)]
fn main() {
use nom::multi::many0;
use nom::character::complete::char;
fn parse_stars(input: &str) -> IResult<&str, Vec<char>> {
many0(char('*'))(input)
}
assert_eq!(parse_stars("***abc"), Ok(("abc", vec!['*', '*', '*'])));
assert_eq!(parse_stars("abc"), Ok(("abc", vec![])));
}
Composing parsers
The power of combinators is composition. Here is a parser that recognises a parenthesised, comma-separated pair of integers:
#![allow(unused)]
fn main() {
use nom::character::complete::{char, digit1, multispace0};
use nom::combinator::map_res;
use nom::sequence::{delimited, separated_pair};
fn parse_int(input: &str) -> IResult<&str, i64> {
map_res(digit1, |s: &str| s.parse::<i64>())(input)
}
fn parse_pair(input: &str) -> IResult<&str, (i64, i64)> {
delimited(
char('('),
separated_pair(parse_int, char(','), parse_int),
char(')'),
)(input)
}
assert_eq!(parse_pair("(3,7)"), Ok(("", (3, 7))));
}
Each piece is a small, testable function. You build up complexity by snapping pieces together.
Testing nom parsers
Testing a nom parser is straightforward: call the function with a test input and assert on the result.
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_int() {
assert_eq!(parse_int("42 rest"), Ok((" rest", 42)));
}
#[test]
fn test_parse_int_fails_on_alpha() {
assert!(parse_int("abc").is_err());
}
}
}
Always test both the happy path and failure cases. Verify that the remaining input is what you expect — a parser that accidentally consumes too much or too little is a common bug.
Exercises
- Write a parser that recognises the string
"true"and returns the Rustboolvaluetrue. - Write a parser that recognises either
"yes"or"no"and returns abool. - Write a parser for a parenthesised integer like
(42)that returns thei64value inside.
5. Setting Up the Project
You will create a new Rust binary crate for the compiler, add nom and any other dependencies to Cargo.toml, and lay out the module structure that the rest of the course fills in. By the end of this section you will have a project that compiles, a src/main.rs that reads from stdin, and placeholder modules for each compiler stage.
Create the crate
cargo new minilisp-cc
cd minilisp-cc
This gives you a standard Rust binary project with src/main.rs and Cargo.toml.
Add dependencies
Edit Cargo.toml:
[package]
name = "minilisp-cc"
version = "0.1.0"
edition = "2021"
[dependencies]
nom = "7"
[profile.release]
opt-level = "z"
lto = true
strip = true
codegen-units = 1
nom version 7 is the current stable release. The release profile settings minimise binary size.
Create the module files
Create one file per compiler stage:
touch src/ast.rs src/parser.rs src/analysis.rs src/codegen.rs
Your src/ directory should now look like this:
src/
├── main.rs
├── ast.rs
├── parser.rs
├── analysis.rs
└── codegen.rs
Wire up modules in main.rs
Replace src/main.rs with:
mod ast;
mod parser;
mod analysis;
mod codegen;
use std::io::{self, Read};
fn main() {
// Read all of stdin into a string
let mut input = String::new();
io::stdin()
.read_to_string(&mut input)
.expect("failed to read stdin");
// TODO: parse, analyse, generate
println!("Read {} bytes", input.len());
}
Add placeholder content to each module
In src/ast.rs:
#![allow(unused)]
fn main() {
/// A MiniLisp expression — the core AST node type.
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
// Variants will be added in §7
}
}
In src/parser.rs:
#![allow(unused)]
fn main() {
use crate::ast::Expr;
/// Parse MiniLisp source code into a list of expressions.
pub fn parse(_input: &str) -> Result<Vec<Expr>, String> {
todo!("implement in §8–§9")
}
}
In src/analysis.rs:
#![allow(unused)]
fn main() {
use crate::ast::Expr;
/// Validate a list of parsed expressions.
pub fn analyse(_exprs: &[Expr]) -> Result<(), Vec<String>> {
todo!("implement in §10–§11")
}
}
In src/codegen.rs:
#![allow(unused)]
fn main() {
use crate::ast::Expr;
/// Generate C source code from a list of validated expressions.
pub fn generate(_exprs: &[Expr]) -> String {
todo!("implement in §12–§15")
}
}
Verify it compiles
cargo check
You should see a clean compilation with no errors. There will be warnings about unused imports and todo!() macros — that is fine. These will disappear as you fill in the implementations.
Checkpoint
You now have a project skeleton with:
- A binary crate that reads from stdin.
- Four modules with clear responsibilities and placeholder signatures.
- nom as a dependency, ready to use.
The rest of the course fills in these modules one at a time.
6. Recognizing Atoms: Integers, Booleans, Strings, Symbols
Before building the full parser, you need nom parsers for each atomic value in MiniLisp: signed integers, boolean literals #t and #f, double-quoted strings with escape sequences, and symbol identifiers. This section develops each atom parser in isolation, explains the nom combinators used, and provides exercises to test your understanding before the parts are assembled into the full parser.
Parsing integers
A MiniLisp integer is an optional minus sign followed by one or more digits. We use recognize to capture the matched text as a single slice, then parse it into an i64:
#![allow(unused)]
fn main() {
use nom::IResult;
use nom::branch::alt;
use nom::character::complete::{char, digit1};
use nom::combinator::{map_res, opt, recognize};
use nom::sequence::pair;
/// Parse a signed integer literal.
fn parse_integer(input: &str) -> IResult<&str, i64> {
map_res(
recognize(pair(opt(char('-')), digit1)),
|s: &str| s.parse::<i64>(),
)(input)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn positive_integer() {
assert_eq!(parse_integer("42 rest"), Ok((" rest", 42)));
}
#[test]
fn negative_integer() {
assert_eq!(parse_integer("-7"), Ok(("", -7)));
}
#[test]
fn zero() {
assert_eq!(parse_integer("0"), Ok(("", 0)));
}
}
}
How it works:
opt(char('-'))optionally matches a leading minus sign.digit1matches one or more ASCII digits.pairsequences them so we match-42as a unit.recognizecaptures the entire matched span as a&strinstead of the pair tuple.map_resappliesstr::parse::<i64>()to convert the string to an integer, propagating parse errors as nom errors.
Parsing booleans
Booleans are the literals #t (true) and #f (false):
#![allow(unused)]
fn main() {
use nom::bytes::complete::tag;
use nom::combinator::map;
/// Parse a boolean literal: `#t` or `#f`.
fn parse_boolean(input: &str) -> IResult<&str, bool> {
alt((
map(tag("#t"), |_| true),
map(tag("#f"), |_| false),
))(input)
}
#[test]
fn booleans() {
assert_eq!(parse_boolean("#t"), Ok(("", true)));
assert_eq!(parse_boolean("#f rest"), Ok((" rest", false)));
}
}
How it works:
tag("#t")matches the literal string#t.maptransforms the matched string into the Rustboolvaluetrue.alttries#tfirst, then#f.
Parsing strings
Strings are enclosed in double quotes and support escape sequences \\, \", \n, and \t:
#![allow(unused)]
fn main() {
use nom::bytes::complete::{tag, take_while1};
use nom::character::complete::char;
use nom::multi::many0;
use nom::sequence::delimited;
/// Parse a single character inside a string (handling escapes).
fn string_char(input: &str) -> IResult<&str, char> {
alt((
// Escape sequences
map(tag("\\\\"), |_| '\\'),
map(tag("\\\""), |_| '"'),
map(tag("\\n"), |_| '\n'),
map(tag("\\t"), |_| '\t'),
// Any character that is not a quote or backslash
nom::character::complete::none_of("\"\\"),
))(input)
}
/// Parse a double-quoted string literal.
fn parse_string(input: &str) -> IResult<&str, String> {
delimited(
char('"'),
map(many0(string_char), |chars| chars.into_iter().collect()),
char('"'),
)(input)
}
#[test]
fn simple_string() {
assert_eq!(parse_string("\"hello\""), Ok(("", "hello".to_string())));
}
#[test]
fn string_with_escapes() {
assert_eq!(
parse_string("\"line\\none\""),
Ok(("", "line\none".to_string()))
);
}
#[test]
fn empty_string() {
assert_eq!(parse_string("\"\""), Ok(("", String::new())));
}
}
How it works:
string_charmatches either an escape sequence or a literal character.- Escape sequences are tried first (the
tag("\\n")branch) becausenone_ofwould otherwise consume the backslash. many0(string_char)collects zero or more characters into aVec<char>.mapconverts theVec<char>into aString.delimitedstrips the surrounding double quotes.
Parsing symbols
A symbol is an identifier that starts with a letter, underscore, or one of the operator characters (+, -, *, /, =, <, >, !, ?), followed by zero or more of those characters or digits:
#![allow(unused)]
fn main() {
use nom::combinator::recognize;
use nom::multi::many0_count;
use nom::character::complete::one_of;
/// Match a character that can start a symbol.
fn symbol_start(input: &str) -> IResult<&str, char> {
alt((
nom::character::complete::alpha1
.map(|s: &str| s.chars().next().unwrap()),
one_of("_+-*/=<>!?"),
))(input)
}
/// Parse a symbol identifier.
fn parse_symbol(input: &str) -> IResult<&str, String> {
map(
recognize(pair(
one_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_+-*/=<>!?"),
many0_count(one_of(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+-*/=<>!?",
)),
)),
String::from,
)(input)
}
#[test]
fn simple_symbol() {
assert_eq!(parse_symbol("foo"), Ok(("", "foo".to_string())));
}
#[test]
fn operator_symbol() {
assert_eq!(parse_symbol("+ rest"), Ok((" rest", "+".to_string())));
}
#[test]
fn hyphenated_symbol() {
assert_eq!(
parse_symbol("my-var"),
Ok(("", "my-var".to_string()))
);
}
}
How it works:
one_of(...)matches a single character from the given set.- The first
one_ofensures the symbol starts with a valid leading character (not a digit). many0_countmatches zero or more continuation characters but only counts them (we userecognizeto capture the whole span).recognizegives us the entire matched text as a&str, which we convert to aString.
Important ordering note
When combining these parsers later, the order matters. parse_integer and parse_symbol can both start with -. The rule is: try parse_integer first, and only fall back to parse_symbol if it fails. We handle this with alt in the next section.
Exercises
- Extend
parse_stringto support the escape sequence\0(null character). - Write tests for edge cases: the symbol
-(just a minus sign, a valid symbol), the integer-0, and the string"\\\""(a string containing a backslash followed by a quote). - What happens if you try
parse_integeron the input"abc"? Read the nom error and explain what it means.
7. The Abstract Syntax Tree
The parser’s output is an Abstract Syntax Tree — a Rust data structure that captures the meaning of a MiniLisp program without the syntactic noise of parentheses and whitespace. This section defines the Expr enum and its variants, discusses why the tree is structured the way it is, and implements Display so you can inspect parse results during development.
What is an AST?
An Abstract Syntax Tree strips away syntactic details — parentheses, whitespace, comments — and represents only the structural meaning of a program. Two MiniLisp programs that differ only in whitespace or comments produce identical ASTs. The AST is the compiler’s internal representation from the parser onwards; every subsequent stage operates on it.
Designing the Expr enum
There are two common approaches for representing Lisp in an AST:
Option A — Generic lists. Everything is either an atom or a List(Vec<Expr>). Special forms like if and define are not distinguished during parsing; they are recognised later during semantic analysis.
Option B — Specific variants. Each special form gets its own enum variant with typed fields. The parser does more work, but later stages deal with well-typed data.
We use Option B. The benefit is exhaustive pattern matching: if you add a new variant, the Rust compiler forces you to handle it in the analyser and code generator. Bugs become compile errors instead of runtime surprises.
The complete Expr enum
Add this to src/ast.rs:
#![allow(unused)]
fn main() {
/// A MiniLisp expression — the core AST node type.
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
/// Integer literal: `42`, `-7`
Int(i64),
/// Boolean literal: `#t`, `#f`
Bool(bool),
/// String literal: `"hello"`
Str(String),
/// Symbol (variable or function name): `x`, `+`, `my-func`
Symbol(String),
/// Function call: `(f arg1 arg2 ...)`
/// The first element is the function; the rest are arguments.
Call(Box<Expr>, Vec<Expr>),
/// Variable definition: `(define name value)`
Define(String, Box<Expr>),
/// Function definition: `(define (name params...) body...)`
DefineFunc(String, Vec<String>, Vec<Expr>),
/// Conditional: `(if condition then else)`
If(Box<Expr>, Box<Expr>, Box<Expr>),
/// Anonymous function: `(lambda (params...) body...)`
Lambda(Vec<String>, Vec<Expr>),
/// Local bindings: `(let ((name val) ...) body...)`
Let(Vec<(String, Expr)>, Vec<Expr>),
/// Sequencing: `(begin expr1 expr2 ...)`
Begin(Vec<Expr>),
}
}
Design notes:
Box<Expr>is used wherever a single sub-expression is required. WithoutBox, the enum would be infinitely sized becauseExprcontainsExpr.Vec<Expr>is used for variable-length lists of sub-expressions (function arguments, body expressions).DefineandDefineFuncare separate variants because they have different structures.(define x 10)binds a name to a value;(define (f x) ...)binds a name to a function with parameters and a body.Callstores the function expression (which could be a symbol, a lambda, or even another call) as aBox<Expr>, and the arguments as aVec<Expr>.
How MiniLisp maps to the AST
| MiniLisp | AST |
|---|---|
42 | Expr::Int(42) |
#t | Expr::Bool(true) |
"hi" | Expr::Str("hi".into()) |
x | Expr::Symbol("x".into()) |
(+ 1 2) | Expr::Call(Box::new(Expr::Symbol("+".into())), vec![Expr::Int(1), Expr::Int(2)]) |
(define x 10) | Expr::Define("x".into(), Box::new(Expr::Int(10))) |
(if #t 1 0) | Expr::If(Box::new(Expr::Bool(true)), Box::new(Expr::Int(1)), Box::new(Expr::Int(0))) |
Implementing Display
A Display implementation lets you print ASTs in a readable Lisp-like format, which is invaluable during debugging:
#![allow(unused)]
fn main() {
use std::fmt;
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expr::Int(n) => write!(f, "{}", n),
Expr::Bool(b) => write!(f, "{}", if *b { "#t" } else { "#f" }),
Expr::Str(s) => write!(f, "\"{}\"", s),
Expr::Symbol(s) => write!(f, "{}", s),
Expr::Call(func, args) => {
write!(f, "({}", func)?;
for arg in args {
write!(f, " {}", arg)?;
}
write!(f, ")")
}
Expr::Define(name, val) => write!(f, "(define {} {})", name, val),
Expr::DefineFunc(name, params, body) => {
write!(f, "(define ({}", name)?;
for p in params {
write!(f, " {}", p)?;
}
write!(f, ")")?;
for expr in body {
write!(f, " {}", expr)?;
}
write!(f, ")")
}
Expr::If(cond, then, els) => {
write!(f, "(if {} {} {})", cond, then, els)
}
Expr::Lambda(params, body) => {
write!(f, "(lambda (")?;
for (i, p) in params.iter().enumerate() {
if i > 0 { write!(f, " ")?; }
write!(f, "{}", p)?;
}
write!(f, ")")?;
for expr in body {
write!(f, " {}", expr)?;
}
write!(f, ")")
}
Expr::Let(bindings, body) => {
write!(f, "(let (")?;
for (i, (name, val)) in bindings.iter().enumerate() {
if i > 0 { write!(f, " ")?; }
write!(f, "({} {})", name, val)?;
}
write!(f, ")")?;
for expr in body {
write!(f, " {}", expr)?;
}
write!(f, ")")
}
Expr::Begin(exprs) => {
write!(f, "(begin")?;
for expr in exprs {
write!(f, " {}", expr)?;
}
write!(f, ")")
}
}
}
}
}
You can now println!("{}", expr) any AST node and get readable output.
8. Parsing Atoms with nom
With atom parsers and the AST defined, this section assembles them into a single parse_atom function that recognises any MiniLisp atom and returns the corresponding Expr variant. You will use alt to try each alternative in turn, learn how nom reports errors and how to interpret them, and write unit tests that verify correct parsing of every atom type.
Connecting parsers to the AST
In §6 we wrote standalone parsers that returned Rust primitives (i64, bool, String). Now we wrap each one with map to produce Expr variants.
Add these to src/parser.rs:
#![allow(unused)]
fn main() {
use nom::IResult;
use nom::branch::alt;
use nom::bytes::complete::tag;
use nom::character::complete::{char, digit1, none_of, one_of};
use nom::combinator::{map, map_res, opt, recognize};
use nom::multi::{many0, many0_count};
use nom::sequence::{delimited, pair};
use crate::ast::Expr;
/// Parse an integer literal and return `Expr::Int`.
fn parse_int(input: &str) -> IResult<&str, Expr> {
map(
map_res(
recognize(pair(opt(char('-')), digit1)),
|s: &str| s.parse::<i64>(),
),
Expr::Int,
)(input)
}
/// Parse a boolean literal and return `Expr::Bool`.
fn parse_bool(input: &str) -> IResult<&str, Expr> {
alt((
map(tag("#t"), |_| Expr::Bool(true)),
map(tag("#f"), |_| Expr::Bool(false)),
))(input)
}
/// Parse a string literal and return `Expr::Str`.
fn parse_str(input: &str) -> IResult<&str, Expr> {
map(
delimited(
char('"'),
map(many0(string_char), |chars| {
chars.into_iter().collect::<String>()
}),
char('"'),
),
Expr::Str,
)(input)
}
fn string_char(input: &str) -> IResult<&str, char> {
alt((
map(tag("\\\\"), |_| '\\'),
map(tag("\\\""), |_| '"'),
map(tag("\\n"), |_| '\n'),
map(tag("\\t"), |_| '\t'),
none_of("\"\\"),
))(input)
}
/// Parse a symbol and return `Expr::Symbol`.
fn parse_symbol(input: &str) -> IResult<&str, Expr> {
map(
map(
recognize(pair(
one_of(
"abcdefghijklmnopqrstuvwxyz\
ABCDEFGHIJKLMNOPQRSTUVWXYZ\
_+-*/=<>!?",
),
many0_count(one_of(
"abcdefghijklmnopqrstuvwxyz\
ABCDEFGHIJKLMNOPQRSTUVWXYZ\
0123456789_+-*/=<>!?",
)),
)),
String::from,
),
Expr::Symbol,
)(input)
}
}
Combining atoms with alt
The parse_atom function tries each atom parser in order:
#![allow(unused)]
fn main() {
/// Parse any atom: integer, boolean, string, or symbol.
pub fn parse_atom(input: &str) -> IResult<&str, Expr> {
alt((parse_bool, parse_str, parse_int, parse_symbol))(input)
}
}
Ordering matters. Booleans are tried first because #t and #f start with #, which no other atom type uses. Strings are next because they start with ". Integers come before symbols because a bare - is a valid symbol — we want -42 to parse as an integer, not as the symbol - followed by 42. If the integer parser fails (no digits after the minus), nom backtracks and the symbol parser gets a chance.
Understanding nom errors
When all alternatives in alt fail, nom returns an error pointing to the position where parsing failed. For example:
#![allow(unused)]
fn main() {
let result = parse_atom("@invalid");
// Err(Err::Error(Error { input: "@invalid", code: ErrorKind::OneOf }))
}
The error says: at position @invalid, the OneOf combinator (inside parse_symbol) could not match any expected character. This is the last alternative that was tried.
During development, you can use nom::error::VerboseError instead of the default error type to get a stack trace of which combinators were attempted. This is helpful for debugging but adds overhead, so switch back to the default for production.
Unit tests
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn atom_integer() {
assert_eq!(parse_atom("42"), Ok(("", Expr::Int(42))));
assert_eq!(parse_atom("-7 rest"), Ok((" rest", Expr::Int(-7))));
}
#[test]
fn atom_boolean() {
assert_eq!(parse_atom("#t"), Ok(("", Expr::Bool(true))));
assert_eq!(parse_atom("#f"), Ok(("", Expr::Bool(false))));
}
#[test]
fn atom_string() {
assert_eq!(
parse_atom("\"hello\""),
Ok(("", Expr::Str("hello".to_string())))
);
}
#[test]
fn atom_symbol() {
assert_eq!(
parse_atom("foo"),
Ok(("", Expr::Symbol("foo".to_string())))
);
assert_eq!(
parse_atom("+"),
Ok(("", Expr::Symbol("+".to_string())))
);
}
#[test]
fn minus_is_symbol_not_integer() {
// A bare "-" with no digits is the symbol "-"
assert_eq!(
parse_atom("- rest"),
Ok((" rest", Expr::Symbol("-".to_string())))
);
}
}
}
The last test is important: it verifies that a bare minus sign is parsed as a symbol, not as a failed integer parse. This is the ordering logic from alt at work.
9. Parsing S-Expressions and Special Forms
S-expressions are parenthesised lists: the heart of Lisp syntax. This section extends the parser to handle arbitrarily nested lists, whitespace between elements, and comments. It then lifts special forms — define, if, lambda, let, begin — out of the generic list parser so they become distinct AST variants, and covers how to handle recursive parsers in nom without running into borrow-checker problems.
Whitespace and comments
Before parsing lists, we need a combinator that skips whitespace and comments:
#![allow(unused)]
fn main() {
use nom::character::complete::{multispace0, not_line_ending};
/// Skip whitespace and comments.
fn ws(input: &str) -> IResult<&str, ()> {
let mut remaining = input;
loop {
// Skip whitespace
let (r, _) = multispace0(remaining)?;
remaining = r;
// If we see a comment, skip it
if remaining.starts_with(';') {
let (r, _) = not_line_ending(remaining)?;
remaining = r;
} else {
break;
}
}
Ok((remaining, ()))
}
}
This loop handles multiple comments and whitespace in any order.
Parsing a generic list
A list is an opening parenthesis, zero or more whitespace-separated expressions, and a closing parenthesis. The key challenge is that parse_expr (which parses any expression, including lists) must call itself recursively — a list can contain lists.
In nom, recursive parsers require a function boundary. You cannot use closures because closures capture references, creating lifetime issues. Instead, use regular fn items:
#![allow(unused)]
fn main() {
/// Parse any expression: atom or list.
pub fn parse_expr(input: &str) -> IResult<&str, Expr> {
let (input, _) = ws(input)?;
alt((parse_list, parse_atom))(input)
}
/// Parse a parenthesised list of expressions.
fn parse_list(input: &str) -> IResult<&str, Expr> {
let (input, _) = char('(')(input)?;
let (input, _) = ws(input)?;
// Peek at the first symbol to detect special forms
// We'll try special form parsers first
let result = alt((
parse_define,
parse_if,
parse_lambda,
parse_let,
parse_begin,
parse_call,
))(input);
let (input, expr) = result?;
let (input, _) = ws(input)?;
let (input, _) = char(')')(input)?;
Ok((input, expr))
}
}
Parsing special forms
Each special form parser matches the keyword and then parses the form’s specific structure. These parsers consume everything between the parentheses — the opening and closing parens are handled by parse_list.
define — two forms:
#![allow(unused)]
fn main() {
use nom::sequence::preceded;
/// Parse the inside of a `define` form.
fn parse_define(input: &str) -> IResult<&str, Expr> {
let (input, _) = tag("define")(input)?;
let (input, _) = ws(input)?;
// Try function definition form: (define (name params...) body...)
if input.starts_with('(') {
let (input, _) = char('(')(input)?;
let (input, _) = ws(input)?;
let (input, name) = parse_symbol_name(input)?;
let (input, params) = many0(preceded(ws, parse_symbol_name))(input)?;
let (input, _) = ws(input)?;
let (input, _) = char(')')(input)?;
let (input, body) = many1_expr(input)?;
return Ok((input, Expr::DefineFunc(name, params, body)));
}
// Variable definition form: (define name value)
let (input, name) = parse_symbol_name(input)?;
let (input, _) = ws(input)?;
let (input, value) = parse_expr(input)?;
Ok((input, Expr::Define(name, Box::new(value))))
}
/// Parse a bare symbol name (returns String, not Expr).
fn parse_symbol_name(input: &str) -> IResult<&str, String> {
map(
recognize(pair(
one_of(
"abcdefghijklmnopqrstuvwxyz\
ABCDEFGHIJKLMNOPQRSTUVWXYZ\
_+-*/=<>!?",
),
many0_count(one_of(
"abcdefghijklmnopqrstuvwxyz\
ABCDEFGHIJKLMNOPQRSTUVWXYZ\
0123456789_+-*/=<>!?",
)),
)),
String::from,
)(input)
}
/// Parse one or more expressions (used for body forms).
fn many1_expr(input: &str) -> IResult<&str, Vec<Expr>> {
let (input, first) = preceded(ws, parse_expr)(input)?;
let (input, mut rest) = many0(preceded(ws, parse_expr))(input)?;
rest.insert(0, first);
Ok((input, rest))
}
}
if — exactly three sub-expressions:
#![allow(unused)]
fn main() {
fn parse_if(input: &str) -> IResult<&str, Expr> {
let (input, _) = tag("if")(input)?;
let (input, _) = ws(input)?;
let (input, cond) = parse_expr(input)?;
let (input, _) = ws(input)?;
let (input, then) = parse_expr(input)?;
let (input, _) = ws(input)?;
let (input, els) = parse_expr(input)?;
Ok((
input,
Expr::If(Box::new(cond), Box::new(then), Box::new(els)),
))
}
}
lambda — parameter list and body:
#![allow(unused)]
fn main() {
fn parse_lambda(input: &str) -> IResult<&str, Expr> {
let (input, _) = tag("lambda")(input)?;
let (input, _) = ws(input)?;
let (input, _) = char('(')(input)?;
let (input, params) = many0(preceded(ws, parse_symbol_name))(input)?;
let (input, _) = ws(input)?;
let (input, _) = char(')')(input)?;
let (input, body) = many1_expr(input)?;
Ok((input, Expr::Lambda(params, body)))
}
}
let — binding list and body:
#![allow(unused)]
fn main() {
fn parse_let(input: &str) -> IResult<&str, Expr> {
let (input, _) = tag("let")(input)?;
let (input, _) = ws(input)?;
let (input, _) = char('(')(input)?;
let (input, bindings) = many0(preceded(ws, parse_binding))(input)?;
let (input, _) = ws(input)?;
let (input, _) = char(')')(input)?;
let (input, body) = many1_expr(input)?;
Ok((input, Expr::Let(bindings, body)))
}
fn parse_binding(input: &str) -> IResult<&str, (String, Expr)> {
let (input, _) = char('(')(input)?;
let (input, _) = ws(input)?;
let (input, name) = parse_symbol_name(input)?;
let (input, _) = ws(input)?;
let (input, val) = parse_expr(input)?;
let (input, _) = ws(input)?;
let (input, _) = char(')')(input)?;
Ok((input, (name, val)))
}
}
begin — one or more expressions:
#![allow(unused)]
fn main() {
fn parse_begin(input: &str) -> IResult<&str, Expr> {
let (input, _) = tag("begin")(input)?;
let (input, exprs) = many1_expr(input)?;
Ok((input, Expr::Begin(exprs)))
}
}
Generic function call — anything else:
#![allow(unused)]
fn main() {
fn parse_call(input: &str) -> IResult<&str, Expr> {
let (input, func) = parse_expr(input)?;
let (input, args) = many0(preceded(ws, parse_expr))(input)?;
Ok((input, Expr::Call(Box::new(func), args)))
}
}
Parsing a program
A program is zero or more top-level expressions:
#![allow(unused)]
fn main() {
/// Parse a complete MiniLisp program.
pub fn parse_program(input: &str) -> IResult<&str, Vec<Expr>> {
let (input, _) = ws(input)?;
let (input, exprs) = many0(preceded(ws, parse_expr))(input)?;
let (input, _) = ws(input)?;
Ok((input, exprs))
}
}
Handling recursion and the borrow checker
The recursive call from parse_list → parse_expr → parse_list works because both are plain fn items, not closures. nom processes &str slices, so there are no ownership issues — each parser borrows the input, consumes some of it, and returns the remainder.
If you find yourself wanting to use closures for recursive parsers, the workaround is to wrap the recursive call in a helper function or use nom::combinator::cut to improve error messages at the boundary.
Tests
#![allow(unused)]
fn main() {
#[test]
fn parse_simple_call() {
let (_, expr) = parse_expr("(+ 1 2)").unwrap();
assert_eq!(
expr,
Expr::Call(
Box::new(Expr::Symbol("+".to_string())),
vec![Expr::Int(1), Expr::Int(2)],
)
);
}
#[test]
fn parse_nested() {
let (_, expr) = parse_expr("(* 3 (+ 1 2))").unwrap();
match expr {
Expr::Call(_, args) => assert_eq!(args.len(), 2),
_ => panic!("expected Call"),
}
}
#[test]
fn parse_define_var() {
let (_, expr) = parse_expr("(define x 42)").unwrap();
assert_eq!(
expr,
Expr::Define("x".to_string(), Box::new(Expr::Int(42)))
);
}
#[test]
fn parse_define_func() {
let (_, expr) = parse_expr("(define (f x) x)").unwrap();
match expr {
Expr::DefineFunc(name, params, _) => {
assert_eq!(name, "f");
assert_eq!(params, vec!["x".to_string()]);
}
_ => panic!("expected DefineFunc"),
}
}
#[test]
fn parse_if_expr() {
let (_, expr) = parse_expr("(if #t 1 0)").unwrap();
match expr {
Expr::If(_, _, _) => {} // ok
_ => panic!("expected If"),
}
}
#[test]
fn parse_with_comments() {
let input = "; comment\n(+ 1 2)";
let (_, expr) = parse_expr(input).unwrap();
match expr {
Expr::Call(_, _) => {} // ok
_ => panic!("expected Call"),
}
}
}
Part 3 — Semantic Analysis
10. Symbol Tables and Scope
A symbol table maps names to their definitions. This section walks through a scope-aware traversal of the AST that builds a symbol table, resolves every symbol reference to its definition, and reports helpful errors for undefined names or names used outside their scope. You will implement a simple environment chain — the standard technique for representing nested lexical scopes.
What is a symbol table?
A symbol table answers the question: “when the program says x, what does x refer to?” In MiniLisp, x could be a top-level definition, a function parameter, or a local let binding. The symbol table tracks every name that is in scope at every point in the program.
Environment chains
An environment chain is a stack of scopes. Each scope is a set of names. When you look up a name, you search from the innermost scope outward. If no scope contains the name, it is undefined.
#![allow(unused)]
fn main() {
use std::collections::HashSet;
/// A chain of lexical scopes.
pub struct Env {
scopes: Vec<HashSet<String>>,
}
impl Env {
/// Create a new environment with a single empty global scope.
pub fn new() -> Self {
Env {
scopes: vec![HashSet::new()],
}
}
/// Push a new scope (entering a function, let, etc.).
pub fn push_scope(&mut self) {
self.scopes.push(HashSet::new());
}
/// Pop the innermost scope (leaving a function, let, etc.).
pub fn pop_scope(&mut self) {
self.scopes.pop();
}
/// Define a name in the current (innermost) scope.
pub fn define(&mut self, name: &str) {
if let Some(scope) = self.scopes.last_mut() {
scope.insert(name.to_string());
}
}
/// Check whether a name is defined in any enclosing scope.
pub fn is_defined(&self, name: &str) -> bool {
self.scopes.iter().rev().any(|scope| scope.contains(name))
}
}
}
Why a Vec<HashSet>? Each HashSet is one scope. Pushing and popping is O(1). Lookup walks the vector from back to front, which is O(number of scopes) — for typical programs with a few levels of nesting, this is trivially fast.
Built-in names
The built-in operators (+, -, *, /, =, <, >, <=, >=, not, and, or, display, newline) are always in scope. Define them in the global scope at startup:
#![allow(unused)]
fn main() {
impl Env {
/// Create an environment with built-in names pre-defined.
pub fn with_builtins() -> Self {
let mut env = Env::new();
for name in &[
"+", "-", "*", "/", "=", "<", ">", "<=", ">=",
"not", "and", "or", "display", "newline",
] {
env.define(name);
}
env
}
}
}
Walking the AST
The analyser traverses the AST recursively, maintaining the environment as it goes. For each node:
- Symbol — check that the name is defined in the current environment. If not, record an error.
- Define / DefineFunc — add the name to the current scope, then check the body.
- Lambda — push a new scope, add parameters, check the body, pop the scope.
- Let — push a new scope, evaluate each binding’s value in the outer scope (this is
let, notlet*), add all binding names to the inner scope, check the body, pop the scope. - If, Call, Begin — recursively check all sub-expressions.
#![allow(unused)]
fn main() {
use crate::ast::Expr;
/// Check all expressions for scope errors.
/// Returns a list of error messages (empty means success).
pub fn check_scope(exprs: &[Expr]) -> Vec<String> {
let mut env = Env::with_builtins();
let mut errors = Vec::new();
for expr in exprs {
check_expr(expr, &mut env, &mut errors);
}
errors
}
fn check_expr(expr: &Expr, env: &mut Env, errors: &mut Vec<String>) {
match expr {
Expr::Int(_) | Expr::Bool(_) | Expr::Str(_) => {
// Atoms are always valid
}
Expr::Symbol(name) => {
if !env.is_defined(name) {
errors.push(format!("undefined symbol: {}", name));
}
}
Expr::Call(func, args) => {
check_expr(func, env, errors);
for arg in args {
check_expr(arg, env, errors);
}
}
Expr::Define(name, value) => {
check_expr(value, env, errors);
env.define(name);
}
Expr::DefineFunc(name, params, body) => {
env.define(name); // allow recursion
env.push_scope();
for p in params {
env.define(p);
}
for expr in body {
check_expr(expr, env, errors);
}
env.pop_scope();
}
Expr::If(cond, then, els) => {
check_expr(cond, env, errors);
check_expr(then, env, errors);
check_expr(els, env, errors);
}
Expr::Lambda(params, body) => {
env.push_scope();
for p in params {
env.define(p);
}
for expr in body {
check_expr(expr, env, errors);
}
env.pop_scope();
}
Expr::Let(bindings, body) => {
// Evaluate binding values in the outer scope
for (_, val) in bindings {
check_expr(val, env, errors);
}
// Then introduce binding names in a new scope
env.push_scope();
for (name, _) in bindings {
env.define(name);
}
for expr in body {
check_expr(expr, env, errors);
}
env.pop_scope();
}
Expr::Begin(exprs) => {
for expr in exprs {
check_expr(expr, env, errors);
}
}
}
}
}
Error reporting
Notice that DefineFunc defines the function name before checking the body. This allows recursive functions to refer to themselves. Without this, (define (fact n) ... (fact (- n 1))) would report fact as undefined inside its own body.
The analyser collects all errors before stopping, so the programmer sees every problem at once:
undefined symbol: foo
undefined symbol: bar
Tests
#![allow(unused)]
fn main() {
#[test]
fn undefined_symbol() {
let exprs = vec![Expr::Symbol("x".to_string())];
let errors = check_scope(&exprs);
assert_eq!(errors, vec!["undefined symbol: x"]);
}
#[test]
fn defined_symbol() {
let exprs = vec![
Expr::Define("x".to_string(), Box::new(Expr::Int(1))),
Expr::Symbol("x".to_string()),
];
let errors = check_scope(&exprs);
assert!(errors.is_empty());
}
#[test]
fn builtin_is_defined() {
let exprs = vec![Expr::Symbol("+".to_string())];
let errors = check_scope(&exprs);
assert!(errors.is_empty());
}
#[test]
fn lambda_scope() {
// (lambda (x) x) — x is defined inside the lambda
let exprs = vec![Expr::Lambda(
vec!["x".to_string()],
vec![Expr::Symbol("x".to_string())],
)];
let errors = check_scope(&exprs);
assert!(errors.is_empty());
}
#[test]
fn lambda_scope_does_not_leak() {
// (lambda (x) x) then reference x — x should be undefined
let exprs = vec![
Expr::Lambda(
vec!["x".to_string()],
vec![Expr::Symbol("x".to_string())],
),
Expr::Symbol("x".to_string()),
];
let errors = check_scope(&exprs);
assert_eq!(errors, vec!["undefined symbol: x"]);
}
}
11. Checking Special Forms
Special forms have fixed shapes: if needs exactly three sub-expressions; define needs a name and a body; lambda needs a parameter list and at least one body expression. This section adds arity and shape checks for each special form so that malformed programs produce clear error messages rather than mysterious C output.
Why check form shapes?
The parser already enforces some structure — it will not produce an If node with two sub-expressions because the parser expects exactly three. However, the parser’s error messages for malformed input can be cryptic (nom reports parsing failures, not semantic problems). The analyser provides a second line of defence with clear, domain-specific error messages.
More importantly, if the AST is ever constructed programmatically (for example, by a macro expander), the analyser catches structural errors that the parser would have prevented.
Checks to implement
| Form | Rule | Error message |
|---|---|---|
Define | value must not be empty | define: missing value |
DefineFunc | must have at least one body expression | define: function body is empty |
If | must have exactly condition, then, else | (enforced by AST structure) |
Lambda | must have at least one body expression | lambda: body is empty |
Let | must have at least one body expression | let: body is empty |
Let | each binding must have exactly a name and value | (enforced by AST structure) |
Begin | must have at least one expression | begin: empty begin block |
Several of these are enforced by the AST’s type structure — for example, If always has exactly three Box<Expr> fields, so you cannot construct a two-armed If. The checks below handle the cases that the type system does not catch.
Implementation
Add form validation to src/analysis.rs, integrated with the scope checker:
#![allow(unused)]
fn main() {
/// Validate the shape of special forms.
fn check_forms(expr: &Expr, errors: &mut Vec<String>) {
match expr {
Expr::DefineFunc(name, _, body) => {
if body.is_empty() {
errors.push(format!(
"define: function '{}' body is empty", name
));
}
for e in body {
check_forms(e, errors);
}
}
Expr::Lambda(_, body) => {
if body.is_empty() {
errors.push("lambda: body is empty".to_string());
}
for e in body {
check_forms(e, errors);
}
}
Expr::Let(bindings, body) => {
if body.is_empty() {
errors.push("let: body is empty".to_string());
}
for (_, val) in bindings {
check_forms(val, errors);
}
for e in body {
check_forms(e, errors);
}
}
Expr::Begin(exprs) => {
if exprs.is_empty() {
errors.push("begin: empty begin block".to_string());
}
for e in exprs {
check_forms(e, errors);
}
}
Expr::Define(_, val) => {
check_forms(val, errors);
}
Expr::If(cond, then, els) => {
check_forms(cond, errors);
check_forms(then, errors);
check_forms(els, errors);
}
Expr::Call(func, args) => {
check_forms(func, errors);
for a in args {
check_forms(a, errors);
}
}
Expr::Int(_) | Expr::Bool(_) | Expr::Str(_) | Expr::Symbol(_) => {}
}
}
}
Combining scope and form checks
The public analyse function runs both passes:
#![allow(unused)]
fn main() {
/// Run all semantic checks on a parsed program.
pub fn analyse(exprs: &[Expr]) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
// Pass 1: scope checking
errors.extend(check_scope(exprs));
// Pass 2: form validation
for expr in exprs {
check_forms(expr, &mut errors);
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
Tests
#![allow(unused)]
fn main() {
#[test]
fn empty_begin() {
let exprs = vec![Expr::Begin(vec![])];
let result = analyse(&exprs);
assert!(result.is_err());
let errs = result.unwrap_err();
assert!(errs.iter().any(|e| e.contains("empty begin block")));
}
#[test]
fn empty_lambda_body() {
let exprs = vec![Expr::Lambda(vec!["x".to_string()], vec![])];
let result = analyse(&exprs);
assert!(result.is_err());
let errs = result.unwrap_err();
assert!(errs.iter().any(|e| e.contains("lambda: body is empty")));
}
#[test]
fn valid_program() {
let exprs = vec![
Expr::Define("x".to_string(), Box::new(Expr::Int(10))),
Expr::Call(
Box::new(Expr::Symbol("display".to_string())),
vec![Expr::Symbol("x".to_string())],
),
];
assert!(analyse(&exprs).is_ok());
}
}
Part 4 — Code Generation
12. The C Runtime Preamble
Every MiniLisp program compiles to a C file that begins with a standard preamble: #include directives, type aliases, boolean constants, and thin wrappers for built-in operations like display and newline. This section designs the preamble, explains why each piece is there, and shows how the code generator emits it before any user-defined code.
Why a preamble?
MiniLisp has built-in operations — display, newline, arithmetic, comparisons — that the generated C code calls. These need C implementations. Rather than linking a separate runtime library, we emit the implementations directly at the top of the C file. This makes every generated file self-contained: you compile it with cc -o out out.c and it works.
The complete preamble
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
/* Boolean type */
typedef int bool_t;
#define TRUE 1
#define FALSE 0
/* Display functions */
static void display_int(long long x) {
printf("%lld", x);
}
static void display_bool(bool_t x) {
printf("%s", x ? "#t" : "#f");
}
static void display_str(const char *s) {
printf("%s", s);
}
static void newline_() {
printf("\n");
}
Line by line:
#include <stdio.h>— forprintf.#include <stdlib.h>— forexit, used by error-handling code.#include <string.h>— for any string operations.typedef int bool_t— MiniLisp booleans become C integers.TRUEandFALSEare macros for readability.display_int,display_bool,display_str— MiniLisp’sdisplayis polymorphic (it can print any value). Since our compiler knows the type of each expression at compile time (integers, booleans, and strings are the only printable types), it calls the correctdisplay_*variant directly.newline_— the trailing underscore avoids clashing with C’snewlineif one exists in the environment. This is a pattern we use throughout: MiniLisp names are mangled to avoid C keyword and standard library conflicts.
Emitting the preamble
In src/codegen.rs, define the preamble as a constant:
#![allow(unused)]
fn main() {
const PREAMBLE: &str = r#"#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef int bool_t;
#define TRUE 1
#define FALSE 0
static void display_int(long long x) {
printf("%lld", x);
}
static void display_bool(bool_t x) {
printf("%s", x ? "#t" : "#f");
}
static void display_str(const char *s) {
printf("%s", s);
}
static void newline_() {
printf("\n");
}
"#;
}
The code generator begins every output file with this string. User-defined code follows immediately after.
Design decision: monomorphic display
A more sophisticated compiler would use a tagged union (a “value” type) so that display could handle any value at runtime. We avoid this because:
- It adds significant complexity (you need a value representation, tag checks, and memory management).
- Our compiler knows at AST level whether an expression is an integer, boolean, or string, so it can emit the correct
display_*call directly. - The goal is simplicity and clarity — you can extend this later if you want dynamic typing.
13. Generating C: Atoms and Expressions
This section implements the expression code generator — the recursive function that turns an Expr into a C expression string. Integers become C integer literals; booleans become TRUE and FALSE; strings become string literals; arithmetic and comparison operations become C operators; function calls become C function-call syntax. You will also handle name-mangling: turning Lisp symbols like my-var into valid C identifiers.
Name mangling
MiniLisp symbols can contain characters that are not valid in C identifiers: -, +, *, /, =, <, >, !, ?. We replace each with a safe suffix:
#![allow(unused)]
fn main() {
/// Convert a MiniLisp symbol into a valid C identifier.
fn mangle(name: &str) -> String {
let mut out = String::new();
for ch in name.chars() {
match ch {
'-' => out.push_str("_dash_"),
'+' => out.push_str("_plus_"),
'*' => out.push_str("_star_"),
'/' => out.push_str("_slash_"),
'=' => out.push_str("_eq_"),
'<' => out.push_str("_lt_"),
'>' => out.push_str("_gt_"),
'!' => out.push_str("_bang_"),
'?' => out.push_str("_qmark_"),
c => out.push(c),
}
}
out
}
}
Examples:
| MiniLisp | C identifier |
|---|---|
x | x |
my-var | my_dash_var |
<= | _lt__eq_ |
zero? | zero_qmark_ |
Generating atoms
#![allow(unused)]
fn main() {
/// Generate a C expression from an AST node.
fn gen_expr(expr: &Expr) -> String {
match expr {
Expr::Int(n) => format!("{}LL", n),
Expr::Bool(b) => if *b { "TRUE".to_string() } else { "FALSE".to_string() },
Expr::Str(s) => format!("\"{}\"", s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\t', "\\t")),
Expr::Symbol(name) => mangle(name),
// ... other variants handled below
_ => todo!(),
}
}
}
Details:
- Integers get the
LLsuffix to ensure they arelong longin C, matching ourdisplay_intsignature. - Booleans become the macros
TRUEandFALSEfrom the preamble. - Strings are re-escaped for C. The MiniLisp parser already resolved escape sequences into actual characters, so we must re-escape them for the C source.
- Symbols are mangled into valid C identifiers.
Generating arithmetic and comparisons
Built-in binary operators compile to infix C operators:
#![allow(unused)]
fn main() {
Expr::Call(func, args) => {
if let Expr::Symbol(name) = func.as_ref() {
match name.as_str() {
// Arithmetic
"+" => format!("({} + {})", gen_expr(&args[0]), gen_expr(&args[1])),
"-" => format!("({} - {})", gen_expr(&args[0]), gen_expr(&args[1])),
"*" => format!("({} * {})", gen_expr(&args[0]), gen_expr(&args[1])),
"/" => format!("({} / {})", gen_expr(&args[0]), gen_expr(&args[1])),
// Comparisons
"=" => format!("({} == {})", gen_expr(&args[0]), gen_expr(&args[1])),
"<" => format!("({} < {})", gen_expr(&args[0]), gen_expr(&args[1])),
">" => format!("({} > {})", gen_expr(&args[0]), gen_expr(&args[1])),
"<=" => format!("({} <= {})", gen_expr(&args[0]), gen_expr(&args[1])),
">=" => format!("({} >= {})", gen_expr(&args[0]), gen_expr(&args[1])),
// Boolean operators
"not" => format!("(!{})", gen_expr(&args[0])),
"and" => format!("({} && {})", gen_expr(&args[0]), gen_expr(&args[1])),
"or" => format!("({} || {})", gen_expr(&args[0]), gen_expr(&args[1])),
// Display
"display" => gen_display(&args[0]),
"newline" => "newline_()".to_string(),
// User-defined function call
_ => {
let mangled = mangle(name);
let arg_strs: Vec<String> = args.iter().map(gen_expr).collect();
format!("{}({})", mangled, arg_strs.join(", "))
}
}
} else {
// Calling a non-symbol expression (e.g., a lambda)
let func_str = gen_expr(func);
let arg_strs: Vec<String> = args.iter().map(gen_expr).collect();
format!("{}({})", func_str, arg_strs.join(", "))
}
}
}
Each arithmetic and comparison operation is wrapped in parentheses to preserve correct precedence in the C output.
Generating display calls
display in MiniLisp is polymorphic, but our compiler knows the type statically. We dispatch based on the argument’s AST type:
#![allow(unused)]
fn main() {
fn gen_display(expr: &Expr) -> String {
match expr {
Expr::Int(_) => format!("display_int({})", gen_expr(expr)),
Expr::Bool(_) => format!("display_bool({})", gen_expr(expr)),
Expr::Str(_) => format!("display_str({})", gen_expr(expr)),
// For expressions whose type we can't determine statically,
// default to display_int (integers are the most common case)
_ => format!("display_int({})", gen_expr(expr)),
}
}
}
This is a simplification. A production compiler would either track types through the AST or use a tagged union. For our purposes, defaulting to display_int for computed expressions (like (display (+ 1 2))) is correct because all our arithmetic produces integers.
Tests
#![allow(unused)]
fn main() {
#[test]
fn gen_integer() {
assert_eq!(gen_expr(&Expr::Int(42)), "42LL");
}
#[test]
fn gen_bool() {
assert_eq!(gen_expr(&Expr::Bool(true)), "TRUE");
}
#[test]
fn gen_addition() {
let expr = Expr::Call(
Box::new(Expr::Symbol("+".to_string())),
vec![Expr::Int(1), Expr::Int(2)],
);
assert_eq!(gen_expr(&expr), "(1LL + 2LL)");
}
#[test]
fn mangle_hyphen() {
assert_eq!(mangle("my-var"), "my_dash_var");
}
}
14. Generating C: Definitions and Functions
Top-level define forms and lambda expressions compile to C function and variable declarations. This section covers how to emit forward declarations (so mutual recursion works), how to turn a MiniLisp parameter list into a C function signature, how lambda compiles to a named C function, and how top-level definitions are ordered in the output file.
The output file structure
The generated C file has four sections, in this order:
/* 1. Preamble (§12) */
#include <stdio.h>
// ...
/* 2. Forward declarations */
long long square(long long x);
long long fact(long long n);
/* 3. Function definitions */
long long square(long long x) {
return (x * x);
}
long long fact(long long n) {
return (n == 0LL) ? 1LL : (n * fact((n - 1LL)));
}
/* 4. main() — top-level expressions */
int main() {
display_int(square(5LL));
newline_();
return 0;
}
Forward declarations allow functions to call each other regardless of definition order. This mirrors how MiniLisp works — order of definitions does not matter for function calls.
Generating function definitions
A DefineFunc compiles to a C function. All parameters and return values use long long for simplicity (MiniLisp does not have a type system, and integers are the primary data type):
#![allow(unused)]
fn main() {
fn gen_func_decl(name: &str, params: &[String]) -> String {
let mangled_name = mangle(name);
let param_list: Vec<String> = params
.iter()
.map(|p| format!("long long {}", mangle(p)))
.collect();
format!("long long {}({})", mangled_name, param_list.join(", "))
}
fn gen_func_def(name: &str, params: &[String], body: &[Expr]) -> String {
let decl = gen_func_decl(name, params);
let body_strs: Vec<String> = body.iter().map(gen_expr).collect();
// All body expressions except the last are statements;
// the last is the return value
let mut lines = Vec::new();
for (i, s) in body_strs.iter().enumerate() {
if i == body_strs.len() - 1 {
lines.push(format!(" return {};", s));
} else {
lines.push(format!(" {};", s));
}
}
format!("{} {{\n{}\n}}", decl, lines.join("\n"))
}
}
For example, (define (square x) (* x x)) becomes:
long long square(long long x) {
return (x * x);
}
Generating variable definitions
A Define (variable, not function) compiles to a global variable:
#![allow(unused)]
fn main() {
fn gen_var_def(name: &str, value: &Expr) -> String {
format!("long long {} = {};", mangle(name), gen_expr(value))
}
}
For example, (define pi 3) becomes:
long long pi = 3LL;
Lambda lifting
A lambda expression needs a name in C — C does not have anonymous functions. We generate a unique name for each lambda:
#![allow(unused)]
fn main() {
use std::sync::atomic::{AtomicUsize, Ordering};
static LAMBDA_COUNTER: AtomicUsize = AtomicUsize::new(0);
fn fresh_lambda_name() -> String {
let n = LAMBDA_COUNTER.fetch_add(1, Ordering::SeqCst);
format!("__lambda_{}", n)
}
}
When the code generator encounters a Lambda, it:
- Generates a named function definition using the fresh name.
- Adds it to a list of “lifted” functions that will be emitted before
main. - Returns the function’s name as the C expression.
#![allow(unused)]
fn main() {
fn gen_lambda(
params: &[String],
body: &[Expr],
lifted: &mut Vec<String>,
) -> String {
let name = fresh_lambda_name();
let def = gen_func_def(&name, params, body);
lifted.push(def);
name
}
}
Putting function generation together
The top-level generation function separates definitions from expressions:
#![allow(unused)]
fn main() {
pub fn generate(exprs: &[Expr]) -> String {
let mut forward_decls = Vec::new();
let mut func_defs = Vec::new();
let mut main_stmts = Vec::new();
let mut lifted = Vec::new();
for expr in exprs {
match expr {
Expr::DefineFunc(name, params, body) => {
forward_decls.push(format!("{};", gen_func_decl(name, params)));
func_defs.push(gen_func_def(name, params, body));
}
Expr::Define(name, value) => {
main_stmts.push(format!(
"long long {} = {};",
mangle(name),
gen_expr(value)
));
}
_ => {
main_stmts.push(format!("{};", gen_expr(expr)));
}
}
}
let mut output = String::new();
output.push_str(PREAMBLE);
for decl in &forward_decls {
output.push_str(decl);
output.push('\n');
}
output.push('\n');
for def in &func_defs {
output.push_str(def);
output.push_str("\n\n");
}
for lf in &lifted {
output.push_str(lf);
output.push_str("\n\n");
}
output.push_str("int main() {\n");
for stmt in &main_stmts {
output.push_str(" ");
output.push_str(stmt);
output.push('\n');
}
output.push_str(" return 0;\n");
output.push_str("}\n");
output
}
}
15. Generating C: Control Flow and Sequencing
if, begin, and let each require their own code-generation strategy. if becomes a C ternary expression or an if/else statement depending on context; begin becomes a sequence of C statements with the last value forwarded; let introduces a C block with local variable declarations. This section works through each form and resolves the practical question of when to emit expressions versus statements.
The expression vs. statement problem
C distinguishes expressions (produce a value) from statements (do not). MiniLisp does not — everything is an expression. This creates a tension: sometimes a MiniLisp form appears where a C expression is needed (e.g., as a function argument), and sometimes it appears where a C statement is fine (e.g., at the top level).
Our strategy: generate expressions wherever possible, falling back to statement blocks only when necessary. The ternary operator ? : lets us keep if as an expression. begin and let use GCC’s “statement expressions” extension ({ ... }) to remain in expression position.
Generating if
An if compiles to a C ternary expression:
#![allow(unused)]
fn main() {
Expr::If(cond, then, els) => {
format!(
"({} ? {} : {})",
gen_expr(cond),
gen_expr(then),
gen_expr(els),
)
}
}
MiniLisp (if (> x 0) x (- 0 x)) becomes C ((x > 0LL) ? x : (0LL - x)).
The ternary operator works because our if always has both branches. If we supported if without an else branch, we would need a different strategy.
Generating begin
begin evaluates a sequence of expressions and returns the value of the last one. We use GCC’s statement-expression extension:
#![allow(unused)]
fn main() {
Expr::Begin(exprs) => {
let mut parts = Vec::new();
for (i, e) in exprs.iter().enumerate() {
if i == exprs.len() - 1 {
parts.push(gen_expr(e));
} else {
parts.push(format!("{};", gen_expr(e)));
}
}
format!("({{ {} }})", parts.join(" "))
}
}
MiniLisp:
(begin
(display 1)
(newline)
42)
Becomes C:
({ display_int(1LL); newline_(); 42LL; })
The statement-expression ({ ... }) evaluates each statement in order and produces the value of the last expression. This is a GCC/Clang extension, not standard C. For portability, you could instead generate a helper function, but the extension keeps things simple.
Generating let
let introduces local variables. We generate a statement-expression with local declarations:
#![allow(unused)]
fn main() {
Expr::Let(bindings, body) => {
let mut parts = Vec::new();
for (name, val) in bindings {
parts.push(format!(
"long long {} = {};",
mangle(name),
gen_expr(val),
));
}
for (i, e) in body.iter().enumerate() {
if i == body.len() - 1 {
parts.push(gen_expr(e));
} else {
parts.push(format!("{};", gen_expr(e)));
}
}
format!("({{ {} }})", parts.join(" "))
}
}
MiniLisp:
(let ((x 10)
(y 20))
(+ x y))
Becomes C:
({ long long x = 10LL; long long y = 20LL; (x + y); })
Complete gen_expr
Here is the full gen_expr function with all variants:
#![allow(unused)]
fn main() {
fn gen_expr(expr: &Expr) -> String {
match expr {
Expr::Int(n) => format!("{}LL", n),
Expr::Bool(b) => if *b { "TRUE".to_string() } else { "FALSE".to_string() },
Expr::Str(s) => {
let escaped = s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\t', "\\t");
format!("\"{}\"", escaped)
}
Expr::Symbol(name) => mangle(name),
Expr::Call(func, args) => {
// ... (as shown in §13)
gen_call(func, args)
}
Expr::If(cond, then, els) => {
format!("({} ? {} : {})",
gen_expr(cond), gen_expr(then), gen_expr(els))
}
Expr::Begin(exprs) => {
let mut parts = Vec::new();
for (i, e) in exprs.iter().enumerate() {
if i == exprs.len() - 1 {
parts.push(gen_expr(e));
} else {
parts.push(format!("{};", gen_expr(e)));
}
}
format!("({{ {} }})", parts.join(" "))
}
Expr::Let(bindings, body) => {
let mut parts = Vec::new();
for (name, val) in bindings {
parts.push(format!("long long {} = {};",
mangle(name), gen_expr(val)));
}
for (i, e) in body.iter().enumerate() {
if i == body.len() - 1 {
parts.push(gen_expr(e));
} else {
parts.push(format!("{};", gen_expr(e)));
}
}
format!("({{ {} }})", parts.join(" "))
}
Expr::Lambda(params, body) => {
// Lambda lifting is handled at the top level;
// if we encounter a lambda in expression position,
// we would need to lift it. For simplicity, we
// generate an inline comment.
let name = fresh_lambda_name();
format!("/* lambda {} */", name)
}
Expr::Define(name, val) => {
format!("long long {} = {}", mangle(name), gen_expr(val))
}
Expr::DefineFunc(_, _, _) => {
// Function definitions are handled at the top level
String::new()
}
}
}
}
Tests
#![allow(unused)]
fn main() {
#[test]
fn gen_if() {
let expr = Expr::If(
Box::new(Expr::Bool(true)),
Box::new(Expr::Int(1)),
Box::new(Expr::Int(0)),
);
assert_eq!(gen_expr(&expr), "(TRUE ? 1LL : 0LL)");
}
#[test]
fn gen_let() {
let expr = Expr::Let(
vec![("x".to_string(), Expr::Int(10))],
vec![Expr::Symbol("x".to_string())],
);
assert_eq!(gen_expr(&expr), "({ long long x = 10LL; x })");
}
#[test]
fn gen_begin() {
let expr = Expr::Begin(vec![Expr::Int(1), Expr::Int(2)]);
assert_eq!(gen_expr(&expr), "({ 1LL; 2LL })");
}
}
Part 5 — Putting It Together
16. The Compilation Pipeline
With all stages implemented, this section wires them into a single compile function and builds a CLI entry point that reads MiniLisp from a file or stdin and writes C to stdout or a file. You will add basic error reporting that shows the source location of each failure and trace a complete example — a recursive factorial function — through every stage.
The compile function
#![allow(unused)]
fn main() {
// In src/main.rs
mod ast;
mod parser;
mod analysis;
mod codegen;
use std::io::{self, Read};
use std::process;
fn compile(source: &str) -> Result<String, Vec<String>> {
// Stage 1: Parse
let exprs = match parser::parse_program(source) {
Ok(("", exprs)) => exprs,
Ok((remaining, _)) => {
return Err(vec![format!(
"parse error: unexpected trailing input: {}",
&remaining[..remaining.len().min(40)]
)]);
}
Err(e) => {
return Err(vec![format!("parse error: {}", e)]);
}
};
// Stage 2: Analyse
analysis::analyse(&exprs)?;
// Stage 3: Generate
Ok(codegen::generate(&exprs))
}
}
The function takes MiniLisp source as a string and returns either the generated C code or a list of error messages.
The CLI entry point
fn main() {
let mut input = String::new();
io::stdin()
.read_to_string(&mut input)
.expect("failed to read stdin");
match compile(&input) {
Ok(c_code) => print!("{}", c_code),
Err(errors) => {
for err in &errors {
eprintln!("error: {}", err);
}
process::exit(1);
}
}
}
Usage:
# Compile MiniLisp to C
echo '(display 42) (newline)' | cargo run > out.c
# Compile and run the C
cc -o out out.c && ./out
42
Tracing an example: factorial
Let us trace (define (fact n) (if (= n 0) 1 (* n (fact (- n 1))))) (display (fact 10)) (newline) through every stage.
Stage 1: Parse
The parser produces:
#![allow(unused)]
fn main() {
[
DefineFunc("fact", ["n"], [
If(
Call(Symbol("="), [Symbol("n"), Int(0)]),
Int(1),
Call(Symbol("*"), [
Symbol("n"),
Call(Symbol("fact"), [
Call(Symbol("-"), [Symbol("n"), Int(1)])
])
])
)
]),
Call(Symbol("display"), [
Call(Symbol("fact"), [Int(10)])
]),
Call(Symbol("newline"), []),
]
}
Stage 2: Analyse
The analyser checks:
factis defined (byDefineFunc), so the recursive call is valid.nis a parameter offact, so it is in scope inside the body.=,*,-,display,newlineare built-in — always in scope.- The
ifhas three sub-expressions. All forms are well-shaped.
Result: no errors.
Stage 3: Generate
The code generator produces:
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef int bool_t;
#define TRUE 1
#define FALSE 0
// ... (display functions, etc.)
long long fact(long long n);
long long fact(long long n) {
return ((n == 0LL) ? 1LL : (n * fact((n - 1LL))));
}
int main() {
display_int(fact(10LL));
newline_();
return 0;
}
Compile and run:
$ cc -o fact out.c && ./fact
3628800
Error reporting
When the compiler encounters errors, it prints them to stderr and exits with code 1:
$ echo '(display x)' | cargo run
error: undefined symbol: x
$ echo '(if #t 1)' | cargo run
error: parse error: ...
The parser error will mention the position where parsing failed. The semantic analyser’s errors identify the problematic symbol or form by name.
17. Testing the Compiler
Good tests are what turn a working prototype into a reliable tool. This section adds unit tests for each compiler stage and integration tests that compile MiniLisp programs, feed the C output to cc, run the binary, and assert on stdout. You will build a small test corpus of MiniLisp programs covering all language features and ensure the compiler handles both valid and invalid input gracefully.
Unit test structure
Each module has inline unit tests in a #[cfg(test)] module. We have already seen examples throughout the course. Here is a summary of what to test in each module:
ast.rs — test Display:
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_int() {
assert_eq!(format!("{}", Expr::Int(42)), "42");
}
#[test]
fn display_call() {
let expr = Expr::Call(
Box::new(Expr::Symbol("+".to_string())),
vec![Expr::Int(1), Expr::Int(2)],
);
assert_eq!(format!("{}", expr), "(+ 1 2)");
}
}
}
parser.rs — test parsing of every construct:
#![allow(unused)]
fn main() {
#[test]
fn parse_factorial() {
let input = "(define (fact n) (if (= n 0) 1 (* n (fact (- n 1)))))";
let (remaining, expr) = parse_expr(input).unwrap();
assert_eq!(remaining, "");
match expr {
Expr::DefineFunc(name, params, body) => {
assert_eq!(name, "fact");
assert_eq!(params, vec!["n".to_string()]);
assert_eq!(body.len(), 1); // single if expression
}
_ => panic!("expected DefineFunc"),
}
}
}
analysis.rs — test valid and invalid programs:
#![allow(unused)]
fn main() {
#[test]
fn recursive_function_is_valid() {
let exprs = vec![Expr::DefineFunc(
"f".to_string(),
vec!["x".to_string()],
vec![Expr::Call(
Box::new(Expr::Symbol("f".to_string())),
vec![Expr::Symbol("x".to_string())],
)],
)];
assert!(analyse(&exprs).is_ok());
}
}
codegen.rs — test C output for known inputs:
#![allow(unused)]
fn main() {
#[test]
fn gen_full_program() {
let exprs = vec![
Expr::Call(
Box::new(Expr::Symbol("display".to_string())),
vec![Expr::Int(42)],
),
];
let output = generate(&exprs);
assert!(output.contains("display_int(42LL)"));
assert!(output.contains("int main()"));
}
}
Integration tests
Integration tests live in the tests/ directory. Each test compiles a MiniLisp program, runs the C compiler, executes the binary, and asserts on the output:
#![allow(unused)]
fn main() {
// tests/integration.rs
use std::process::Command;
use std::io::Write;
fn compile_and_run(minilisp: &str) -> String {
// Run our compiler
let mut child = Command::new("cargo")
.args(["run", "--quiet"])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.expect("failed to start compiler");
child
.stdin
.as_mut()
.unwrap()
.write_all(minilisp.as_bytes())
.unwrap();
let output = child.wait_with_output().unwrap();
assert!(
output.status.success(),
"compiler failed: {}",
String::from_utf8_lossy(&output.stderr)
);
let c_code = String::from_utf8(output.stdout).unwrap();
// Write C to a temp file
let dir = tempfile::tempdir().unwrap();
let c_path = dir.path().join("test.c");
let bin_path = dir.path().join("test");
std::fs::write(&c_path, &c_code).unwrap();
// Compile C
let cc = Command::new("cc")
.args([
c_path.to_str().unwrap(),
"-o",
bin_path.to_str().unwrap(),
])
.output()
.expect("failed to run cc");
assert!(
cc.status.success(),
"cc failed: {}",
String::from_utf8_lossy(&cc.stderr)
);
// Run the binary
let run = Command::new(bin_path)
.output()
.expect("failed to run binary");
String::from_utf8(run.stdout).unwrap()
}
#[test]
fn test_display_integer() {
let output = compile_and_run("(display 42) (newline)");
assert_eq!(output, "42\n");
}
#[test]
fn test_arithmetic() {
let output = compile_and_run("(display (+ 3 (* 4 5))) (newline)");
assert_eq!(output, "23\n");
}
#[test]
fn test_if_true() {
let output = compile_and_run("(display (if #t 1 0)) (newline)");
assert_eq!(output, "1\n");
}
#[test]
fn test_if_false() {
let output = compile_and_run("(display (if #f 1 0)) (newline)");
assert_eq!(output, "0\n");
}
#[test]
fn test_define_and_use() {
let output = compile_and_run("(define x 10) (display x) (newline)");
assert_eq!(output, "10\n");
}
#[test]
fn test_function_definition() {
let output = compile_and_run(
"(define (double x) (* x 2)) (display (double 21)) (newline)"
);
assert_eq!(output, "42\n");
}
#[test]
fn test_recursion() {
let output = compile_and_run(
"(define (fact n) (if (= n 0) 1 (* n (fact (- n 1))))) \
(display (fact 10)) (newline)"
);
assert_eq!(output, "3628800\n");
}
#[test]
fn test_let_binding() {
let output = compile_and_run(
"(display (let ((x 10) (y 20)) (+ x y))) (newline)"
);
assert_eq!(output, "30\n");
}
#[test]
fn test_begin() {
let output = compile_and_run(
"(begin (display 1) (display 2) (display 3)) (newline)"
);
assert_eq!(output, "123\n");
}
#[test]
fn test_string_display() {
let output = compile_and_run("(display \"hello world\") (newline)");
assert_eq!(output, "hello world\n");
}
#[test]
fn test_boolean_display() {
let output = compile_and_run("(display #t) (newline)");
assert_eq!(output, "#t\n");
}
#[test]
fn test_comparison() {
let output = compile_and_run(
"(display (if (< 3 5) 1 0)) (newline)"
);
assert_eq!(output, "1\n");
}
}
Test corpus checklist
Ensure you have tests covering:
- Integer literals (positive, negative, zero)
- Boolean literals
- String literals (with escape sequences)
- Arithmetic operators (
+,-,*,/) - Comparison operators (
=,<,>,<=,>=) - Boolean operators (
not,and,or) -
displayfor each type -
newline -
define(variable) -
define(function, including recursive) -
if(both branches) -
lambda -
letbindings -
beginsequencing - Nested expressions
- Comments in source
- Invalid programs (undefined symbols, malformed forms)
Running the tests
# Unit tests
cargo test
# Integration tests (requires cc)
cargo test --test integration
18. What’s Next: Extensions and Further Reading
The compiler you have built is deliberately minimal — a solid foundation. This final section surveys the directions you can take it further: tail-call optimisation, closures and lambda lifting, a garbage collector, hygienic macros, a type system, an interactive REPL, and a self-hosting MiniLisp standard library. It closes with a curated reading list for going deeper into compiler theory and Lisp implementation.
Extensions
Tail-call optimisation. Recursive functions like fact currently grow the C call stack with every recursive call. In Lisp, tail calls are expected to run in constant space. You can implement this by detecting tail-position calls and compiling them as goto jumps or loops in the generated C.
Closures and lambda lifting. Our compiler handles lambdas that do not capture variables from their enclosing scope. True closures — lambdas that reference variables from outer scopes — require either lambda lifting (passing free variables as extra parameters) or a closure data structure that packages the function pointer with its captured environment.
Garbage collection. MiniLisp allocates strings but never frees them. For long-running programs, you would add a garbage collector. A simple mark-and-sweep collector is a good starting point: maintain a list of all allocated objects, periodically mark the ones reachable from the stack and global variables, and free the rest.
Hygienic macros. Lisp macros transform code before it is evaluated. A macro system would let MiniLisp users define their own special forms. Hygienic macros (as in Scheme’s syntax-rules) ensure that macro-generated code does not accidentally capture user variables.
A type system. Adding type annotations and inference would catch errors like passing a string to + at compile time instead of generating C that segfaults. Hindley-Milner type inference is the classic starting point for functional languages.
An interactive REPL. A Read-Eval-Print Loop would let users type MiniLisp expressions and see results immediately. You could implement this by compiling each expression to C, linking it dynamically, and calling it — or by adding an interpreter mode alongside the compiler.
A standard library. Write a MiniLisp standard library in MiniLisp: list operations (map, filter, fold), higher-order functions, and utility procedures. This exercises the compiler and reveals any limitations in the language.
Self-hosting. The ultimate test of a compiler: rewrite the compiler itself in MiniLisp. This requires adding enough features to MiniLisp (file I/O, string manipulation, data structures) to make it practical, then translating the Rust code into MiniLisp.
Further reading
Compiler construction:
- Compilers: Principles, Techniques, and Tools by Aho, Lam, Sethi, and Ullman (the “Dragon Book”) — the canonical compiler textbook.
- Engineering a Compiler by Cooper and Torczon — a modern, practical alternative.
- Crafting Interpreters by Robert Nystrom — an excellent, free online book that walks through building two complete interpreters.
Lisp implementation:
- Structure and Interpretation of Computer Programs (SICP) by Abelson and Sussman — the classic computer science textbook, centred on Scheme.
- Lisp in Small Pieces by Christian Queinnec — deep coverage of Lisp implementation strategies, from interpreters to compilers.
- An Incremental Approach to Compiler Construction by Abdulaziz Ghuloum — a paper describing how to build a Scheme-to-x86 compiler incrementally, one feature at a time.
Rust and parsing:
- The nom documentation and nom recipes — essential reference for parser combinators in Rust.
- Programming Rust by Blandy, Orendorff, and Tindall — comprehensive coverage of Rust, including traits, generics, and the borrow checker.
Congratulations
You have built a working compiler from scratch. It reads MiniLisp, validates it, and emits C. You understand the core pipeline — parsing, semantic analysis, code generation — that underlies every compiler, from GCC to the Rust compiler itself. The techniques you have learned here transfer directly to building compilers, interpreters, static analysers, linters, and code formatters for any language. The next step is yours: pick an extension, implement it, and keep building.
Shader Programming with wgpu and WGSL
This document is a self-guided course on GPU shader programming. It is organised into six parts: the GPU execution model, setting up with wgpu, vertex and fragment shaders, textures and samplers, compute shaders, and a look at where to go next. Each section is either a reading lesson or a hands-on Rust programming exercise.
Table of Contents
Part 1 — The GPU and the Graphics Pipeline
- CPU vs GPU: parallel execution model
- The programmable pipeline: vertex, fragment, compute shaders
- What is WGSL? Syntax overview
Part 2 — Setting Up with wgpu
- What is wgpu? Cross-platform graphics API in Rust
- Exercise 1: create a window and clear it to a colour
- The render loop: swap chains, frames, command encoders
Part 3 — Vertex and Fragment Shaders
- Vertices, buffers, and the vertex shader
- Interpolation and the fragment shader
- Exercise 2: draw a coloured triangle
- Exercise 3: animate the triangle using a time uniform
Part 4 — Textures and Samplers
Part 5 — Compute Shaders
- Compute pipelines: dispatching work groups
- Storage buffers and read/write access from WGSL
- Exercise 5: GPU-accelerate a particle simulation
Part 6 — Going Further
- Post-processing effects (bloom, blur): conceptual overview
- Signed Distance Fields for font rendering
- Resources: Learn WGPU, Shadertoy, The Book of Shaders
Part 1 — The GPU and the Graphics Pipeline
1. CPU vs GPU: parallel execution model
To understand shader programming, you first need to understand why GPUs exist and how they differ from CPUs. The core difference comes down to a design trade-off: latency vs throughput.
The CPU: a few powerful cores
A modern CPU has a small number of cores — typically 4 to 16 on a consumer chip. Each core is highly sophisticated: it has deep pipelines, branch predictors, out-of-order execution, and large caches. This design makes each individual core extremely fast at executing a single sequence of instructions.
CPU (8 cores)
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ Core 0 │ │ Core 1 │ │ Core 2 │ │ Core 3 │
│ (complex)│ │ (complex)│ │ (complex)│ │ (complex)│
│ OoO exec │ │ OoO exec │ │ OoO exec │ │ OoO exec │
│ Branch │ │ Branch │ │ Branch │ │ Branch │
│ pred. │ │ pred. │ │ pred. │ │ pred. │
│ L1/L2 │ │ L1/L2 │ │ L1/L2 │ │ L1/L2 │
└──────────┘ └──────────┘ └──────────┘ └──────────┘
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ Core 4 │ │ Core 5 │ │ Core 6 │ │ Core 7 │
│ (complex)│ │ (complex)│ │ (complex)│ │ (complex)│
└──────────┘ └──────────┘ └──────────┘ └──────────┘
CPUs are optimised for low latency — finishing any single task as quickly as possible. This makes them ideal for general-purpose programming: parsing JSON, running game logic, managing operating system tasks.
The GPU: thousands of simple cores
A GPU takes the opposite approach. It packs thousands of tiny, simple cores onto a single chip. Each individual core is much less powerful than a CPU core — no branch prediction, no out-of-order execution, minimal cache. But there are so many of them that the total throughput is enormous.
GPU (thousands of cores)
┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐
│ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │
├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤
│ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │
├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤
│ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │
├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤
│ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │
├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤
│ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │
├───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┼───┤
│ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │
└───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘
Each · is a simple core. Thousands execute in parallel.
GPUs are optimised for high throughput — processing millions of similar operations per second. Each individual operation might be slower than on a CPU, but the sheer volume of parallel work makes up for it.
SIMD vs SIMT
You may have heard of SIMD (Single Instruction, Multiple Data) on CPUs — instructions like SSE or AVX that process 4 or 8 values at once in a single register. GPUs take this idea much further with SIMT (Single Instruction, Multiple Threads).
In SIMT, groups of threads (called warps on NVIDIA or wavefronts on AMD) execute the same instruction at the same time, but each thread operates on different data. A typical warp is 32 threads wide.
SIMT execution (one warp of 32 threads):
Instruction: multiply position by matrix
Thread 0: vertex[0].pos * matrix → result[0]
Thread 1: vertex[1].pos * matrix → result[1]
Thread 2: vertex[2].pos * matrix → result[2]
...
Thread 31: vertex[31].pos * matrix → result[31]
All 32 threads execute the same multiply instruction
at the same clock cycle, on different vertex data.
This is why GPUs are perfect for graphics: every pixel on screen needs the same computation (run the fragment shader), just with different input coordinates. The same applies to vertex transformations, physics simulations, and many other tasks.
When does the GPU win?
The GPU excels when your problem has these characteristics:
- Data parallelism: the same operation is applied to many independent data elements
- Arithmetic intensity: lots of math per memory access
- Predictable control flow: minimal branching (if/else) since all threads in a warp must take the same path
Problems that are sequential, branch-heavy, or have complex data dependencies are better left on the CPU.
Key takeaway: CPUs are fast race cars — great at finishing one task quickly. GPUs are cargo ships — slower per trip, but they move enormous amounts of freight in parallel. Shader programming is the art of loading that cargo ship efficiently.
2. The programmable pipeline: vertex, fragment, compute shaders
Modern GPUs run a programmable graphics pipeline — a fixed sequence of stages where some stages run programs you write (shaders) and others are handled automatically by the hardware. Understanding this pipeline is essential before writing any shader code.
The graphics pipeline
When you ask the GPU to draw a triangle, your data flows through several stages:
The Graphics Pipeline
=====================
CPU (your Rust code)
│
│ Vertex data + draw call
▼
┌─────────────────────┐
│ VERTEX SHADER │ ◄── Programmable (you write this)
│ Transforms each │ Runs once per vertex
│ vertex position │
└────────┬────────────┘
│
▼
┌─────────────────────┐
│ PRIMITIVE │ ◄── Fixed-function (hardware)
│ ASSEMBLY │ Connects vertices into
│ │ triangles, lines, or points
└────────┬────────────┘
│
▼
┌─────────────────────┐
│ RASTERISATION │ ◄── Fixed-function (hardware)
│ Determines which │ Converts triangles into
│ pixels a triangle │ fragments (candidate pixels)
│ covers │
└────────┬────────────┘
│
▼
┌─────────────────────┐
│ FRAGMENT SHADER │ ◄── Programmable (you write this)
│ Computes the │ Runs once per fragment
│ colour of each │ (potential pixel)
│ fragment │
└────────┬────────────┘
│
▼
┌─────────────────────┐
│ OUTPUT MERGER │ ◄── Fixed-function (hardware)
│ Depth test, blend │ Combines fragments into
│ with framebuffer │ the final image
└─────────────────────┘
The vertex shader
The vertex shader runs once for every vertex you submit. Its primary job is to transform vertex positions from model space (the coordinates you defined your mesh in) to clip space (the coordinate system the GPU uses to determine what is on screen).
A vertex shader typically receives input data — position, colour, texture coordinates — and outputs a transformed position plus any data that should be passed to the fragment shader.
For example, a vertex shader might:
- Multiply the vertex position by a model-view-projection matrix
- Pass the vertex colour through to the next stage
- Compute lighting values at each vertex
Rasterisation
After the vertex shader runs and the GPU assembles vertices into triangles, rasterisation determines which screen pixels each triangle covers. This is not programmable — the hardware handles it automatically.
For each pixel covered by a triangle, the rasteriser generates a fragment. A fragment is a candidate pixel — it carries interpolated values from the triangle’s vertices (we will explore interpolation in detail in section 8).
The fragment shader
The fragment shader runs once for every fragment produced by rasterisation. Its job is to determine the final colour of that pixel. This is where most of the visual magic happens: texturing, lighting, shadows, reflections, and special effects are all implemented in the fragment shader.
The fragment shader receives interpolated data from the vertex shader (like colour or texture coordinates) and outputs a colour value, typically as an RGBA (red, green, blue, alpha) tuple.
Compute shaders: a separate path
Compute shaders do not participate in the graphics pipeline at all. They are general-purpose programs that run on the GPU, independent of any rendering. You dispatch them with explicit work-group sizes and they can read from and write to buffers and textures.
Compute Pipeline (independent of graphics)
==========================================
CPU (your Rust code)
│
│ Dispatch (work group counts)
▼
┌─────────────────────┐
│ COMPUTE SHADER │ ◄── Programmable (you write this)
│ General-purpose │ Runs once per invocation
│ parallel work │ across work groups
└─────────────────────┘
│
▼
Output buffers / textures
Compute shaders are used for physics simulations, image processing, machine learning inference, procedural generation, and any task that benefits from massive parallelism but does not need the rasterisation pipeline.
Key takeaway: The GPU has two paths for running your code. The graphics pipeline flows from vertex shader through rasterisation to fragment shader, producing pixels on screen. The compute pipeline is a separate, general-purpose path for parallel computation. You will write programs for all three shader types in this course.
3. What is WGSL? Syntax overview
WGSL (WebGPU Shading Language) is the shader language used by the WebGPU API — and by extension, by wgpu. If you have used GLSL or HLSL before, WGSL will feel familiar but with a more explicit, Rust-influenced syntax. If you are new to shader languages, this section covers everything you need to get started.
Scalar types
WGSL provides a small set of scalar types:
| Type | Description |
|---|---|
f32 | 32-bit floating point |
f16 | 16-bit floating point (optional feature) |
i32 | 32-bit signed integer |
u32 | 32-bit unsigned integer |
bool | Boolean |
Vector types
Vectors are fundamental in shader programming. WGSL supports vectors of 2, 3, or 4 components:
var a: vec2<f32> = vec2<f32>(1.0, 2.0);
var b: vec3<f32> = vec3<f32>(1.0, 0.0, 0.0); // a red colour or a direction
var c: vec4<f32> = vec4<f32>(0.2, 0.4, 0.8, 1.0); // RGBA colour
// Shorthand constructors (type inference):
var d = vec3f(1.0, 0.0, 0.0); // vec3<f32>
var e = vec4f(0.0, 0.0, 0.0, 1.0);
You can access components with swizzling:
var color = vec4f(1.0, 0.5, 0.2, 1.0);
var rgb = color.rgb; // vec3f(1.0, 0.5, 0.2)
var rr = color.xx; // vec2f(1.0, 1.0)
Components can be accessed as x/y/z/w or r/g/b/a — they are interchangeable aliases.
Matrix types
Matrices are used for transformations (rotation, scaling, projection):
// A 4x4 matrix of f32 values (4 columns, 4 rows)
var transform: mat4x4<f32>;
// A 3x3 matrix
var rotation: mat3x3<f32>;
Matrix-vector multiplication uses the * operator: transform * vec4f(pos, 1.0).
Variables: let vs var
// `let` declares an immutable binding (like Rust's `let`)
let pi = 3.14159;
// `var` declares a mutable variable
var counter: u32 = 0u;
counter = counter + 1u;
Structs
Structs group related data, and they are used extensively for shader inputs and outputs:
struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) color: vec3<f32>,
}
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) color: vec3<f32>,
}
The @location(n) attribute links struct fields to specific slots in the vertex buffer layout or inter-stage communication. The @builtin(position) attribute tells the GPU this field is the clip-space position.
Functions and entry points
WGSL functions look like this:
fn add(a: f32, b: f32) -> f32 {
return a + b;
}
Entry points are functions marked with a stage attribute:
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
var out: VertexOutput;
out.clip_position = vec4f(in.position, 1.0);
out.color = in.color;
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return vec4f(in.color, 1.0);
}
@compute @workgroup_size(64)
fn cs_main(@builtin(global_invocation_id) id: vec3<u32>) {
// compute work here
}
The @location(0) on the fragment shader return type means “write to the first colour attachment” (the render target).
Built-in attributes
Some commonly used built-in attributes:
| Attribute | Stage | Meaning |
|---|---|---|
@builtin(position) | Vertex out / Fragment in | Clip-space position / fragment coordinates |
@builtin(vertex_index) | Vertex | Index of the current vertex |
@builtin(instance_index) | Vertex | Index of the current instance |
@builtin(global_invocation_id) | Compute | 3D index of this thread in the dispatch |
@builtin(local_invocation_id) | Compute | 3D index within the work group |
Binding resources
Uniforms, storage buffers, textures, and samplers are declared at module scope with @group and @binding attributes:
@group(0) @binding(0)
var<uniform> time: f32;
@group(0) @binding(1)
var texture: texture_2d<f32>;
@group(0) @binding(2)
var tex_sampler: sampler;
The @group(n) corresponds to a bind group index, and @binding(n) is the binding within that group. These must match the bind group layout you define on the Rust side.
Control flow
WGSL supports if/else, for, while, loop, switch, break, continue, and return:
for (var i: u32 = 0u; i < 10u; i = i + 1u) {
if i == 5u {
continue;
}
// do work
}
Key takeaway: WGSL’s syntax is a blend of Rust and C-family languages. Types are explicit, entry points are marked with stage attributes (
@vertex,@fragment,@compute), and data flows between stages via structs annotated with@locationand@builtin. You will write WGSL for every exercise in this course.
Part 2 — Setting Up with wgpu
4. What is wgpu? Cross-platform graphics API in Rust
wgpu is a Rust crate that implements the WebGPU API specification. It provides a safe, cross-platform interface for GPU programming that works on multiple backends:
| Backend | Platform |
|---|---|
| Vulkan | Linux, Windows, Android |
| Metal | macOS, iOS |
| DX12 | Windows |
| WebGPU | Web browsers (via wasm) |
| OpenGL | Fallback for older systems |
This means you write your GPU code once and it runs everywhere — on desktop, on mobile, and in the browser.
Why not raw Vulkan/Metal/DX12?
Writing directly against a low-level graphics API like Vulkan requires thousands of lines of boilerplate before you can draw a single triangle. Vulkan’s explicit nature gives you maximum control, but the complexity is enormous. wgpu provides a higher-level abstraction that handles the platform differences and much of the boilerplate while still being close enough to the metal for serious work.
Key types in wgpu
Here are the core types you will interact with, in the order you typically create them:
Initialization Flow
====================
Instance
│
│ enumerate adapters
▼
Adapter ←── represents a physical GPU
│
│ request device
▼
Device + Queue
│ │
│ │ submit commands
│ ▼
│ (GPU execution)
│
│ create resources
▼
Buffers, Textures, Pipelines, Bind Groups, ...
Instance: the entry point to wgpu. Created first, used to find adapters and create surfaces.Surface: a handle to a window’s drawable area. Created from a window (provided by a windowing library likewinit).Adapter: represents a physical GPU. You request one from the instance, optionally specifying preferences (power preference, compatibility with your surface).Device: a logical connection to the GPU. You create resources (buffers, textures, pipelines) through the device. Think of it as an open connection to the GPU.Queue: used to submit work (command buffers) to the GPU. You get a queue together with the device.CommandEncoder: records GPU commands (render passes, compute dispatches, buffer copies) into a command buffer. The command buffer is then submitted to the queue.RenderPipeline: describes the full configuration for rendering — which shaders to use, vertex layout, blending mode, pixel format, etc.Buffer: a block of GPU-accessible memory. Used for vertex data, index data, uniforms, storage, etc.BindGroup: a collection of resources (buffers, textures, samplers) that are made available to shaders. Corresponds to@group(n)in WGSL.
The initialisation sequence in code
Here is a simplified view of wgpu initialisation (we will see the full code in Exercise 1):
#![allow(unused)]
fn main() {
// 1. Create an instance
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
// 2. Create a surface from a window
let surface = instance.create_surface(&window)?;
// 3. Request an adapter (physical GPU)
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::default(),
compatible_surface: Some(&surface),
force_fallback_adapter: false,
})
.await
.unwrap();
// 4. Request a device and queue
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor::default(), None)
.await
.unwrap();
// 5. Configure the surface
let config = surface.get_default_config(&adapter, width, height).unwrap();
surface.configure(&device, &config);
}
After this, you are ready to create pipelines, buffers, and start rendering.
Key takeaway: wgpu is a cross-platform GPU abstraction for Rust. You create an
Instance, get anAdapter(physical GPU), open aDevice+Queue, and then create resources and submit commands. This same code works on Vulkan, Metal, DX12, and WebGPU.
5. Exercise 1: create a window and clear it to a colour
In this exercise you will create a window using winit, initialise wgpu, and fill the window with a solid colour (cornflower blue). This is the “hello world” of GPU programming.
Step 1: project setup
Create a new Rust project and add the required dependencies to Cargo.toml:
[package]
name = "shader-exercises"
version = "0.1.0"
edition = "2021"
[dependencies]
wgpu = "24"
winit = "30"
pollster = "0.4"
log = "0.4"
env_logger = "0.11"
[profile.release]
opt-level = "z"
lto = true
strip = true
codegen-units = 1
wgpu: the GPU abstraction layerwinit: cross-platform window creation and event handlingpollster: a minimal async executor to block on futures (wgpu uses async for initialisation)env_logger: so wgpu can report errors and warnings
Step 2: the complete code
use winit::{
application::ApplicationHandler,
event::WindowEvent,
event_loop::EventLoop,
window::{Window, WindowAttributes},
};
use std::sync::Arc;
/// Holds all wgpu state needed for rendering.
struct GpuState {
surface: wgpu::Surface<'static>,
device: wgpu::Device,
queue: wgpu::Queue,
config: wgpu::SurfaceConfiguration,
}
/// The main application struct.
struct App {
window: Option<Arc<Window>>,
gpu: Option<GpuState>,
}
impl App {
fn new() -> Self {
Self {
window: None,
gpu: None,
}
}
/// Initialise wgpu with the given window.
fn init_gpu(&mut self, window: Arc<Window>) {
let size = window.inner_size();
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
let surface = instance.create_surface(window.clone()).unwrap();
let adapter = pollster::block_on(instance.request_adapter(
&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::default(),
compatible_surface: Some(&surface),
force_fallback_adapter: false,
},
))
.expect("Failed to find a suitable GPU adapter");
let (device, queue) = pollster::block_on(adapter.request_device(
&wgpu::DeviceDescriptor::default(),
None,
))
.expect("Failed to create device");
let config = surface
.get_default_config(&adapter, size.width.max(1), size.height.max(1))
.expect("Surface is not supported by the adapter");
surface.configure(&device, &config);
self.gpu = Some(GpuState {
surface,
device,
queue,
config,
});
}
/// Render a single frame: clear the screen to cornflower blue.
fn render(&self) {
let gpu = self.gpu.as_ref().unwrap();
// Get the next frame's texture to draw on
let output = gpu.surface.get_current_texture()
.expect("Failed to get surface texture");
let view = output.texture.create_view(&Default::default());
// Create a command encoder to record GPU commands
let mut encoder = gpu.device.create_command_encoder(
&wgpu::CommandEncoderDescriptor {
label: Some("Clear Encoder"),
},
);
// Begin a render pass that clears to cornflower blue
{
let _render_pass = encoder.begin_render_pass(
&wgpu::RenderPassDescriptor {
label: Some("Clear Pass"),
color_attachments: &[Some(
wgpu::RenderPassColorAttachment {
view: &view,
resolve_target: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Clear(
wgpu::Color {
r: 0.392,
g: 0.584,
b: 0.929,
a: 1.0,
},
),
store: wgpu::StoreOp::Store,
},
},
)],
depth_stencil_attachment: None,
..Default::default()
},
);
// The render pass is dropped here, ending it
}
// Submit the commands to the GPU
gpu.queue.submit(std::iter::once(encoder.finish()));
// Present the frame on screen
output.present();
}
}
impl ApplicationHandler for App {
fn resumed(&mut self, event_loop: &winit::event_loop::ActiveEventLoop) {
if self.window.is_none() {
let attrs = WindowAttributes::default()
.with_title("Exercise 1: Cornflower Blue");
let window = Arc::new(
event_loop.create_window(attrs).unwrap()
);
self.init_gpu(window.clone());
self.window = Some(window);
}
}
fn window_event(
&mut self,
event_loop: &winit::event_loop::ActiveEventLoop,
_window_id: winit::window::WindowId,
event: WindowEvent,
) {
match event {
WindowEvent::CloseRequested => {
event_loop.exit();
}
WindowEvent::Resized(new_size) => {
if let Some(gpu) = &mut self.gpu {
gpu.config.width = new_size.width.max(1);
gpu.config.height = new_size.height.max(1);
gpu.surface.configure(&gpu.device, &gpu.config);
}
}
WindowEvent::RedrawRequested => {
self.render();
if let Some(window) = &self.window {
window.request_redraw();
}
}
_ => {}
}
}
}
fn main() {
env_logger::init();
let event_loop = EventLoop::new().unwrap();
let mut app = App::new();
event_loop.run_app(&mut app).unwrap();
}
Step 3: run it
cargo run
You should see a window filled with cornflower blue (a pleasant mid-blue, rgb(100, 149, 237)). The window responds to resizing and closes when you click the close button.
What just happened?
Let’s break down the key parts:
- Window creation:
winitcreates a native window. We wrap it inArcso wgpu can reference it. - Surface: created from the window — this is where rendered frames go.
- Adapter + Device + Queue: we find a GPU, open a logical device, and get a command queue.
- Surface configuration: tells the surface what pixel format and size to use.
- Render loop: every frame we create a
CommandEncoder, begin aRenderPasswith a clear colour, end the pass, submit commands, and present.
The clear colour is specified as wgpu::Color { r, g, b, a } with values in the 0.0-1.0 range.
Try this: change the colour to something else — pure red (1.0, 0.0, 0.0, 1.0), bright green, or your favourite colour. Rebuild and see the change.
6. The render loop: swap chains, frames, command encoders
Now that you have a working window, let’s dive deeper into what happens each frame. Understanding the render loop is crucial because every shader program you write will run inside this cycle.
The frame lifecycle
Every frame follows the same sequence. Here is what happens between one screen update and the next:
Frame Lifecycle
===============
Time ──────────────────────────────────────────────────►
┌──── Frame N ─────────────────────┐ ┌── Frame N+1 ──
│ │ │
│ 1. Acquire 2. Record 3. Submit 4. Present
│ surface commands to to
│ texture (render queue screen
│ pass)
│ │ │
│ CPU CPU │ GPU executes
│ side side │ asynchronously
└──────────────────────────────────┘ └────────────────
Step 1: acquire a surface texture
#![allow(unused)]
fn main() {
let output = surface.get_current_texture()?;
let view = output.texture.create_view(&Default::default());
}
The surface manages a small pool of textures (typically 2-3, called a swap chain). When you call get_current_texture(), you receive the next available texture to draw on. While you are drawing on texture A, the GPU may still be displaying the previous texture B on screen — this is double buffering.
Double Buffering
================
┌────────────┐ ┌────────────┐
│ Texture A │ │ Texture B │
│ (drawing) │ │ (on screen)│
└────────────┘ └────────────┘
▲ ▲
│ │
You render Monitor
into this displays
one now this one
After you present texture A, the roles swap: A goes to the screen and B becomes available for the next frame.
Step 2: record commands with a command encoder
#![allow(unused)]
fn main() {
let mut encoder = device.create_command_encoder(&Default::default());
}
The CommandEncoder is like a tape recorder for GPU commands. You do not execute anything immediately — you record a list of operations, and then submit them all at once. This is called a command buffer model.
Why not execute commands immediately? Because the GPU operates asynchronously. Batching commands into a buffer lets the GPU execute them efficiently without constant back-and-forth with the CPU.
Step 3: begin a render pass
#![allow(unused)]
fn main() {
let render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view: &view,
ops: wgpu::Operations {
load: wgpu::LoadOp::Clear(clear_color),
store: wgpu::StoreOp::Store,
},
..Default::default()
})],
..Default::default()
});
}
A render pass is a sequence of draw commands that all target the same set of attachments (colour textures, depth buffers). Within a render pass, you:
- Set the pipeline
- Bind vertex buffers and bind groups
- Issue draw calls
The load operation specifies what happens to the attachment at the start of the pass. LoadOp::Clear(color) fills it with a solid colour. LoadOp::Load preserves the previous contents.
The store operation specifies what happens at the end. StoreOp::Store keeps the results; StoreOp::Discard throws them away (useful for depth buffers you do not need after the pass).
Step 4: submit and present
#![allow(unused)]
fn main() {
// End the render pass (drop it)
drop(render_pass);
// Finish recording and get a command buffer
let command_buffer = encoder.finish();
// Submit the command buffer to the GPU
queue.submit(std::iter::once(command_buffer));
// Show the rendered frame on screen
output.present();
}
queue.submit() sends the command buffer to the GPU for execution. The GPU processes it asynchronously — your CPU code continues immediately. output.present() tells the surface to display this texture once the GPU finishes rendering to it.
Multiple render passes
You can have multiple render passes in a single frame. This is common for:
- Shadow mapping: render the scene from a light’s perspective (pass 1), then render the final image using the shadow map (pass 2)
- Post-processing: render the scene to a texture (pass 1), then apply a blur filter to that texture and draw the result to the screen (pass 2)
Each pass gets its own begin_render_pass / drop cycle within the same command encoder.
Key takeaway: each frame, you acquire a surface texture, record GPU commands into a command encoder (including one or more render passes), submit the commands to the queue, and present the result. The CPU and GPU work asynchronously — the CPU records commands for the next frame while the GPU executes the current one.
Part 3 — Vertex and Fragment Shaders
7. Vertices, buffers, and the vertex shader
To draw anything beyond a solid colour, you need to send geometry to the GPU. Geometry is made of vertices — points in space that define the corners of triangles. This section explains how vertex data flows from your Rust code to the vertex shader on the GPU.
What is a vertex?
A vertex is a point with associated data. At minimum, a vertex has a position, but it usually carries additional attributes:
Vertex Data (per vertex)
========================
┌──────────────────────────────────────────────┐
│ position: vec3<f32> (x, y, z) │
│ color: vec3<f32> (r, g, b) │
│ uv: vec2<f32> (texture coordinate) │
│ normal: vec3<f32> (surface direction) │
└──────────────────────────────────────────────┘
For a simple coloured triangle, you might have three vertices with position and colour:
#![allow(unused)]
fn main() {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct Vertex {
position: [f32; 3],
color: [f32; 3],
}
const VERTICES: &[Vertex] = &[
Vertex { position: [ 0.0, 0.5, 0.0], color: [1.0, 0.0, 0.0] }, // top, red
Vertex { position: [-0.5, -0.5, 0.0], color: [0.0, 1.0, 0.0] }, // left, green
Vertex { position: [ 0.5, -0.5, 0.0], color: [0.0, 0.0, 1.0] }, // right, blue
];
}
The #[repr(C)] attribute ensures the struct has a predictable memory layout matching what the GPU expects. The bytemuck derives let us safely cast the struct to raw bytes.
Vertex buffers
To get vertex data onto the GPU, you create a vertex buffer:
#![allow(unused)]
fn main() {
use wgpu::util::DeviceExt;
let vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Vertex Buffer"),
contents: bytemuck::cast_slice(VERTICES),
usage: wgpu::BufferUsages::VERTEX,
});
}
This copies the vertex data from CPU memory into GPU memory. The VERTEX usage flag tells wgpu that this buffer will be used as a vertex buffer.
Vertex buffer layout
The GPU does not know the structure of your vertex data — you must describe it with a vertex buffer layout:
#![allow(unused)]
fn main() {
let vertex_layout = wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<Vertex>() as u64,
step_mode: wgpu::VertexStepMode::Vertex,
attributes: &[
// position: 3 floats at offset 0
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x3,
offset: 0,
shader_location: 0,
},
// color: 3 floats at offset 12 bytes (after 3 x f32)
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x3,
offset: 12,
shader_location: 1,
},
],
};
}
This tells the GPU: “each vertex is N bytes apart (array_stride), and within each vertex, location 0 is three floats starting at byte 0, and location 1 is three floats starting at byte 12.”
The shader_location values correspond to @location(n) in your WGSL shader.
How data flows from CPU to vertex shader
CPU Memory GPU Memory Vertex Shader
========== ========== =============
Vertex array copy Vertex buffer read @location(0) position
[pos, color] ──────────► [bytes...] ──────────► @location(1) color
[pos, color] [bytes...]
[pos, color] [bytes...]
The layout descriptor tells the GPU how to
interpret the bytes into typed attributes.
The vertex shader’s job
The vertex shader runs once per vertex. It must output a @builtin(position) value in clip space — a coordinate system where:
xranges from -1 (left) to +1 (right)yranges from -1 (bottom) to +1 (top)zranges from 0 (near) to 1 (far)
Anything outside these ranges is clipped (not drawn).
struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) color: vec3<f32>,
}
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) color: vec3<f32>,
}
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
var out: VertexOutput;
out.clip_position = vec4f(in.position, 1.0);
out.color = in.color;
return out;
}
In this simple shader, the position passes through unchanged (we are already working in clip space). In real applications, you would multiply by a model-view-projection matrix to transform from 3D world coordinates to clip space.
Key takeaway: vertices carry per-point data (position, colour, etc.) packed into a buffer. The vertex buffer layout tells the GPU how to decode the bytes. The vertex shader transforms each vertex’s position into clip space and passes any additional data (like colour) to the next stage.
8. Interpolation and the fragment shader
After the vertex shader has transformed all vertices and the GPU has assembled them into triangles, rasterisation takes over. This section explains what happens between the vertex shader and the fragment shader — the critical concept of interpolation.
From triangles to pixels
Rasterisation determines which pixels on screen fall inside each triangle. For each pixel inside a triangle, the rasteriser generates a fragment. But what data does each fragment carry?
Consider a triangle with a red vertex, a green vertex, and a blue vertex:
Red (1,0,0)
/\
/ \
/ \
/ what \
/ colour \
/ is this \
/ pixel? \
/______________\
Green Blue
(0,1,0) (0,0,1)
A pixel near the red vertex should be mostly red. A pixel exactly in the centre should be an equal mix of red, green, and blue. The GPU computes this automatically using barycentric interpolation.
Barycentric coordinates
Every point inside a triangle can be expressed as a weighted combination of the three vertices. These weights are called barycentric coordinates (w0, w1, w2), where:
- w0 + w1 + w2 = 1.0
- All weights are between 0 and 1
Barycentric Interpolation
=========================
Point P inside triangle ABC:
P = w0 * A + w1 * B + w2 * C
At vertex A: w0=1, w1=0, w2=0 → colour = A.color
At vertex B: w0=0, w1=1, w2=0 → colour = B.color
At vertex C: w0=0, w1=0, w2=1 → colour = C.color
At centre: w0=⅓, w1=⅓, w2=⅓ → colour = average
The GPU performs this interpolation automatically for every field in the VertexOutput struct (except @builtin(position), which is used for rasterisation itself). This means colours, texture coordinates, normals — everything — gets smoothly interpolated across the triangle surface.
The fragment shader
The fragment shader runs once for each fragment generated by rasterisation. It receives the interpolated values from the vertex shader and outputs a colour:
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return vec4f(in.color, 1.0);
}
In this simple example, the fragment shader just passes through the interpolated colour. But you can do much more:
- Sample a texture using interpolated UV coordinates
- Apply lighting calculations using interpolated normals
- Compute procedural patterns based on the fragment position
- Discard fragments to create transparency cutouts
What the fragment receives
The fragment shader’s input looks like the vertex shader’s output, but the values have been interpolated:
Vertex Shader Output Fragment Shader Input
==================== ====================
Vertex 0: color=(1,0,0)
Vertex 1: color=(0,1,0) ──► Fragment at centre:
Vertex 2: color=(0,0,1) color=(0.33, 0.33, 0.33)
The @builtin(position) field in the fragment shader’s input contains the fragment’s window-space coordinates — (x, y) in pixel coordinates. This can be useful for screen-space effects.
Visual result
When you draw our red-green-blue triangle, the interpolation produces a smooth colour gradient:
Expected visual output:
┌───────────────────────┐
│ │
│ ▲ Red │
│ ╱ ╲ │
│ ╱ ╲ │
│ ╱ gra ╲ │
│ ╱ dient ╲ │
│ ╱─────────╲ │
│ Green Blue │
│ │
└───────────────────────┘
Colours blend smoothly across the
triangle surface via interpolation.
Key takeaway: the GPU automatically interpolates all vertex shader outputs across the triangle surface using barycentric coordinates. The fragment shader receives these smoothly interpolated values and uses them to compute the final pixel colour. This is why a triangle with three different vertex colours produces a smooth gradient.
9. Exercise 2: draw a coloured triangle
Time to put theory into practice. In this exercise you will extend the Exercise 1 code to draw a coloured triangle with red, green, and blue vertices.
What you will add
- A WGSL shader with vertex and fragment entry points
- A vertex buffer with three coloured vertices
- A render pipeline that connects everything
- A draw call inside the render pass
Step 1: add bytemuck to Cargo.toml
[dependencies]
wgpu = { version = "24", features = ["wgsl"] }
winit = "30"
pollster = "0.4"
log = "0.4"
env_logger = "0.11"
bytemuck = { version = "1", features = ["derive"] }
Step 2: the WGSL shader
Create a file called shader.wgsl in the same directory as main.rs (or embed it as a string — we will embed it here for simplicity):
// Vertex input: position and colour from the vertex buffer
struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) color: vec3<f32>,
}
// Output from vertex shader, input to fragment shader
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) color: vec3<f32>,
}
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
var out: VertexOutput;
out.clip_position = vec4f(in.position, 1.0);
out.color = in.color;
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return vec4f(in.color, 1.0);
}
Step 3: the Rust code
Below is the complete program. It extends Exercise 1 with a vertex buffer, pipeline, and draw call:
use winit::{
application::ApplicationHandler,
event::WindowEvent,
event_loop::EventLoop,
window::{Window, WindowAttributes},
};
use wgpu::util::DeviceExt;
use std::sync::Arc;
// Vertex data structure — must match the shader's VertexInput
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct Vertex {
position: [f32; 3],
color: [f32; 3],
}
impl Vertex {
/// Describe the memory layout for the GPU.
fn layout() -> wgpu::VertexBufferLayout<'static> {
wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<Vertex>() as u64,
step_mode: wgpu::VertexStepMode::Vertex,
attributes: &[
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x3,
offset: 0,
shader_location: 0, // @location(0) position
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x3,
offset: 12,
shader_location: 1, // @location(1) color
},
],
}
}
}
// Three vertices forming a coloured triangle
const VERTICES: &[Vertex] = &[
Vertex { position: [ 0.0, 0.5, 0.0], color: [1.0, 0.0, 0.0] }, // top — red
Vertex { position: [-0.5, -0.5, 0.0], color: [0.0, 1.0, 0.0] }, // bottom-left — green
Vertex { position: [ 0.5, -0.5, 0.0], color: [0.0, 0.0, 1.0] }, // bottom-right — blue
];
struct GpuState {
surface: wgpu::Surface<'static>,
device: wgpu::Device,
queue: wgpu::Queue,
config: wgpu::SurfaceConfiguration,
pipeline: wgpu::RenderPipeline,
vertex_buffer: wgpu::Buffer,
}
struct App {
window: Option<Arc<Window>>,
gpu: Option<GpuState>,
}
impl App {
fn new() -> Self {
Self { window: None, gpu: None }
}
fn init_gpu(&mut self, window: Arc<Window>) {
let size = window.inner_size();
let instance = wgpu::Instance::new(&Default::default());
let surface = instance.create_surface(window.clone()).unwrap();
let adapter = pollster::block_on(instance.request_adapter(
&wgpu::RequestAdapterOptions {
compatible_surface: Some(&surface),
..Default::default()
},
)).unwrap();
let (device, queue) = pollster::block_on(
adapter.request_device(&Default::default(), None)
).unwrap();
let config = surface
.get_default_config(&adapter, size.width.max(1), size.height.max(1))
.unwrap();
surface.configure(&device, &config);
// Create the shader module from WGSL source
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Triangle Shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
});
// Create the render pipeline
let pipeline_layout = device.create_pipeline_layout(
&wgpu::PipelineLayoutDescriptor {
label: Some("Pipeline Layout"),
bind_group_layouts: &[],
push_constant_ranges: &[],
},
);
let pipeline = device.create_render_pipeline(
&wgpu::RenderPipelineDescriptor {
label: Some("Triangle Pipeline"),
layout: Some(&pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[Vertex::layout()],
compilation_options: Default::default(),
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
targets: &[Some(wgpu::ColorTargetState {
format: config.format,
blend: Some(wgpu::BlendState::REPLACE),
write_mask: wgpu::ColorWrites::ALL,
})],
compilation_options: Default::default(),
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: wgpu::FrontFace::Ccw,
cull_mode: Some(wgpu::Face::Back),
unclipped_depth: false,
polygon_mode: wgpu::PolygonMode::Fill,
conservative: false,
},
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
multiview: None,
cache: None,
},
);
// Create the vertex buffer
let vertex_buffer = device.create_buffer_init(
&wgpu::util::BufferInitDescriptor {
label: Some("Vertex Buffer"),
contents: bytemuck::cast_slice(VERTICES),
usage: wgpu::BufferUsages::VERTEX,
},
);
self.gpu = Some(GpuState {
surface, device, queue, config, pipeline, vertex_buffer,
});
}
fn render(&self) {
let gpu = self.gpu.as_ref().unwrap();
let output = gpu.surface.get_current_texture().unwrap();
let view = output.texture.create_view(&Default::default());
let mut encoder = gpu.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: Some("Triangle Pass"),
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view: &view,
resolve_target: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Clear(wgpu::Color {
r: 0.1, g: 0.1, b: 0.1, a: 1.0,
}),
store: wgpu::StoreOp::Store,
},
})],
depth_stencil_attachment: None,
..Default::default()
});
pass.set_pipeline(&gpu.pipeline);
pass.set_vertex_buffer(0, gpu.vertex_buffer.slice(..));
pass.draw(0..3, 0..1); // 3 vertices, 1 instance
}
gpu.queue.submit(std::iter::once(encoder.finish()));
output.present();
}
}
impl ApplicationHandler for App {
fn resumed(&mut self, event_loop: &winit::event_loop::ActiveEventLoop) {
if self.window.is_none() {
let window = Arc::new(
event_loop.create_window(
WindowAttributes::default().with_title("Exercise 2: Coloured Triangle")
).unwrap()
);
self.init_gpu(window.clone());
self.window = Some(window);
}
}
fn window_event(
&mut self,
event_loop: &winit::event_loop::ActiveEventLoop,
_id: winit::window::WindowId,
event: WindowEvent,
) {
match event {
WindowEvent::CloseRequested => event_loop.exit(),
WindowEvent::Resized(size) => {
if let Some(gpu) = &mut self.gpu {
gpu.config.width = size.width.max(1);
gpu.config.height = size.height.max(1);
gpu.surface.configure(&gpu.device, &gpu.config);
}
}
WindowEvent::RedrawRequested => {
self.render();
if let Some(w) = &self.window { w.request_redraw(); }
}
_ => {}
}
}
}
fn main() {
env_logger::init();
let event_loop = EventLoop::new().unwrap();
event_loop.run_app(&mut App::new()).unwrap();
}
Step 4: run and observe
cargo run
You should see a triangle centred in the window with a smooth gradient: red at the top, green at the bottom-left, and blue at the bottom-right. The colours blend smoothly across the surface thanks to the interpolation discussed in section 8.
Key concepts demonstrated
- Vertex struct with
#[repr(C)]andbytemuckfor safe casting to bytes - Vertex buffer layout mapping struct fields to
@location(n)in the shader - Shader module loaded from WGSL source via
include_str! - Render pipeline connecting shaders, vertex layout, and output format
- Draw call (
pass.draw(0..3, 0..1)) telling the GPU to process 3 vertices as one triangle
Challenge: add three more vertices and draw a second triangle to form a rectangle. You will need 6 vertices total (two triangles of 3 vertices each) and change the draw call to pass.draw(0..6, 0..1).
10. Exercise 3: animate the triangle using a time uniform
Static shapes are nice, but animation is where shaders really shine. In this exercise you will pass an elapsed time value to the vertex shader and use it to rotate the triangle.
New concepts
- Uniform buffers: small, read-only buffers for data that is the same across all vertices/fragments in a draw call (like time, camera matrices, light positions)
- Bind groups: how you connect uniform buffers (and other resources) to shader bindings
- Updating buffers: writing new data to a buffer each frame
Step 1: the updated WGSL shader
struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) color: vec3<f32>,
}
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) color: vec3<f32>,
}
// A uniform buffer containing the elapsed time
@group(0) @binding(0)
var<uniform> time: f32;
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
// Rotate the vertex around the Z axis
let angle = time;
let cos_a = cos(angle);
let sin_a = sin(angle);
let rotated = vec3f(
in.position.x * cos_a - in.position.y * sin_a,
in.position.x * sin_a + in.position.y * cos_a,
in.position.z,
);
var out: VertexOutput;
out.clip_position = vec4f(rotated, 1.0);
out.color = in.color;
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return vec4f(in.color, 1.0);
}
The key change is the time uniform and the 2D rotation matrix applied to each vertex. The rotation is:
x' = x * cos(angle) - y * sin(angle)
y' = x * sin(angle) + y * cos(angle)
This rotates the triangle around the origin (centre of clip space) at one radian per second.
Step 2: create the uniform buffer and bind group
On the Rust side, you need to:
- Create a buffer for the time value
- Create a bind group layout describing the binding
- Create a bind group linking the buffer to the layout
- Update the pipeline layout to include the bind group layout
#![allow(unused)]
fn main() {
use std::time::Instant;
// Create the uniform buffer (4 bytes for one f32)
let time_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Time Uniform Buffer"),
size: std::mem::size_of::<f32>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Create a bind group layout
let bind_group_layout = device.create_bind_group_layout(
&wgpu::BindGroupLayoutDescriptor {
label: Some("Time Bind Group Layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
},
);
// Create the bind group
let time_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Time Bind Group"),
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: time_buffer.as_entire_binding(),
}],
});
// Update the pipeline layout to include our bind group
let pipeline_layout = device.create_pipeline_layout(
&wgpu::PipelineLayoutDescriptor {
label: Some("Animated Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
},
);
}
Step 3: update the buffer each frame
In your render function, before beginning the render pass, write the current time to the buffer:
#![allow(unused)]
fn main() {
let elapsed = self.start_time.elapsed().as_secs_f32();
gpu.queue.write_buffer(&gpu.time_buffer, 0, bytemuck::cast_slice(&[elapsed]));
}
queue.write_buffer copies data from CPU memory into the GPU buffer. This is the simplest way to update a uniform each frame.
Step 4: bind the group in the render pass
Inside your render pass, after setting the pipeline:
#![allow(unused)]
fn main() {
pass.set_pipeline(&gpu.pipeline);
pass.set_bind_group(0, &gpu.time_bind_group, &[]); // group 0
pass.set_vertex_buffer(0, gpu.vertex_buffer.slice(..));
pass.draw(0..3, 0..1);
}
The set_bind_group(0, ...) call makes the time buffer available to the shader as @group(0) @binding(0).
Expected result
When you run the program, you should see the coloured triangle smoothly rotating around the centre of the window. The triangle completes one full rotation every 2*pi (approximately 6.28) seconds.
Understanding the flow
Each frame:
┌─────────────────────────────────────────────────────────────┐
│ │
│ CPU: elapsed = Instant::now() - start │
│ queue.write_buffer(time_buffer, elapsed) │
│ │
│ GPU: time uniform ← time_buffer │
│ for each vertex: │
│ rotated_pos = rotate(vertex.pos, time) │
│ output clip_position = rotated_pos │
│ │
│ Result: triangle rotates smoothly │
└─────────────────────────────────────────────────────────────┘
Challenge: instead of (or in addition to) rotating, try making the triangle pulse in size using sin(time) as a scale factor. Or make it bounce by adding sin(time) * 0.3 to the y position.
Key takeaway: uniform buffers let you pass per-frame data (time, matrices, parameters) to shaders. You create a buffer, describe its layout in a bind group, bind it during the render pass, and access it in WGSL via
@group(n) @binding(n). This is how you make shaders dynamic.
Part 4 — Textures and Samplers
11. Texture coordinates (UVs), texture creation, sampler config
Solid colours and gradients are a start, but most real-world graphics use textures — images mapped onto surfaces. This section explains how textures work, how UV coordinates map image data onto geometry, and how samplers control the lookup.
What are UV coordinates?
UV coordinates (also called texture coordinates) describe where on a texture each vertex should sample from. They range from (0, 0) at the top-left of the texture to (1, 1) at the bottom-right:
Texture UV Space
================
(0,0)────────────────(1,0)
│ │
│ ┌──────────┐ │
│ │ │ │
│ │ image │ │
│ │ data │ │
│ │ │ │
│ └──────────┘ │
│ │
(0,1)────────────────(1,1)
Note: in wgpu/WebGPU, (0,0) is the top-left
and v increases downward.
Each vertex carries a UV coordinate. When the GPU rasterises a triangle, it interpolates these UVs across the surface (just like it interpolates colours). The fragment shader then uses the interpolated UV to look up a colour from the texture.
Quad with UV mapping
====================
Vertex 0: pos=(-0.5, 0.5) uv=(0, 0) ← top-left
Vertex 1: pos=( 0.5, 0.5) uv=(1, 0) ← top-right
Vertex 2: pos=(-0.5,-0.5) uv=(0, 1) ← bottom-left
Vertex 3: pos=( 0.5,-0.5) uv=(1, 1) ← bottom-right
The full texture maps exactly onto the quad.
Creating a texture in wgpu
To use a texture, you need to:
- Create the texture on the GPU
- Upload the image data
- Create a texture view for accessing it in shaders
- Create a sampler that controls how texels are looked up
#![allow(unused)]
fn main() {
// Step 1: Create the texture
let texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("My Texture"),
size: wgpu::Extent3d {
width: img_width,
height: img_height,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: wgpu::TextureFormat::Rgba8UnormSrgb,
usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
view_formats: &[],
});
// Step 2: Upload the pixel data
queue.write_texture(
wgpu::TexelCopyTextureInfo {
texture: &texture,
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
&rgba_bytes, // &[u8] of RGBA pixel data
wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(4 * img_width),
rows_per_image: Some(img_height),
},
wgpu::Extent3d {
width: img_width,
height: img_height,
depth_or_array_layers: 1,
},
);
// Step 3: Create a view
let texture_view = texture.create_view(&Default::default());
}
Sampler configuration
A sampler controls how the GPU looks up texels (texture pixels) when the UV does not land exactly on a texel centre. There are two key settings:
Filtering controls how texels are blended:
Nearest: picks the closest texel (pixelated look, fast)Linear: blends the four nearest texels (smooth look)
Nearest filtering Linear filtering
================== =================
┌───┬───┬───┐ ┌───┬───┬───┐
│ A │ B │ │ │ A │ B │ │
├───┼───┼───┤ ├───┼╌╌╌┼───┤
│ C │ D │ │ │ C │avg│ │
├───┼───┼───┤ ├───┼───┼───┤
│ │ │ │ │ │ │ │
└───┴───┴───┘ └───┴───┴───┘
Nearest: picks one Linear: blends A,B,C,D
texel (e.g., A) based on distance
Address mode (wrapping) controls what happens when UVs go outside the 0-1 range:
ClampToEdge: UVs outside 0-1 use the edge colourRepeat: the texture tilesMirrorRepeat: the texture tiles, flipping every other repetition
#![allow(unused)]
fn main() {
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("Texture Sampler"),
address_mode_u: wgpu::AddressMode::ClampToEdge,
address_mode_v: wgpu::AddressMode::ClampToEdge,
address_mode_w: wgpu::AddressMode::ClampToEdge,
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::FilterMode::Nearest,
..Default::default()
});
}
Bind groups for textures
Textures and samplers are bound to shaders using bind groups, just like uniform buffers:
#![allow(unused)]
fn main() {
let bind_group_layout = device.create_bind_group_layout(
&wgpu::BindGroupLayoutDescriptor {
label: Some("Texture Bind Group Layout"),
entries: &[
// The texture
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
// The sampler
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Sampler(
wgpu::SamplerBindingType::Filtering,
),
count: None,
},
],
},
);
}
In WGSL, you access them like this:
@group(0) @binding(0)
var t_diffuse: texture_2d<f32>;
@group(0) @binding(1)
var s_diffuse: sampler;
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return textureSample(t_diffuse, s_diffuse, in.uv);
}
The textureSample function takes a texture, a sampler, and UV coordinates, and returns the sampled colour.
Key takeaway: textures are images stored on the GPU. UV coordinates map texture space onto geometry. Samplers control filtering (nearest vs linear) and wrapping behaviour. The fragment shader uses
textureSampleto look up a colour from the texture at interpolated UV coordinates.
12. Exercise 4: render a textured quad
In this exercise you will draw a rectangle (two triangles forming a quad) with a texture mapped onto it. You will create a procedural checkerboard texture in code rather than loading an image file, keeping the exercise self-contained.
Step 1: add dependencies
We do not need an image loading crate for this exercise since we generate the texture procedurally. The same Cargo.toml from Exercise 2 works, with bytemuck already included.
Step 2: the WGSL shader
struct VertexInput {
@location(0) position: vec3<f32>,
@location(1) uv: vec2<f32>,
}
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) uv: vec2<f32>,
}
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
var out: VertexOutput;
out.clip_position = vec4f(in.position, 1.0);
out.uv = in.uv;
return out;
}
@group(0) @binding(0)
var t_texture: texture_2d<f32>;
@group(0) @binding(1)
var s_sampler: sampler;
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
return textureSample(t_texture, s_sampler, in.uv);
}
Note how the vertex now carries a vec2<f32> UV coordinate instead of a colour. The fragment shader samples the texture at the interpolated UV.
Step 3: vertex data for a quad
A quad is two triangles. We define six vertices (or four vertices with an index buffer — we will use six for simplicity):
#![allow(unused)]
fn main() {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct Vertex {
position: [f32; 3],
uv: [f32; 2],
}
// Two triangles forming a quad
const VERTICES: &[Vertex] = &[
// Triangle 1 (top-left half)
Vertex { position: [-0.5, 0.5, 0.0], uv: [0.0, 0.0] }, // top-left
Vertex { position: [-0.5, -0.5, 0.0], uv: [0.0, 1.0] }, // bottom-left
Vertex { position: [ 0.5, 0.5, 0.0], uv: [1.0, 0.0] }, // top-right
// Triangle 2 (bottom-right half)
Vertex { position: [ 0.5, 0.5, 0.0], uv: [1.0, 0.0] }, // top-right
Vertex { position: [-0.5, -0.5, 0.0], uv: [0.0, 1.0] }, // bottom-left
Vertex { position: [ 0.5, -0.5, 0.0], uv: [1.0, 1.0] }, // bottom-right
];
}
Step 4: generate a procedural checkerboard texture
#![allow(unused)]
fn main() {
/// Generate an 8x8 checkerboard pattern as RGBA bytes.
fn make_checkerboard(width: u32, height: u32, cell_size: u32) -> Vec<u8> {
let mut pixels = Vec::with_capacity((width * height * 4) as usize);
for y in 0..height {
for x in 0..width {
let is_white = ((x / cell_size) + (y / cell_size)) % 2 == 0;
let val = if is_white { 255u8 } else { 80u8 };
pixels.push(val); // R
pixels.push(val); // G
pixels.push(val); // B
pixels.push(255); // A
}
}
pixels
}
}
Call it with make_checkerboard(256, 256, 32) to get a 256x256 texture with 32-pixel checker cells.
Step 5: create the texture, view, and sampler
#![allow(unused)]
fn main() {
let tex_size = 256u32;
let tex_data = make_checkerboard(tex_size, tex_size, 32);
let texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("Checkerboard Texture"),
size: wgpu::Extent3d {
width: tex_size,
height: tex_size,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: wgpu::TextureFormat::Rgba8UnormSrgb,
usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
view_formats: &[],
});
queue.write_texture(
wgpu::TexelCopyTextureInfo {
texture: &texture,
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
&tex_data,
wgpu::TexelCopyBufferLayout {
offset: 0,
bytes_per_row: Some(4 * tex_size),
rows_per_image: Some(tex_size),
},
wgpu::Extent3d {
width: tex_size,
height: tex_size,
depth_or_array_layers: 1,
},
);
let texture_view = texture.create_view(&Default::default());
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("Checkerboard Sampler"),
mag_filter: wgpu::FilterMode::Nearest, // crisp pixels for checkerboard
min_filter: wgpu::FilterMode::Nearest,
..Default::default()
});
}
Step 6: bind group setup
#![allow(unused)]
fn main() {
let bind_group_layout = device.create_bind_group_layout(
&wgpu::BindGroupLayoutDescriptor {
label: Some("Texture Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
},
],
},
);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Texture Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(&texture_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::Sampler(&sampler),
},
],
});
}
Remember to include &bind_group_layout in your pipeline layout’s bind_group_layouts array, and update the vertex buffer layout to match the new Vertex struct (position: Float32x3 at offset 0, uv: Float32x2 at offset 12).
Step 7: draw the quad
In your render pass:
#![allow(unused)]
fn main() {
pass.set_pipeline(&gpu.pipeline);
pass.set_bind_group(0, &gpu.bind_group, &[]);
pass.set_vertex_buffer(0, gpu.vertex_buffer.slice(..));
pass.draw(0..6, 0..1); // 6 vertices = 2 triangles = 1 quad
}
Expected result
You should see a rectangle in the centre of the window showing a black-and-white checkerboard pattern. The texture is mapped so that the full checkerboard fills the quad exactly.
Challenge: try changing the sampler’s mag_filter from Nearest to Linear and see how the checkerboard edges become blurred when the quad is large. Then try setting address_mode_u and address_mode_v to Repeat, and change the UVs to go from 0 to 3 — you will see the checkerboard tile three times across the quad.
Key takeaway: texturing involves creating a texture from pixel data, configuring a sampler for filtering and wrapping, binding both via a bind group, and sampling in the fragment shader using interpolated UV coordinates. This same pattern applies whether your texture is a checkerboard, a photograph, or a render target from a previous pass.
Part 5 — Compute Shaders
13. Compute pipelines: dispatching work groups
Compute shaders break free from the graphics pipeline entirely. There are no vertices, no triangles, no pixels — just raw parallel computation. This makes them ideal for physics simulations, image processing, data transformations, and any task that benefits from GPU parallelism.
Graphics pipeline vs compute pipeline
Graphics Pipeline Compute Pipeline
================= ================
Vertices Dispatch(x, y, z)
│ │
▼ ▼
Vertex Shader Compute Shader
│ │
▼ ▼
Rasterisation Storage buffers /
│ textures (output)
▼
Fragment Shader
│
▼
Framebuffer (pixels)
Produces images Produces data
With compute shaders, you do not set up vertex buffers, render passes, or colour attachments. Instead, you dispatch work and let the compute shader read/write storage buffers or textures directly.
Work groups and invocations
When you dispatch a compute shader, you specify a 3D grid of work groups. Each work group contains a fixed number of invocations (threads), defined by @workgroup_size in the shader.
Dispatch and Work Groups
========================
dispatch(4, 3, 1) ← 4 x 3 x 1 = 12 work groups
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ WG │ │ WG │ │ WG │ │ WG │ row 0
│(0,0)│ │(1,0)│ │(2,0)│ │(3,0)│
└─────┘ └─────┘ └─────┘ └─────┘
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ WG │ │ WG │ │ WG │ │ WG │ row 1
│(0,1)│ │(1,1)│ │(2,1)│ │(3,1)│
└─────┘ └─────┘ └─────┘ └─────┘
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ WG │ │ WG │ │ WG │ │ WG │ row 2
│(0,2)│ │(1,2)│ │(2,2)│ │(3,2)│
└─────┘ └─────┘ └─────┘ └─────┘
Inside each work group (e.g., @workgroup_size(8, 8, 1)):
┌─┬─┬─┬─┬─┬─┬─┬─┐
│·│·│·│·│·│·│·│·│ 8 invocations wide
├─┼─┼─┼─┼─┼─┼─┼─┤
│·│·│·│·│·│·│·│·│ x 8 invocations tall
├─┼─┼─┼─┼─┼─┼─┼─┤
│·│·│·│·│·│·│·│·│ = 64 invocations per
├─┼─┼─┼─┼─┼─┼─┼─┤ work group
│·│·│·│·│·│·│·│·│
├─┼─┼─┼─┼─┼─┼─┼─┤
│·│·│·│·│·│·│·│·│
├─┼─┼─┼─┼─┼─┼─┼─┤
│·│·│·│·│·│·│·│·│
├─┼─┼─┼─┼─┼─┼─┼─┤
│·│·│·│·│·│·│·│·│
├─┼─┼─┼─┼─┼─┼─┼─┤
│·│·│·│·│·│·│·│·│
└─┴─┴─┴─┴─┴─┴─┴─┘
Total invocations = 12 work groups x 64 = 768 threads
Built-in IDs
Each invocation knows its position in the grid via built-in variables:
| Built-in | Type | Meaning |
|---|---|---|
global_invocation_id | vec3<u32> | Unique ID across the entire dispatch |
local_invocation_id | vec3<u32> | ID within the work group (0 to workgroup_size-1) |
workgroup_id | vec3<u32> | Which work group this invocation belongs to |
num_workgroups | vec3<u32> | Total number of work groups dispatched |
global_invocation_id is the most commonly used — it gives each thread a unique index.
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
// Process element at `index`
}
Choosing workgroup_size
The @workgroup_size(x, y, z) declaration sets how many invocations run per work group. Guidelines:
- Total invocations per group (x * y * z) should be a multiple of 32 or 64 for best performance (matching GPU warp/wavefront size)
- Common choices:
@workgroup_size(64),@workgroup_size(256),@workgroup_size(8, 8)(for 2D),@workgroup_size(4, 4, 4)(for 3D) - The maximum total varies by GPU but is typically 256 or 1024
Creating a compute pipeline in Rust
#![allow(unused)]
fn main() {
let compute_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Compute Shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let compute_pipeline = device.create_compute_pipeline(
&wgpu::ComputePipelineDescriptor {
label: Some("Compute Pipeline"),
layout: Some(&pipeline_layout),
module: &compute_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
},
);
}
Dispatching work
Instead of a render pass, you use a compute pass:
#![allow(unused)]
fn main() {
let mut encoder = device.create_command_encoder(&Default::default());
{
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
compute_pass.set_pipeline(&compute_pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(num_groups_x, num_groups_y, num_groups_z);
}
queue.submit(std::iter::once(encoder.finish()));
}
If you have 1024 elements and your workgroup_size is 64, you dispatch 1024 / 64 = 16 work groups: dispatch_workgroups(16, 1, 1).
Key takeaway: compute shaders run outside the graphics pipeline. You dispatch a 3D grid of work groups, each containing a fixed number of invocations. Every invocation gets a unique
global_invocation_idto determine which data element to process. This is how you harness the GPU’s parallelism for general-purpose computation.
14. Storage buffers and read/write access from WGSL
Compute shaders need to read input data and write output data. Storage buffers are the primary mechanism for this. Unlike uniform buffers (which are small and read-only), storage buffers can be large and support both reading and writing.
Storage buffers vs uniform buffers
| Feature | Uniform Buffer | Storage Buffer |
|---|---|---|
| Max size | ~64 KB (varies) | Hundreds of MB |
| Access | Read-only | Read-only or read-write |
| Speed | Faster (cached aggressively) | Slightly slower |
| Use case | Small, per-frame constants | Large data arrays |
Use uniform buffers for things like transformation matrices, time values, and camera parameters. Use storage buffers for arrays of particles, pixels, mesh data, or any large dataset.
Declaring storage buffers in WGSL
// Read-only storage buffer
@group(0) @binding(0)
var<storage, read> input: array<f32>;
// Read-write storage buffer
@group(0) @binding(1)
var<storage, read_write> output: array<f32>;
You can also use structs:
struct Particle {
position: vec2<f32>,
velocity: vec2<f32>,
}
@group(0) @binding(0)
var<storage, read_write> particles: array<Particle>;
Accessing storage buffer data
Storage buffers behave like regular arrays in WGSL:
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let i = id.x;
// Bounds check — important when dispatch size
// does not evenly divide the data
if i >= arrayLength(&particles) {
return;
}
// Read
let pos = particles[i].position;
let vel = particles[i].velocity;
// Compute
let new_pos = pos + vel * delta_time;
// Write back
particles[i].position = new_pos;
}
The arrayLength(&buffer) function returns the number of elements in a runtime-sized array. Always use it for bounds checking — if your dispatch creates more invocations than data elements, the extra threads must bail out early.
Creating storage buffers in Rust
#![allow(unused)]
fn main() {
// Create a storage buffer from initial data
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Particle Buffer"),
contents: bytemuck::cast_slice(&initial_particles),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC // to read back to CPU
| wgpu::BufferUsages::COPY_DST, // to write from CPU
});
}
The STORAGE usage flag is required. Add COPY_SRC if you want to read data back to the CPU, and COPY_DST if you want to upload data from the CPU.
Bind group layout for storage buffers
#![allow(unused)]
fn main() {
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage {
read_only: false, // true for read-only access
},
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
Reading results back to the CPU
GPU buffers are not directly accessible from CPU memory. To read results back, you copy to a staging buffer with MAP_READ usage:
#![allow(unused)]
fn main() {
// Create a staging buffer
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: storage_buffer.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Copy from storage to staging
encoder.copy_buffer_to_buffer(
&storage_buffer, 0,
&staging_buffer, 0,
storage_buffer.size(),
);
queue.submit(std::iter::once(encoder.finish()));
// Map the staging buffer and read the data
let slice = staging_buffer.slice(..);
slice.map_async(wgpu::MapMode::Read, |_| {});
device.poll(wgpu::Maintain::Wait);
let data = slice.get_mapped_range();
let result: &[Particle] = bytemuck::cast_slice(&data);
// Use the result...
drop(data);
staging_buffer.unmap();
}
Memory considerations
- Workgroup memory: WGSL also supports
var<workgroup>for shared memory within a work group. This is very fast but limited in size (typically 16-48 KB). - Synchronization: within a work group, use
workgroupBarrier()to ensure all threads have finished writing before any thread reads shared data. Across work groups, there is no synchronization within a single dispatch — use separate dispatches if you need global barriers.
var<workgroup> shared_data: array<f32, 64>;
@compute @workgroup_size(64)
fn main(@builtin(local_invocation_id) lid: vec3<u32>) {
shared_data[lid.x] = some_computation();
workgroupBarrier(); // wait for all threads in this group
let neighbour = shared_data[(lid.x + 1u) % 64u];
}
Key takeaway: storage buffers are the workhorse of compute shaders — they hold large arrays that shaders can read and write. Declare them with
var<storage, read_write>in WGSL, create them withBufferUsages::STORAGEin Rust, and always bounds-check witharrayLength. To read results back to CPU, copy to a staging buffer withMAP_READ.
15. Exercise 5: GPU-accelerate a particle simulation
In this exercise you will build a simple particle system where thousands of particles are updated each frame by a compute shader. Particles will have positions and velocities, bounce off the edges of the screen, and be rendered as points.
Overview
The architecture is:
┌──────────────┐ ┌──────────────────┐ ┌─────────────┐
│ CPU: init │────►│ GPU: compute pass │────►│ GPU: render │
│ particles │ │ update positions │ │ pass: draw │
│ once │ │ each frame │ │ as points │
└──────────────┘ └──────────────────┘ └─────────────┘
│
▼
Storage buffer
(read/write by
compute shader,
read as vertex
buffer for render)
The same buffer serves double duty: the compute shader writes updated positions into it, and the render pass reads it as a vertex buffer.
Step 1: particle data structure
#![allow(unused)]
fn main() {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct Particle {
position: [f32; 2],
velocity: [f32; 2],
}
}
Step 2: initialise particles
#![allow(unused)]
fn main() {
use rand::Rng;
fn create_particles(count: usize) -> Vec<Particle> {
let mut rng = rand::rng();
(0..count)
.map(|_| Particle {
position: [
rng.random_range(-1.0f32..1.0),
rng.random_range(-1.0f32..1.0),
],
velocity: [
rng.random_range(-0.5f32..0.5),
rng.random_range(-0.5f32..0.5),
],
})
.collect()
}
}
Add rand = "0.9" to your Cargo.toml.
Step 3: the compute shader (WGSL)
struct Particle {
position: vec2<f32>,
velocity: vec2<f32>,
}
@group(0) @binding(0)
var<storage, read_write> particles: array<Particle>;
@group(0) @binding(1)
var<uniform> delta_time: f32;
@compute @workgroup_size(64)
fn cs_main(@builtin(global_invocation_id) id: vec3<u32>) {
let i = id.x;
if i >= arrayLength(&particles) {
return;
}
var p = particles[i];
// Update position
p.position = p.position + p.velocity * delta_time;
// Bounce off edges
if p.position.x < -1.0 || p.position.x > 1.0 {
p.velocity.x = -p.velocity.x;
p.position.x = clamp(p.position.x, -1.0, 1.0);
}
if p.position.y < -1.0 || p.position.y > 1.0 {
p.velocity.y = -p.velocity.y;
p.position.y = clamp(p.position.y, -1.0, 1.0);
}
particles[i] = p;
}
Step 4: the render shader (WGSL)
To render particles as points, the vertex shader reads the position from the storage buffer. Each particle becomes one point:
struct RenderOutput {
@builtin(position) pos: vec4<f32>,
@builtin(point_size) size: f32,
}
// We read the same particle buffer as a storage buffer for rendering
@group(0) @binding(0)
var<storage, read> render_particles: array<Particle>;
@vertex
fn vs_render(@builtin(vertex_index) vi: u32) -> RenderOutput {
var out: RenderOutput;
let p = render_particles[vi];
out.pos = vec4f(p.position, 0.0, 1.0);
out.size = 2.0;
return out;
}
@fragment
fn fs_render() -> @location(0) vec4<f32> {
return vec4f(0.2, 0.8, 0.4, 1.0); // green particles
}
Note: @builtin(point_size) is an optional feature; not all backends support it. An alternative approach is to render each particle as a small quad using instancing.
Step 5: buffer creation
#![allow(unused)]
fn main() {
let num_particles = 10_000u32;
let particles = create_particles(num_particles as usize);
let particle_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Particle Buffer"),
contents: bytemuck::cast_slice(&particles),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::VERTEX,
});
let dt_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Delta Time Buffer"),
size: 4,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
}
The particle buffer has both STORAGE (for the compute shader) and VERTEX (for the render pipeline) usage flags.
Step 6: frame loop
Each frame:
- Calculate delta time
- Write delta time to the uniform buffer
- Run the compute pass to update particles
- Run the render pass to draw particles
#![allow(unused)]
fn main() {
// Compute pass
{
let mut cpass = encoder.begin_compute_pass(&Default::default());
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &compute_bind_group, &[]);
let num_workgroups = (num_particles + 63) / 64; // round up
cpass.dispatch_workgroups(num_workgroups, 1, 1);
}
// Render pass
{
let mut rpass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view: &view,
resolve_target: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Clear(wgpu::Color::BLACK),
store: wgpu::StoreOp::Store,
},
})],
..Default::default()
});
rpass.set_pipeline(&render_pipeline);
rpass.set_bind_group(0, &render_bind_group, &[]);
rpass.draw(0..num_particles, 0..1);
}
}
Note how dispatch_workgroups rounds up: (10000 + 63) / 64 = 157 work groups, giving 10048 invocations. The bounds check in the shader (if i >= arrayLength(&particles)) prevents the extra 48 threads from accessing out-of-bounds memory.
Expected result
You should see thousands of small green particles bouncing around the window, all updated in parallel on the GPU. With 10,000 particles at 60 FPS, the GPU handles 600,000 particle updates per second with ease — and it could handle millions.
Challenge: add a gravity force (p.velocity.y -= 9.8 * delta_time) and watch the particles fall and bounce off the bottom edge. Or add mouse interaction — pass the mouse position as a uniform and apply a force toward or away from the cursor.
Key takeaway: compute shaders can update large datasets in parallel every frame. By giving a buffer both
STORAGEandVERTEXusage flags, you can update data in a compute pass and render it in a render pass without copying between buffers. This compute-then-render pattern is the foundation of GPU-driven simulations.
Part 6 — Going Further
16. Post-processing effects (bloom, blur): conceptual overview
So far, you have rendered directly to the screen. But many visual effects require multi-pass rendering: render the scene to an intermediate texture first, then process that texture in subsequent passes before displaying the final result. This is called post-processing.
Render-to-texture
Instead of targeting the swap chain texture directly, you create an off-screen texture and render to it:
Render-to-Texture
==================
Pass 1: Render scene Pass 2: Post-process
┌──────────────────┐ ┌──────────────────┐
│ │ │ │
│ Scene geometry │──render to──► │ Full-screen │──render to──► Screen
│ (3D objects) │ off-screen │ quad sampling │ swap chain
│ │ texture │ the texture │ texture
└──────────────────┘ └──────────────────┘
In wgpu, this means creating a wgpu::Texture with RENDER_ATTACHMENT | TEXTURE_BINDING usage. You render to it in pass 1, then sample from it in pass 2.
Bloom effect
Bloom makes bright areas of an image glow, simulating how real cameras and eyes perceive very bright light. The algorithm has three stages:
Bloom Pipeline
==============
Scene ──► [1. Threshold] ──► [2. Blur] ──► [3. Composite] ──► Final
Extract Gaussian Add blurred
bright blur the bright areas
pixels result back onto
only the original
Stage 1 — Threshold: a fragment shader that outputs only pixels brighter than a threshold, and black for everything else.
@fragment
fn threshold(in: FullscreenInput) -> @location(0) vec4<f32> {
let color = textureSample(scene_texture, samp, in.uv);
let brightness = dot(color.rgb, vec3f(0.2126, 0.7152, 0.0722));
if brightness > 0.8 {
return color;
}
return vec4f(0.0, 0.0, 0.0, 1.0);
}
The dot with (0.2126, 0.7152, 0.0722) computes perceptual luminance — the human eye is most sensitive to green, then red, then blue.
Stage 2 — Gaussian blur: blur the thresholded image so bright spots become soft glows. Gaussian blur is separable — you can split a 2D blur into two 1D passes (horizontal then vertical), which is much faster:
Separable Gaussian Blur
=======================
Bright Horizontal Vertical Blurred
pixels ──► blur pass ──► blur pass ──► result
(1D, left (1D, up
to right) to down)
A 9x9 2D kernel = 81 samples per pixel
Two 9-wide 1D kernels = 18 samples per pixel
Same result, 4.5x faster!
A single-direction blur shader samples several neighbouring texels with Gaussian weights:
@fragment
fn blur_horizontal(in: FullscreenInput) -> @location(0) vec4<f32> {
let texel_size = 1.0 / f32(textureDimensions(source).x);
var result = vec4f(0.0);
// Gaussian weights for a 5-tap kernel
let weights = array<f32, 5>(0.227, 0.194, 0.122, 0.054, 0.016);
let offsets = array<f32, 5>(0.0, 1.0, 2.0, 3.0, 4.0);
for (var i = 0u; i < 5u; i = i + 1u) {
let offset = vec2f(offsets[i] * texel_size, 0.0);
result += textureSample(source, samp, in.uv + offset) * weights[i];
if i > 0u {
result += textureSample(source, samp, in.uv - offset) * weights[i];
}
}
return result;
}
Stage 3 — Composite: add the blurred bright areas back onto the original scene:
@fragment
fn composite(in: FullscreenInput) -> @location(0) vec4<f32> {
let scene = textureSample(scene_texture, samp, in.uv);
let bloom = textureSample(bloom_texture, samp, in.uv);
return scene + bloom * bloom_intensity;
}
Other post-processing effects
The render-to-texture pattern enables many effects:
- Colour grading: adjust contrast, saturation, colour curves
- Vignette: darken the edges of the screen
- Chromatic aberration: split RGB channels with slight offsets
- Motion blur: blend the current frame with previous frames
- Depth of field: blur based on distance from a focal point (requires a depth buffer)
- Screen-space ambient occlusion (SSAO): approximate indirect shadows
Each effect is a fragment shader running on a full-screen quad, sampling from the previous pass’s texture.
Key takeaway: post-processing effects are implemented as multi-pass rendering. You render the scene to an off-screen texture, then process it through one or more full-screen fragment shader passes. Bloom is a classic example: threshold bright pixels, blur them with separable Gaussian passes, and composite the glow back onto the original. This pattern is the backbone of modern real-time visual effects.
17. Signed Distance Fields for font rendering
Rendering crisp text at any size and rotation is surprisingly difficult with traditional bitmap fonts. Signed Distance Fields (SDFs) provide an elegant solution that gives resolution-independent, anti-aliased text with a single texture.
The problem with bitmap fonts
A bitmap font is a texture where each character is stored as a grid of pixels:
Bitmap "A" at 32px: Zoomed in (pixelated):
┌────────────┐ ┌──┬──┬──┬──┬──┬──┐
│ ██ │ │ │ │██│██│ │ │
│ █ █ │ │ │██│ │ │██│ │
│ ██████ │ │██│██│██│██│██│██│
│ █ █ │ │██│ │ │ │ │██│
│ █ █ │ │██│ │ │ │ │██│
└────────────┘ └──┴──┴──┴──┴──┴──┘
Looks fine at 32px. Looks blocky at 128px.
If you scale the bitmap up, it becomes pixelated. If you scale it down, details are lost. You would need multiple texture sizes, wasting memory.
What is a Signed Distance Field?
An SDF stores, for each texel, the distance to the nearest edge of the shape. Texels inside the shape have negative distances; texels outside have positive distances. The zero-crossing is the exact edge.
SDF for a circle:
+3 +2 +1 0 -1 -2 -1 0 +1 +2 +3
+2 +1 0 -1 -2 -3 -2 -1 0 +1 +2
+1 0 -1 -2 -3 -4 -3 -2 -1 0 +1
0 -1 -2 -3 -4 -5 -4 -3 -2 -1 0
+1 0 -1 -2 -3 -4 -3 -2 -1 0 +1
+2 +1 0 -1 -2 -3 -2 -1 0 +1 +2
+3 +2 +1 0 -1 -2 -1 0 +1 +2 +3
← 0 is the edge. Negative = inside. Positive = outside.
The key insight is that this distance information contains the shape at any resolution. To render, you simply check whether the distance is negative (inside, draw the character) or positive (outside, draw nothing).
The smoothstep trick
Hard thresholding (inside vs outside) gives jagged edges. The smoothstep function provides perfect anti-aliasing by creating a smooth transition in a narrow band around the edge:
@fragment
fn sdf_text(in: VertexOutput) -> @location(0) vec4<f32> {
// Sample the SDF texture — value is distance to edge
let distance = textureSample(sdf_texture, samp, in.uv).r;
// smoothstep creates a smooth transition near the edge
// 0.5 is the edge; the range (0.45, 0.55) is the anti-alias band
let alpha = smoothstep(0.45, 0.55, distance);
return vec4f(text_color.rgb, alpha);
}
smoothstep visualised:
alpha
1.0 ─────────────────────╮
╲
╲ ← smooth transition
╲ (anti-aliased edge)
0.0 ╰───────────────────
outside edge inside
0.45 0.5 0.55
The width of the transition band can be adjusted. A narrower band gives sharper text; a wider band gives softer text. You can even compute the band width based on the rate of change of the UV coordinates (using fwidth) to get pixel-perfect anti-aliasing at any scale:
let distance = textureSample(sdf_texture, samp, in.uv).r;
let edge = 0.5;
let aa_width = fwidth(distance) * 0.75;
let alpha = smoothstep(edge - aa_width, edge + aa_width, distance);
Advantages of SDF text
- Resolution-independent: one small texture (e.g., 64x64 per glyph) looks crisp at any display size
- Cheap anti-aliasing: just
smoothstep— no multisampling needed - Effects for free: outlines, drop shadows, and glow are trivial to add by adjusting the distance threshold:
// Outline effect
let outline_alpha = smoothstep(0.35, 0.40, distance); // outer edge of outline
let fill_alpha = smoothstep(0.45, 0.55, distance); // inner fill
let color = mix(outline_color, fill_color, fill_alpha);
let alpha = outline_alpha;
SDF effects by varying the threshold:
┌──────────────────────────────────┐
│ dist < 0.35 → outside (transparent) │
│ 0.35 to 0.45 → outline │
│ dist > 0.45 → fill (solid text) │
└──────────────────────────────────┘
Generating SDF textures
SDF textures are typically pre-generated offline. Tools include:
- msdfgen: generates multi-channel SDFs for even sharper edges
- Hiero (LibGDX): generates SDF font atlases
- fontdue (Rust crate): can generate SDF glyph bitmaps
The generated SDF texture is a single-channel (greyscale) image where 0.5 represents the edge, values above 0.5 are inside the glyph, and values below 0.5 are outside.
Key takeaway: signed distance fields store the distance to a shape’s edge at each texel. This allows rendering crisp, anti-aliased shapes at any resolution from a small texture. The
smoothstepfunction provides the anti-aliasing, and varying the distance threshold enables outlines, glows, and shadows. SDF-based text rendering is used in game engines, mapping applications, and anywhere resolution-independent text is needed.
18. Resources: Learn WGPU, Shadertoy, The Book of Shaders
This section collects the best resources for continuing your shader programming journey. Each resource approaches the topic from a different angle — use them together for a well-rounded education.
Tutorials and courses
Learn WGPU — the definitive tutorial for wgpu in Rust. It walks through window setup, textures, camera systems, lighting, instancing, and more, with complete working code at each step. If you want to build on the exercises in this course, this is the natural next step.
The Book of Shaders by Patricio Gonzalez Vivo and Jen Lowe — a gentle, visual introduction to fragment shaders. It uses GLSL (not WGSL), but the concepts translate directly: noise functions, patterns, colour mixing, shapes, and animation. The interactive editor lets you experiment in real time. Excellent for building shader intuition.
GPU Gems — NVIDIA’s classic book series (available free online). Covers advanced topics like water rendering, subsurface scattering, shadow techniques, and GPU physics. The techniques are presented in HLSL/GLSL but the algorithms are API-agnostic.
WebGPU Fundamentals — explains WebGPU concepts from the ground up with JavaScript examples. Since wgpu implements the WebGPU spec, the API concepts map directly to Rust. Useful for understanding the “why” behind API design decisions.
Interactive playgrounds
Shadertoy — a web-based shader playground where you write fragment shaders (GLSL) and see results immediately. The community has created incredible effects: raymarched landscapes, fluid simulations, fractal zooms, entire games. Study other people’s shaders to learn techniques — the compact format forces creative solutions. You can port Shadertoy ideas to WGSL in your wgpu projects.
WGSL Playground — Google’s Tour of WGSL. An interactive introduction to the WGSL language with runnable examples. Good for quickly testing WGSL syntax.
Specifications and references
WebGPU Specification — the official W3C specification that wgpu implements. Dense but authoritative. Useful when you need to understand exact behaviour.
WGSL Specification — the complete language specification for WGSL. Reference for built-in functions, types, memory models, and grammar.
wgpu documentation (docs.rs) — Rust API documentation for the wgpu crate. Essential reference for looking up function signatures, enum variants, and descriptor fields.
Advanced topics to explore
Once you are comfortable with the basics covered in this course, here are directions to explore:
- 3D rendering: model-view-projection matrices, depth buffers, camera systems
- Lighting: Phong, Blinn-Phong, physically-based rendering (PBR)
- Shadow mapping: rendering depth from light’s perspective, shadow comparison
- Instancing: drawing thousands of objects efficiently with a single draw call
- Raymarching: rendering 3D scenes using signed distance functions (no triangles)
- Procedural generation: noise functions (Perlin, Simplex) for terrain, textures, and clouds
- Deferred rendering: separating geometry and lighting into different passes
- Skeletal animation: vertex skinning with bone matrices
Community
- wgpu GitHub — the source code, issue tracker, and examples
- WebGPU Matrix channel — real-time chat with the wgpu developers and community
- r/rust_gamedev — Rust game development community on Reddit, where wgpu projects are frequently shared
Key takeaway: shader programming is a vast field. Start with Learn WGPU for Rust-specific guidance, The Book of Shaders for visual intuition, and Shadertoy for inspiration. Keep the WGSL spec and wgpu docs.rs handy as references. The GPU programming community is active and welcoming — share your work and learn from others.
Machine Learning: Training a Game AI Through Self-Play
This is a self-guided course on reinforcement learning through the lens of a concrete goal: training a Rust program to play Tic-Tac-Toe by playing against itself, in the style of AlphaGo Zero. No prior ML experience is assumed. You will build everything from scratch — a game engine, a search algorithm, and eventually a neural network that guides that search. Sections marked 🚧 are stubs whose full content is tracked in a beans ticket.
Table of Contents
Part 1 — Foundations
- What is reinforcement learning?
- Monte Carlo Tree Search — algorithm explained
- Why self-play? The AlphaGo Zero insight
Part 2 — The Game
- Choosing a simple game: Tic-Tac-Toe
- Representing game state in Rust
- Exercise 1: implement the game logic
Part 3 — MCTS
Part 4 — Neural Network Policy/Value Head
- Neural network architecture overview
- Integrating a neural network crate
- Exercise 3: train the network on MCTS data
- Exercise 4: replace rollout with the value network
Part 5 — Self-Play Loop
Part 1 — Foundations
1. What is reinforcement learning?
Imagine teaching a dog to sit. You don’t hand it a textbook on canine biomechanics. You don’t show it a thousand videos of other dogs sitting. Instead, you wait for it to do something close to sitting, and then you give it a treat. Over time, the dog figures out: that particular action leads to something good. It learns from experience.
That, in essence, is reinforcement learning (RL). It is one of the three major branches of machine learning, and it is the one that most closely mirrors how humans and animals learn to interact with the world. Before we dive into RL itself, let’s briefly look at the other two branches so you can see what makes RL different.
The three branches of machine learning
Supervised learning is like studying with an answer key. You are given a dataset of questions paired with correct answers — pictures labeled “cat” or “dog,” emails labeled “spam” or “not spam,” house features paired with sale prices. The algorithm’s job is to learn the pattern that maps inputs to outputs so it can make predictions on new, unseen data.
The key ingredient is labeled data. Someone (often a human) has already done the hard work of providing the right answer for every example. The algorithm learns by comparing its predictions to those answers and adjusting itself to get closer.
Unsupervised learning is like being handed a box of puzzle pieces with no picture on the box. There are no labels, no right answers. The algorithm’s job is to find structure on its own — clusters of similar data points, hidden patterns, compressed representations. Think of a music streaming service grouping songs into genres based on audio features alone, without anyone telling it what “jazz” or “rock” means.
Reinforcement learning is different from both. There are no labeled examples to study. There is no dataset to mine for hidden structure. Instead, there is an agent that takes actions in an environment and receives rewards (or penalties) based on what happens. The agent’s goal is to figure out, through trial and error, which actions lead to the most reward over time.
Here is a table that summarizes the differences:
| Supervised | Unsupervised | Reinforcement | |
|---|---|---|---|
| Learning signal | Correct answer for each example | None — find structure | Reward signal after actions |
| Data | Labeled dataset, provided upfront | Unlabeled dataset, provided upfront | Generated through interaction |
| Goal | Predict the right output | Discover hidden patterns | Maximize cumulative reward |
| Analogy | Studying with an answer key | Sorting a box of unlabeled photos | Training a dog with treats |
The critical difference for our purposes: in RL, the agent generates its own data by interacting with the environment. It doesn’t need a human to label millions of examples. It learns by doing. This is exactly why RL is so powerful for games — the agent can play millions of games against itself and learn from every single one.
The core vocabulary of RL
Let’s define the key concepts. We will use a concrete example throughout: imagine you are playing a simple board game.
Agent
The agent is the learner and decision-maker. It is the entity that observes the world, chooses actions, and tries to maximize its reward. In our course, the agent will be a program that plays Tic-Tac-Toe. In the real world, an agent could be a robot navigating a warehouse, a program trading stocks, or a thermostat controlling room temperature.
The agent is not the environment. It does not control the rules. It can only observe what is happening and choose from the actions available to it.
Environment
The environment is everything outside the agent. It is the world the agent interacts with. The environment receives the agent’s actions, updates its internal state, and sends back two things: the new state and a reward.
For a board game, the environment is the game itself — the board, the rules, and the opponent. The agent places a piece, and the environment responds with the updated board position and (eventually) a signal about whether the agent won or lost.
The environment does not need to be physical. It can be a simulation, a video game, a mathematical function — anything that accepts actions and produces states and rewards.
State
The state is a snapshot of the environment at a particular moment. It contains all the information the agent needs to make a decision.
In Tic-Tac-Toe, the state is the current board configuration: which squares have X, which have O, and which are empty. In chess, it is the position of every piece on the board. For a self-driving car, the state might include speed, position, nearby obstacles, traffic light status, and more.
A good state representation captures everything relevant. If the agent cannot see something that matters, it will struggle to make good decisions. We will spend a full section later on choosing a state representation for Tic-Tac-Toe.
Action
An action is something the agent can do. At each step, the agent looks at the current state and selects one action from the set of available actions.
In Tic-Tac-Toe, an action is placing your mark on one of the empty squares. If there are four empty squares, the agent has four possible actions. In a video game, actions might be “move left,” “jump,” or “fire.” For a robot arm, actions could be “rotate joint 30 degrees” or “close gripper.”
The set of available actions can change from state to state. Early in a Tic-Tac-Toe game, you have nine possible moves. Late in the game, you might have only one or two.
Reward
The reward is a numerical signal from the environment that tells the agent how well it is doing. It is the only feedback the agent gets. The agent’s entire goal is to accumulate as much reward as possible over time.
Rewards can be:
- Positive — “That was good, do more of that.” Winning a game might give +1.
- Negative — “That was bad, do less of that.” Losing might give -1.
- Zero — “Nothing interesting happened.” Most moves in a game might give 0 reward.
One of the challenges of RL is that rewards can be sparse. In Tic-Tac-Toe, you might only get a reward at the very end of the game — win, lose, or draw. All the moves leading up to that final outcome get zero reward. The agent has to figure out which of its earlier moves were responsible for the eventual outcome. This is called the credit assignment problem: which actions deserve credit (or blame) for a delayed reward?
Think about learning to cook. You follow a recipe, making dozens of small decisions along the way — how much salt, how long to sear, when to add the garlic. You only get your “reward” when you taste the final dish. If it is terrible, which decision was the mistake? Was it the salt? The garlic timing? The heat? RL faces exactly this challenge.
Policy
A policy is the agent’s strategy — its rule for choosing actions. Given a state, the policy tells the agent what to do.
A policy can be simple: “always pick the first available square” (a terrible Tic-Tac-Toe strategy, but a valid policy). Or it can be complex: a neural network that takes the board state as input and outputs a probability distribution over all possible moves.
We write the policy as pi (the Greek letter). Formally, a policy maps states to actions:
Given state s, the policy pi(s) tells the agent which action a to take.
A deterministic policy always picks the same action for a given state. A stochastic policy outputs probabilities — “play center with 60% chance, play corner with 30% chance, play edge with 10% chance.” Stochastic policies are powerful because they allow exploration. If the agent always does the same thing, it can never discover that something else might be better.
The goal of RL is to find a good policy — one that leads to high cumulative reward. The optimal policy is the one that maximizes expected reward. Finding it (or approximating it) is what RL algorithms do.
Value function
The value function answers a deceptively important question: how good is it to be in a particular state?
Not how good is the immediate reward — but how much total future reward can the agent expect to accumulate from this state onward, assuming it follows its current policy?
Consider two Tic-Tac-Toe positions. In one, you have two in a row with the third square open — you are about to win. In the other, your opponent has two in a row with the third square open — you are about to lose. The immediate reward for both states is zero (the game is not over yet), but their values are very different. One state is worth a lot; the other is nearly worthless.
The value function is what lets the agent think ahead. Instead of being greedy (just maximizing immediate reward), the agent can evaluate states by their long-term potential. “This move gives me zero reward right now, but it puts me in a state with high value.”
There are two flavors:
- State value function V(s): How much total future reward can I expect from state s?
- Action value function Q(s, a): How much total future reward can I expect if I take action a in state s?
The Q function is especially useful because it directly tells the agent which action is best: just pick the action with the highest Q value.
The RL loop
Now let’s put all these concepts together. Reinforcement learning follows a cycle — a loop that repeats over and over:
┌──────────────────────────────────────────────────┐
│ │
│ ┌─────────┐ action ┌─────────────┐ │
│ │ │──────────────▶│ │ │
│ │ Agent │ │ Environment │ │
│ │ │◀──────────────│ │ │
│ └─────────┘ state, └─────────────┘ │
│ reward │
│ │
└──────────────────────────────────────────────────┘
Here is the loop, step by step:
-
Observe: The agent observes the current state of the environment. In Tic-Tac-Toe, it looks at the board.
-
Decide: The agent uses its policy to choose an action. Maybe it picks the center square. Maybe it picks a corner. The policy determines how this choice is made.
-
Act: The agent takes the action. It places its mark on the board.
-
Receive feedback: The environment responds with two things — the new state (the updated board, possibly after the opponent also moves) and a reward (did the game end? did we win?).
-
Learn: The agent uses this experience (state, action, reward, new state) to update its policy and/or value function. “Last time I was in that state and took that action, things went well. I should do that more often.”
-
Repeat: Go back to step 1.
This loop runs thousands, millions, even billions of times. Early on, the agent has no idea what it is doing. It explores randomly, stumbling into wins and losses. But over time, it accumulates experience. It starts to learn which states are good and which are bad, which actions lead to rewards and which lead to disaster. The policy improves. The value estimates get more accurate. The agent gets better.
This is fundamentally different from supervised learning, where all the data is collected upfront and the model trains on it in a fixed dataset. In RL, the agent’s behavior changes the data it collects. A better policy leads to different states, which leads to different experiences, which leads to further improvements. The agent is simultaneously exploring the environment and exploiting what it has learned.
Exploration vs. exploitation
This brings us to one of the deepest tensions in RL: the exploration-exploitation tradeoff.
Suppose you have found a restaurant you like. Do you keep going there (exploit what you know) or try a new restaurant that might be even better (explore the unknown)? If you always exploit, you might miss something great. If you always explore, you never enjoy the best thing you have already found.
RL agents face this same dilemma at every step. The agent has a policy that it thinks is pretty good. Should it follow that policy (exploit) or try something random to see what happens (explore)?
-
Too much exploitation: The agent gets stuck in a rut. It finds a decent strategy and never discovers that a much better one exists. In Tic-Tac-Toe, it might learn to always play in the center and never discover that corner openings can be just as strong.
-
Too much exploration: The agent never commits to anything. It keeps trying random things and never builds on what it has learned. It plays like a beginner forever.
Good RL algorithms balance exploration and exploitation carefully. A common approach is to start with lots of exploration (try everything!) and gradually shift toward exploitation as the agent becomes more confident in its knowledge. We will see this principle in action when we implement Monte Carlo Tree Search in a later section.
Why RL is a natural fit for games
You might be wondering: why are we using reinforcement learning for a game? Why not just use supervised learning — collect a bunch of expert games, label the moves as “good” or “bad,” and train on that?
There are several reasons why RL is a particularly natural fit for games:
Games have clear rules and rewards. The environment is well-defined. The state is fully observable (in most board games, you can see the whole board). The reward is unambiguous: win, lose, or draw. There is no messy, noisy real-world data to deal with. This makes games an ideal testbed for RL algorithms.
You can generate unlimited training data. Unlike supervised learning, where you need humans to label examples, an RL agent can play games against itself forever. It generates its own training data through self-play. No labeling required. No expensive human expertise needed. The agent just plays and learns.
The search space is structured. Games have turns, legal moves, and terminal states. This structure makes it possible to reason about future consequences systematically. Algorithms like Monte Carlo Tree Search (which we will cover in the next section) exploit this structure brilliantly.
There is no “expert” bottleneck. Supervised learning is limited by the quality of its training data. If you train on expert games, the best your agent can do is imitate those experts. But what if the experts are not actually optimal? RL agents can surpass human performance because they are not trying to copy anyone — they are trying to maximize reward. AlphaGo Zero famously discovered Go strategies that no human had ever played, because it was not constrained by human conventions.
Games have been the proving ground for AI breakthroughs. From IBM’s Deep Blue defeating Kasparov at chess in 1997, to DeepMind’s AlphaGo defeating Lee Sedol at Go in 2016, to OpenAI Five competing at Dota 2, games have consistently pushed the boundaries of what AI can do. The techniques developed for games — including the ones you will learn in this course — have gone on to be applied in robotics, drug discovery, chip design, and more.
A preview of where we are headed
In this course, we will build a complete RL system from scratch. Here is a high-level preview:
-
Part 2 builds the game — a clean, efficient Tic-Tac-Toe engine in Rust that can represent states, enumerate legal moves, and detect wins.
-
Part 3 builds the search — Monte Carlo Tree Search, an algorithm that explores possible future game states by simulating random playouts. MCTS does not need any machine learning at all; it works through pure simulation.
-
Part 4 adds the brain — a neural network that learns to evaluate board positions (value) and suggest promising moves (policy). This replaces the random simulations in MCTS with learned intuition.
-
Part 5 closes the loop — the full self-play training cycle where the agent plays itself, generates training data, improves its neural network, and repeats. This is the AlphaGo Zero recipe.
By the end, you will have a program that starts knowing nothing about Tic-Tac-Toe — not even the basic strategy — and teaches itself to play at a strong level through pure self-play. The same architecture, scaled up, is what defeated the world champion at Go.
But first, we need to understand the search algorithm that makes this all possible. In the next section, we will explore Monte Carlo Tree Search.
2. Monte Carlo Tree Search — algorithm explained
In the last section, we learned that an RL agent needs a policy — a strategy for choosing actions. But how do you come up with a good policy for a game like Tic-Tac-Toe, or chess, or Go? One powerful answer is to search: look ahead at possible future moves, evaluate what might happen, and pick the move that leads to the best outcome.
The problem is that looking ahead in a game is staggeringly expensive.
The game tree problem
Every game can be represented as a game tree. The root is the current position. Each branch is a possible move. Each node is a resulting position. The leaves are terminal states — wins, losses, or draws.
For Tic-Tac-Toe, the game tree is manageable. From the opening position, there are 9 possible first moves, then 8 possible responses, then 7, and so on. The total number of possible games is at most 9! = 362,880 (and fewer in practice because many games end before all squares are filled). A modern computer can explore the entire Tic-Tac-Toe tree in milliseconds.
Now consider chess. From the opening position, White has 20 possible first moves. Black has 20 possible responses. That is already 400 positions after one move each. The average chess game is about 40 moves per side, and at each turn there are roughly 30 legal moves. The total number of possible chess games is estimated at around 10^120 — a number so large it dwarfs the number of atoms in the observable universe (roughly 10^80).
Go is even worse. On a 19x19 board, the first move has 361 options. The game tree has an estimated 10^360 possible positions.
Exhaustive search is impossible. You cannot look at every possible future. You need a smarter approach — one that focuses its computational effort on the parts of the tree that matter most.
What “Monte Carlo” means
The term “Monte Carlo” comes from the famous casino in Monaco and refers to a broad class of algorithms that use random sampling to estimate quantities that are too expensive to compute exactly.
The core idea is simple: if you cannot calculate something precisely, you can estimate it by running many random experiments and averaging the results.
Suppose someone asks you: “What fraction of the area of a square is covered by a circle inscribed inside it?” You could calculate this analytically (it is pi/4). Or you could take the Monte Carlo approach: throw 10,000 darts randomly at the square, count how many land inside the circle, and divide. With enough darts, your estimate converges to the true answer.
Monte Carlo Tree Search applies this same idea to game trees. Instead of exhaustively analyzing every possible sequence of moves, we play out random games (called rollouts or simulations) and use the results to estimate how good each move is. A move that leads to lots of random wins is probably a good move. A move that leads to lots of random losses is probably a bad one.
The remarkable thing is that this works. Even completely random play — just picking legal moves at random until the game ends — gives you useful signal about which positions are strong and which are weak. And as we will see, MCTS is much smarter than pure random sampling. It focuses its simulations on the most promising parts of the tree.
The four phases of MCTS
MCTS builds a search tree incrementally, one simulation at a time. Each simulation consists of four phases: Selection, Expansion, Simulation, and Backpropagation. Let’s walk through each one.
Phase 1: Selection
Starting from the root of the tree (the current game position), we walk down through nodes that we have already explored, choosing which child to visit at each step.
The key question is: which child do we pick? We want to visit moves that look promising (to refine our estimate of their value), but we also want to try moves we have not explored much (because they might turn out to be even better). This is the exploration-exploitation tradeoff from the previous section, and MCTS handles it with a formula called UCT, which we will explain shortly.
Selection continues until we reach a node that has at least one child we have never visited — a frontier of the known tree.
Phase 2: Expansion
Once selection reaches a node with unexplored children, we pick one of those unexplored moves and add it to the tree as a new node. This is how the tree grows: one new node per simulation.
The new node starts with no statistics — we have never simulated a game from this position before. We are about to fix that.
Phase 3: Simulation (rollout)
From the newly expanded node, we play a random game to completion. Both sides make random legal moves until someone wins, someone loses, or the game is drawn. This is called a rollout or playout.
The rollout does not add any nodes to the tree. It is purely a quick-and-dirty way to estimate how good the new position is. If random play from this position tends to result in wins, the position is probably decent. If it tends to result in losses, the position is probably bad.
The result of the rollout is a simple outcome: typically +1 for a win, -1 for a loss, and 0 for a draw.
Phase 4: Backpropagation
After the rollout finishes, we take the result and propagate it back up the tree, updating every node we visited during the selection phase.
Each node in the tree tracks two numbers:
- N — the number of times this node has been visited (i.e., how many simulations have passed through it)
- W — the total accumulated reward from all simulations that passed through it
We increment N by 1 and add the rollout result to W for every node on the path from the new node back to the root. Note that the reward is relative to the player whose turn it is at each node — a win for player X is a loss for player O, so the sign of the reward flips as we move up through alternating players.
After backpropagation, the tree has slightly better statistics. The next simulation can make a slightly more informed selection. Over hundreds or thousands of simulations, the tree’s estimates converge toward the true values of each move.
The MCTS loop in pseudocode
Here is the complete algorithm:
function mcts(root_state, num_simulations):
root = create_node(root_state)
repeat num_simulations times:
// Phase 1: Selection
node = root
state = copy(root_state)
while node is fully expanded and not terminal:
node = select_child(node) // using UCT formula
state = apply_move(state, node.move)
// Phase 2: Expansion
if node is not terminal:
move = pick an unexplored move from node
state = apply_move(state, move)
child = create_node(state, move)
add child to node's children
node = child
// Phase 3: Simulation (rollout)
while state is not terminal:
move = random legal move
state = apply_move(state, move)
result = reward from terminal state
// Phase 4: Backpropagation
while node is not null:
node.N += 1
node.W += result (from node's perspective)
node = node.parent
// After all simulations, pick the most-visited child of root
return child of root with highest N
Notice the final step: after all simulations are done, we pick the move whose corresponding child has been visited the most (highest N), not the one with the highest average reward. Visit count is a more robust measure than average reward because it reflects how much computational effort the algorithm devoted to that move. A move that was visited many times was consistently attractive enough to survive the selection process.
UCB1 and the UCT formula
The heart of MCTS is the selection phase, and the heart of selection is the formula that decides which child to visit. This is where exploration and exploitation are balanced.
The standard formula is called UCT (Upper Confidence Bounds for Trees), adapted from a statistical formula called UCB1 (Upper Confidence Bound 1).
For each child node, UCT computes a score:
UCT(child) = (W / N) + C * sqrt(ln(N_parent) / N)
Where:
- W / N is the child’s average reward — its win rate so far. This is the exploitation term. Children with high win rates get high scores.
- sqrt(ln(N_parent) / N) is the exploration term. N_parent is the visit count of the parent node, and N is the visit count of the child. This term is large when the child has been visited relatively few times compared to its siblings. It shrinks as the child gets more visits.
- C is a constant that controls the balance between exploration and exploitation. A common choice is C = sqrt(2), which comes from the theoretical analysis of UCB1. In practice, this value is often tuned.
The intuition behind UCT is straightforward:
If a child has a high win rate, visit it more (exploitation). You want to refine your estimate of moves that look strong.
If a child has been visited much less than its siblings, visit it more (exploration). You might be missing something good. A move with 2 visits and a 50% win rate might actually be the best move — you just haven’t given it enough chances yet.
As N grows, the exploration term shrinks. The more you visit a child, the less “exploration bonus” it gets. Eventually, exploitation dominates, and the algorithm focuses on the moves with genuinely high win rates.
As N_parent grows but N stays small, the exploration term grows. If the parent has been visited 1000 times but one child has only been visited 3 times, that child gets a large exploration bonus. The algorithm is saying: “You’ve visited this position 1000 times but barely looked at this move. Go check it out.”
UCT is elegant because it automatically and gradually shifts from exploration to exploitation. In the early simulations, when visit counts are low, the exploration term dominates and the algorithm tries many different moves. As simulations accumulate, the exploitation term takes over and the algorithm focuses on the best moves.
A worked example
Let’s trace through a small example to see MCTS in action. Imagine a very simple game — not Tic-Tac-Toe, but something even smaller so we can fit the tree in ASCII art. Suppose it is our turn and we have three possible moves: A, B, and C.
After 0 simulations — We have only the root node. Nothing has been explored yet.
Root
(0/0)
/ | \
? ? ?
A B C
The notation (W/N) means W total reward out of N visits. The root has no visits yet.
Simulations 1-3 — We expand and simulate each child once (since all are unexplored, selection goes to each in turn). Suppose the random rollouts give us: A wins, B loses, C wins.
Root
(1/3)
/ | \
(1/1) (0/1) (1/1)
A B C
A has 1 win out of 1 visit (average 1.0). B has 0 wins out of 1 visit (average 0.0). C has 1 win out of 1 visit (average 1.0).
Simulation 4 — Now all children have been visited at least once, so selection uses UCT to choose. Let’s compute UCT scores with C = sqrt(2) ~ 1.41:
UCT(A) = 1/1 + 1.41 * sqrt(ln(3)/1) = 1.0 + 1.41 * 1.048 = 1.0 + 1.48 = 2.48
UCT(B) = 0/1 + 1.41 * sqrt(ln(3)/1) = 0.0 + 1.41 * 1.048 = 0.0 + 1.48 = 1.48
UCT(C) = 1/1 + 1.41 * sqrt(ln(3)/1) = 1.0 + 1.41 * 1.048 = 1.0 + 1.48 = 2.48
A and C are tied. Say we break ties randomly and pick A. Suppose the rollout from A results in a loss this time.
Root
(1/4)
/ | \
(1/2) (0/1) (1/1)
A B C
Simulation 5 — Recompute UCT scores:
UCT(A) = 1/2 + 1.41 * sqrt(ln(4)/2) = 0.50 + 1.41 * 0.833 = 0.50 + 1.17 = 1.67
UCT(B) = 0/1 + 1.41 * sqrt(ln(4)/1) = 0.00 + 1.41 * 1.177 = 0.00 + 1.66 = 1.66
UCT(C) = 1/1 + 1.41 * sqrt(ln(4)/1) = 1.00 + 1.41 * 1.177 = 1.00 + 1.66 = 2.66
C has the highest UCT score (2.66) — it has a perfect win rate and has only been visited once, so both exploitation and exploration favor it. MCTS selects C. Suppose the rollout from C is a win again.
Root
(2/5)
/ | \
(1/2) (0/1) (2/2)
A B C
Notice what is happening. C is emerging as the most promising move (2 wins out of 2 visits). But MCTS has not abandoned A (which might just have been unlucky) or even B (which gets an exploration bonus for being rarely visited). Over many more simulations, if C is truly the best move, it will accumulate the most visits.
After 100 simulations, the tree might look something like this:
Root
(62/100)
/ | \
(25/42) (7/16) (30/42)
A B C
C has the highest visit count (42) and win rate (30/42 = 71%). A is close behind. B has been explored less because its early results were poor, but MCTS still checked it 16 times to make sure. The algorithm would recommend move C.
How MCTS builds an asymmetric tree
A critical property of MCTS is that it does not explore the game tree uniformly. Unlike minimax with a fixed depth limit (which examines every move to the same depth), MCTS builds an asymmetric tree that is deep in promising branches and shallow in unpromising ones.
Consider this scenario in a more complex game:
Root
/ \
A B
/ \ |
A1 A2 B1
/
A1a
/ \
A1a1 A1a2
Move A looks promising, so MCTS explores it deeply. Within A, move A1 looks better than A2, so it gets explored even deeper. Move B looks weak, so it only has one child — MCTS checked it out, confirmed it was not great, and moved on.
This is exactly the right strategy. There is no point spending computational effort analyzing a bad move 15 moves deep. It is much better to spend those simulations refining your understanding of the good moves. MCTS achieves this automatically through the UCT formula — no hand-tuning required.
This property becomes especially important in games with large branching factors. In Go, where each position might have 200+ legal moves, you simply cannot afford to explore every move to the same depth. MCTS naturally focuses on the 10-20 most promising moves and explores them deeply while giving the others only a cursory glance.
Why MCTS works without domain knowledge
One of the most remarkable properties of MCTS is that it works with zero domain knowledge. The algorithm needs only three things:
- A way to enumerate legal moves from any position
- A way to apply a move and get the resulting position
- A way to detect when the game is over and who won
That is it. MCTS does not need to know that controlling the center matters in Tic-Tac-Toe. It does not need to know that connected stones are strong in Go. It does not need an evaluation function crafted by a human expert. It discovers all of this on its own through random rollouts.
This is a profound advantage. For many games and decision problems, writing a good evaluation function is extremely difficult. What makes a chess position “good”? Generations of chess programmers have painstakingly encoded hundreds of heuristic rules — piece values, king safety, pawn structure, mobility, and more. MCTS sidesteps this entirely. You just need the rules of the game.
Of course, random rollouts are not as accurate as expert evaluation. A random game of chess is a terrible guide to who is winning — random play is so bad that the result is nearly uncorrelated with the actual quality of the position. This is where the AlphaGo Zero innovation comes in (which we will cover in the next section): replace random rollouts with a neural network that has learned to evaluate positions. But pure MCTS with random rollouts is still surprisingly effective for many games, especially those with lower branching factors.
Strengths of vanilla MCTS
Anytime algorithm. You can stop MCTS after 100 simulations or 100,000 simulations. More simulations give better results, but even a small number of simulations gives a reasonable answer. This is useful when you have a time limit — run as many simulations as you can in the available time, then pick the best move.
No evaluation function needed. As discussed above, MCTS works with just the game rules. This makes it applicable to any game (or sequential decision problem) immediately, without domain-specific engineering.
Handles high branching factors. Unlike alpha-beta search (the traditional game-tree search algorithm), which slows down dramatically as the branching factor increases, MCTS handles large branching factors gracefully because it focuses on promising branches.
Naturally balances exploration and exploitation. The UCT formula provides an elegant, principled solution to this fundamental tradeoff.
Asymptotically optimal. Given enough simulations, MCTS converges to the optimal move. The more you simulate, the closer you get to perfect play.
Limitations of vanilla MCTS
Random rollouts can be uninformative. In complex games, random play is essentially noise. A random chess game tells you almost nothing about who is actually winning. This limits how well pure MCTS can play strategic games. (The fix: replace rollouts with a learned evaluation function — exactly what AlphaGo Zero does.)
Many simulations needed for complex games. While MCTS is an anytime algorithm, the number of simulations needed for good play scales with the complexity of the game. For Go on a 19x19 board, millions of simulations per move are needed for strong play with random rollouts alone.
Memory usage. MCTS stores the search tree in memory. In long games or games with high branching factors, this tree can grow large. Various techniques exist to manage this (e.g., recycling parts of the tree between moves).
Not well-suited to all problems. MCTS works best in domains with discrete actions, clear terminal states, and well-defined rewards. It is less natural for continuous control problems or domains where the outcome of a random policy is completely uninformative.
Looking ahead
We now have the two foundational concepts for this course: reinforcement learning (from section 1) and Monte Carlo Tree Search (from this section). In the next section, we will see how DeepMind combined these ideas — along with deep neural networks and self-play — to create AlphaGo Zero, an agent that mastered Go with no human knowledge at all. That combination of MCTS + neural networks + self-play is exactly what we will build in Rust over the rest of this course.
3. Why self-play? The AlphaGo Zero insight
You now understand reinforcement learning and Monte Carlo Tree Search as separate ideas. RL gives us the framework: an agent, an environment, a reward signal, and a policy to optimize. MCTS gives us a powerful search algorithm that explores game trees intelligently using random simulations. Both are impressive on their own. But neither one, by itself, is enough to produce superhuman game play in a complex game like Go.
The missing ingredient is self-play — the idea that an agent can generate its own training data by playing against copies of itself, creating a feedback loop where improvement begets further improvement, with no human knowledge required. This section tells the story of how that insight emerged and why it changed the field.
From handcrafted to learned: a brief history
To appreciate what self-play made possible, it helps to understand what came before. The history of game-playing AI is a story of progressively removing the need for human expertise.
Deep Blue: the triumph of engineering (1997)
In 1997, IBM’s Deep Blue defeated world chess champion Garry Kasparov in a six-game match. It was a landmark moment for AI — the first time a computer had beaten a reigning world champion at chess under standard tournament conditions.
But Deep Blue was not a learning system. It was a masterpiece of engineering. At its core was an alpha-beta search algorithm (a refined version of minimax) running on custom hardware that could evaluate 200 million chess positions per second. The critical component was its evaluation function: a hand-tuned formula that scored any chess position based on hundreds of features — material count, king safety, pawn structure, piece mobility, control of the center, and dozens of other heuristics. These features were designed and tuned by a team that included grandmaster Joel Benjamin.
Deep Blue played brilliantly, but everything it knew about chess had been put there by humans. The evaluation function was the product of decades of chess programming knowledge. If you wanted to apply the same approach to a different game — say, Go — you would need a team of Go experts to design a completely new evaluation function from scratch. And for Go, no one could figure out how to do that well. The game is too subtle, too positional, too intuitive. Handcrafted evaluation functions for Go never came close to expert human play.
Deep Blue’s approach was powerful but fundamentally limited: you need human experts to encode their knowledge, and you can never exceed what those experts know.
AlphaGo: learning from humans, then surpassing them (2016)
Nearly two decades later, DeepMind tackled Go — a game widely considered to be far beyond the reach of brute-force search. The original AlphaGo system, which defeated European champion Fan Hui in 2015 and world champion Lee Sedol in 2016, took a hybrid approach.
AlphaGo’s training had two stages:
Stage 1: Supervised learning from human expert games. DeepMind collected a database of roughly 30 million positions from games played by strong human Go players on online servers. They trained a deep neural network (the policy network) to predict the move a human expert would play in any given position. This network learned to imitate human play — given a board position, it output a probability distribution over possible moves, reflecting how likely a strong human would be to play each one.
This supervised policy network was already impressive. It could predict the expert move about 57% of the time, and it played at a level roughly equivalent to a strong amateur. But imitation has limits. The network was constrained by the data it was trained on — it could only learn patterns that appeared in human games.
Stage 2: Reinforcement learning through self-play. Starting from the supervised policy network, DeepMind then used RL to improve it further. The network played millions of games against earlier versions of itself. After each game, the policy was updated to make winning moves more likely and losing moves less likely. This self-play RL stage pushed AlphaGo significantly beyond the level of the human games it had originally learned from.
AlphaGo also trained a value network — a separate neural network that, given a board position, estimated the probability of winning from that position. This value network replaced the random rollouts in MCTS. Instead of playing a random game to the end to estimate a position’s value, AlphaGo could ask its value network: “How likely am I to win from here?”
The final AlphaGo system combined MCTS with these two neural networks: the policy network guided which branches to explore (replacing the need to look at every possible move), and the value network evaluated positions at the leaves of the search tree (replacing random rollouts). This was far more efficient and far more accurate than either component alone.
AlphaGo’s victory over Lee Sedol in March 2016 was a watershed moment. Go had been considered a “grand challenge” for AI for decades. But there was a nagging asterisk: AlphaGo still started from human data. It needed those 30 million expert positions to bootstrap its learning. What if the human experts were wrong about something? What if there were strategies that no human had ever discovered? The supervised learning stage was both a crutch and a ceiling.
AlphaGo Zero: tabula rasa (2017)
In October 2017, DeepMind published a paper that answered the question: What happens if you remove the human data entirely?
AlphaGo Zero started with no human games, no handcrafted features, and no human knowledge beyond the rules of Go. The network’s weights were initialized randomly. It knew nothing about Go strategy — not that the center is important, not that connected stones are strong, not that you should surround territory. Nothing.
It learned everything from self-play alone. And within 40 days, it surpassed every previous version of AlphaGo, including the version that defeated Lee Sedol. It was not even close — AlphaGo Zero defeated the Lee Sedol version 100 games to 0.
Let that sink in. A system that started knowing literally nothing about Go, that learned entirely by playing against itself, crushed a system that had been bootstrapped from millions of human expert games. Human expert knowledge was not just unnecessary — it was actively holding the system back.
Why human data can be a ceiling
This result is counterintuitive. Surely starting with human knowledge should help? You get a head start. You don’t have to waste time rediscovering basic strategy. How could it be worse to start with expert data?
The answer lies in the nature of imitation learning. When you train on human expert games, you are teaching the network to predict what a human would do. But “what a human would do” is not the same as “the objectively best move.” Human play reflects human biases, conventions, fashions, and limitations.
Consider a few ways human data can limit an AI:
Blind spots. If no human has ever played a particular style of move, that move will be absent from the training data. The network will learn to assign it near-zero probability, even if it is actually strong. Human Go players avoided certain moves because they violated conventional wisdom — but some of those moves turn out to be excellent when analyzed deeply enough. AlphaGo Zero discovered several such moves that professional players had dismissed for centuries.
Convergent play styles. Strong human players tend to play similarly to each other because they study the same games, read the same books, and follow the same meta-game. Training on this data produces a network that imitates this convergent style. It cannot easily break out of the mold to find something genuinely novel.
Suboptimal consensus. Sometimes the entire human community is wrong about a position or a strategy. If every expert agrees that a particular opening is weak, the training data will reflect that consensus — even if the consensus is wrong. A self-play system has no such preconceptions.
Distributional mismatch. The states a strong human encounters during a game are not the same states the AI will encounter once it starts improving beyond human level. The AI ends up in positions that no human has ever seen, and its policy — trained only on human-encountered positions — has no useful guidance for these alien board states.
A self-play system avoids all of these problems. It is not trying to imitate anyone. It is trying to win. It explores whatever strategies lead to victories, no matter how unconventional. Its training data comes from its own games, so the distribution of positions it trains on always matches the positions it actually encounters during play.
The self-play insight
The core idea of self-play is deceptively simple: the agent generates its own training data by playing against itself.
Here is how it works in broad strokes:
- Start with an agent that plays randomly (or very poorly).
- Have the agent play a game against a copy of itself.
- Use the game’s outcome (who won, who lost) and the move-by-move decisions as training data.
- Update the agent’s neural network using this data, making it slightly better.
- Go back to step 2 with the improved agent.
At first, both sides play terribly. The games are random chaos. But even in random chaos, one side wins and the other loses, and that outcome provides a training signal. The network learns, very slightly, which positions tend to lead to wins and which tend to lead to losses. It learns, very slightly, which moves are better than others.
Now something remarkable happens. The slightly-improved agent plays against itself again. Because it is slightly better, it generates slightly better training data. The games are slightly less random, slightly more strategic. The agent learns from this better data and improves further. Which produces even better data. Which produces even more improvement.
The virtuous cycle
This is the key dynamic that makes self-play so powerful — a virtuous cycle where improvement in the agent’s policy leads to higher-quality training data, which leads to further improvement:
┌─────────────────────────────────────────────┐
│ │
│ Better policy │
│ │ │
│ ▼ │
│ Higher-quality self-play games │
│ │ │
│ ▼ │
│ More informative training data │
│ │ │
│ ▼ │
│ Better neural network │
│ │ │
│ ▼ │
│ Better policy ──── (cycle repeats) ──┐ │
│ │ │
└─────────────────────────────────────────┘ │
▲ │
└───────────────────────────────────┘
This is in sharp contrast to supervised learning, where the training data is fixed. No matter how good the model gets, it is still learning from the same dataset. The quality of the data is a hard ceiling on the quality of the model. In self-play, the ceiling rises with the agent. The data gets better because the agent gets better.
There is an important subtlety here: why doesn’t the agent just learn to exploit weaknesses in its own play, going around in circles? The answer is that MCTS provides a stabilizing force. Even when the neural network’s policy is weak, MCTS improves upon it by running hundreds or thousands of simulations. The search acts as a “policy improvement operator” — the MCTS-enhanced policy is always stronger than the raw neural network policy. When the network is then trained to match the MCTS-improved policy, it genuinely gets better. The combination of search and learning prevents the kind of circular, degenerate strategies that might arise from naive self-play.
The AlphaGo Zero architecture
Let’s look at how AlphaGo Zero puts all of this together. The architecture is elegant in its simplicity — it unifies several components that were separate in earlier systems.
A single neural network with two heads
The original AlphaGo used two separate networks: a policy network (which moves to play) and a value network (who is winning). AlphaGo Zero combined these into a single neural network with two output heads:
-
Policy head: Given the current board state, output a probability distribution over all legal moves. “Play here with 35% probability, here with 20%, here with 15%…” This guides the MCTS search toward promising moves.
-
Value head: Given the current board state, output a single number between -1 and +1 estimating who is winning. +1 means “I am definitely winning,” -1 means “I am definitely losing,” and 0 means “the position is roughly equal.”
Both heads share the same underlying neural network (a deep residual network, or ResNet), which processes the board state into a rich internal representation. The two heads are just different “lenses” on that same representation — one for suggesting moves, one for evaluating positions.
MCTS guided by the neural network
In vanilla MCTS (as we described in section 2), the simulation phase uses random rollouts — random moves played to the end of the game. AlphaGo Zero replaces this entirely. There are no random rollouts. Instead:
-
During selection, MCTS uses a modified UCT formula that incorporates the neural network’s policy output. Moves that the network thinks are promising get a higher prior probability, so MCTS explores them first. This is far more efficient than treating all moves equally.
-
During evaluation (replacing simulation), when MCTS reaches a leaf node, it does not play a random game to completion. Instead, it queries the neural network’s value head to get an instant estimate of the position’s value. This is faster and far more accurate than a random rollout, especially in complex games where random play is essentially noise.
The result is an MCTS that is guided by learned intuition. The network says “these moves look promising and this position looks good for me,” and MCTS uses that guidance to search much more efficiently. But MCTS also corrects the network — if the network thinks a move is good but search reveals it leads to a loss, MCTS will discover that through deeper exploration.
The training loop
Here is the AlphaGo Zero training loop in its entirety:
Step 1: Self-play. The current neural network plays games against itself, using MCTS to select moves. During each game, at every move, the system records three things: the board state, the MCTS visit counts for each move (which represent the search-improved policy), and the eventual game outcome (win or loss).
Step 2: Train. The neural network is updated to better match the self-play data. Specifically, it is trained to:
- Make the policy head’s output match the MCTS visit count distribution (because MCTS, with its search, has a better sense of which moves are good than the raw network does)
- Make the value head’s output match the actual game outcome (because the final result tells us who really was winning)
Step 3: Evaluate. Periodically, the newly trained network plays a match against the previous best network. If it wins convincingly (by a specified margin), it becomes the new best network and replaces the old one for future self-play games.
Step 4: Repeat. Go back to step 1 with the improved network.
This loop ran for roughly 4.9 million self-play games over 40 days on 4 TPUs (Google’s custom machine learning accelerators). The network steadily improved throughout.
The breakthrough results
The results from the 2017 paper (“Mastering the Game of Go without Human Knowledge,” published in Nature) were striking:
Surpassed all previous versions. After just 36 hours of training, AlphaGo Zero surpassed the version of AlphaGo that defeated Lee Sedol. After 72 hours, it surpassed AlphaGo Master (an improved version that had defeated the world’s top professionals 60-0 in online blitz games). By the end of training, it defeated AlphaGo Master 89-11 in a 100-game match.
Learned from scratch. The system started with completely random play. In the early hours of training, it played nonsensically. Within a few hours, it had discovered basic Go concepts — influence, territory, life and death. Within days, it had developed sophisticated strategic understanding. It rediscovered many patterns of play that took humans thousands of years to develop, and it also discovered novel strategies that professional players had never considered.
Simpler architecture. Counterintuitively, AlphaGo Zero used a simpler architecture than the original AlphaGo. One network instead of two. No handcrafted features — just raw stone positions on the board. No random rollouts. And yet it was dramatically stronger. Simplicity and generality turned out to be advantages, not handicaps.
More efficient. Despite learning from scratch, AlphaGo Zero required less computation overall than the original AlphaGo, which had been trained on human data plus RL. Starting from a blank slate actually made learning easier, because the system was not constrained by the biases and limitations of human play.
A year later, DeepMind published AlphaZero (2018), which applied the same approach to chess, shogi (Japanese chess), and Go simultaneously — the same algorithm, the same neural network architecture, the same hyperparameters. Within hours of training, AlphaZero achieved superhuman performance in all three games, defeating the strongest existing programs in each: Stockfish in chess, Elmo in shogi, and AlphaGo Zero itself in Go. This demonstrated that the self-play approach was not Go-specific — it was a general recipe for mastering two-player perfect-information games.
Why this matters beyond Go
The AlphaGo Zero result matters not just because it plays Go well, but because of what it demonstrates as a principle:
You do not need human knowledge to achieve superhuman performance. Given the right architecture (neural network + MCTS), the right training signal (self-play outcomes), and enough computation, an agent can discover strategies that exceed the best human understanding. The human experts are not the ceiling — they are a waypoint.
Self-generated data can be superior to human-curated data. This challenges the assumption, common in machine learning, that more data (especially human-labeled data) is always better. For certain problems, an agent’s own experience — tailored to its own level of play, covering the exact situations it encounters — is more valuable than any static dataset.
Search and learning are complementary. MCTS alone (with random rollouts) plays Go at an amateur level. A neural network alone (without search) plays at a strong amateur level. But MCTS guided by a neural network, with the network trained on MCTS-improved data, produces superhuman play. The whole is vastly greater than the sum of its parts.
How we will adapt this for Tic-Tac-Toe
In the remaining sections of this course, we will build a miniature version of the AlphaGo Zero system. The game will be Tic-Tac-Toe instead of Go, the neural network will be small and simple instead of a massive ResNet, and training will take seconds instead of days. But the core architecture and training loop will be the same:
-
A game engine (Part 2) that represents Tic-Tac-Toe state, enumerates legal moves, and detects terminal positions — the “environment” in RL terms.
-
An MCTS implementation (Part 3) that searches the game tree, initially using random rollouts, just as vanilla MCTS does.
-
A neural network with policy and value heads (Part 4) that takes a board state as input and outputs move probabilities and a position evaluation. We will integrate this with MCTS, replacing random rollouts with the value head and using the policy head to guide search.
-
A self-play training loop (Part 5) where the system plays Tic-Tac-Toe against itself, generates training data from those games, trains the neural network on that data, and repeats. We will watch the agent improve from completely random play to strong (ideally perfect) Tic-Tac-Toe.
Tic-Tac-Toe is a solved game — perfect play from both sides always results in a draw. This actually makes it an ideal testbed. We know what the “right answer” looks like, so we can measure exactly how close our agent gets to optimal play. And because the game is small (only 9 squares, games last at most 9 moves), training is fast enough to iterate and experiment on a single laptop.
The AlphaGo Zero recipe is remarkably general. The same ideas that produce superhuman Go play will produce strong Tic-Tac-Toe play. And understanding the recipe at this small scale — where you can inspect every node in the search tree and every weight in the neural network — will give you genuine insight into how and why it works at large scale.
In the next section, we will start building. First up: choosing our game and understanding what properties it needs to have for this approach to work.
Part 2 — The Game
4. Choosing a simple game: Tic-Tac-Toe
We ended Part 1 with a grand vision: a system that masters a game purely through self-play, combining MCTS with a neural network in the style of AlphaGo Zero. Before we start writing code, we need to pick a game. The choice matters more than it might seem — the wrong game can bury the core ideas under irrelevant complexity, while the right game keeps the focus exactly where it belongs.
Why start with a toy problem
Complex algorithms are hard to debug. When your self-play agent produces nonsensical moves after training, you want to know whether the bug is in your game logic, your MCTS implementation, your neural network architecture, or your training loop. If your game is Go (a 19x19 board with roughly 10^170 legal game positions), you cannot feasibly enumerate the game tree to check correctness. If your game is Tic-Tac-Toe, you can.
Toy problems are not a detour — they are a microscope. At small scale, you can inspect every node that MCTS visits, print the entire game tree, verify that the neural network’s policy output assigns probability to every legal move, and confirm that the value head’s evaluation matches the true game-theoretic value of each position. Once you understand the system at this resolution, scaling up to larger games becomes an engineering problem rather than a conceptual one.
AlphaZero’s authors themselves understood this principle. The same algorithm that defeated the world champion in Go was also tested on Tic-Tac-Toe-like games during development. Start simple, get it working, then scale.
What properties does a game need?
Not every game is suitable for the AlphaGo Zero approach. The framework we described in Part 1 — MCTS guided by a neural network, trained through self-play — relies on the game having several specific properties:
Two-player. Self-play requires exactly two players taking alternating roles. The agent plays both sides, generating training data from both perspectives. Single-player optimization problems and multiplayer (3+) games introduce complications that would distract from the core loop.
Zero-sum. One player’s gain is the other’s loss. When the agent wins as Player X, that same game is a loss from Player O’s perspective. This gives us a clean training signal: every game produces data for both “winning” and “losing” play, and the value of any position to one player is the exact negative of its value to the other.
Deterministic. No dice, no card draws, no random events. Given a state and a move, the next state is completely determined. This means MCTS can search the game tree without needing to account for chance nodes, and the neural network’s value estimate is a function of the position alone, not of hidden randomness.
Perfect information. Both players can see the entire game state at all times. There are no hidden cards, no fog of war. This is essential because our neural network takes the full board state as input — if information were hidden, we would need a fundamentally different architecture to handle uncertainty about the opponent’s private state.
Finite. The game always terminates after a bounded number of moves. Every branch of the game tree ends in a win, loss, or draw. This guarantees that MCTS rollouts terminate and that the game-theoretic value of every position is well-defined. (Contrast this with a game like Peg Solitaire variants where cycles might arise without special rules.)
These five properties — two-player, zero-sum, deterministic, perfect information, finite — define the class of games that the AlphaGo Zero approach handles directly. Chess, Go, Othello, Connect Four, and Tic-Tac-Toe all qualify. Poker does not (imperfect information). Backgammon does not (dice introduce randomness). Risk does not (multiplayer, random).
Tic-Tac-Toe: the rules
You almost certainly know this game already, but let us state the rules precisely, since we will need to implement them exactly:
- The board is a 3x3 grid, initially empty.
- Two players alternate turns. Player X moves first, Player O moves second.
- On your turn, you place your mark (X or O) in any empty cell.
- The first player to get three of their marks in a row — horizontally, vertically, or diagonally — wins.
- If all nine cells are filled and neither player has three in a row, the game is a draw.
That is the entire game. No special moves, no optional rules, no edge cases. This simplicity is a feature, not a limitation.
The game tree: small enough to verify, large enough to matter
How big is Tic-Tac-Toe, computationally?
If we count naively — the number of complete sequences of moves from the start to a full board, ignoring early termination — there are 9! = 362,880 possible games. But many games end before the board is full (someone wins), and many positions can be reached by different move orders. The actual numbers:
- ~255,168 possible games (unique sequences of moves from start to a terminal state, counting games that end early due to a win).
- 5,478 unique board states (distinct positions reachable during play, after accounting for the identity of which cells are filled but before considering symmetry).
- ~765 unique states after symmetry reduction (treating rotations and reflections as equivalent).
For comparison, chess has roughly 10^44 legal positions. Go has roughly 10^170. Tic-Tac-Toe’s entire game tree fits comfortably in memory on any modern machine. You can enumerate every state, compute the minimax value of every position, and verify that your MCTS agent converges to the correct answer. This is extraordinarily valuable when debugging.
At the same time, 5,478 states (or even 765 after symmetry) is not trivially small. It is large enough that a neural network has something meaningful to learn, MCTS has a tree worth searching, and the training loop has real work to do. You cannot solve it by just memorizing a lookup table during training — the network has to generalize.
The known solution: a built-in answer key
Tic-Tac-Toe is a solved game. Game theorists have exhaustively analyzed every possible position and determined the optimal move in each. The result: with perfect play from both sides, the game always ends in a draw. Player X (who moves first) cannot force a win against a perfect opponent, and Player O cannot force a win either.
This is enormously useful for us. We have a ground truth to validate our AI against:
- After training, our agent should never lose when playing against a perfect opponent (or against itself).
- Every game between two copies of a fully trained agent should end in a draw.
- The value head of the neural network should evaluate the starting position as approximately 0 (a draw), not as favoring either player.
- The policy head should never assign high probability to moves that are known blunders in solved positions.
Without this ground truth, we would have no way to know whether our trained agent was merely good or actually optimal. With Tic-Tac-Toe, we can measure exactly.
Why not a more interesting game?
You might wonder: why not use Connect Four, Othello, or even a small-board variant of Go? These are all valid targets for the AlphaGo Zero approach, and we will mention them as stretch goals later. But each adds complexity that makes it harder to learn the algorithm:
Connect Four (7 columns, 6 rows, four in a row to win) has roughly 4.5 trillion possible positions. It is also solved — Player 1 can force a win with perfect play — but the game tree is too large to exhaustively verify your implementation against. Training takes minutes to hours rather than seconds. Bugs become much harder to diagnose.
Othello (8x8 board, pieces flip) introduces a more complex move-legality rule (you must flip at least one opponent piece) and a board evaluation that shifts dramatically throughout the game. The neural network needs to be meaningfully larger to capture the patterns. This is interesting but distracting when you are trying to understand MCTS and self-play.
Small-board Go (e.g., 5x5 or 9x9) is closer to the original AlphaGo Zero domain, but Go’s rules — liberties, captures, ko, scoring — are substantially more complex to implement correctly than Tic-Tac-Toe’s. Getting the game engine right becomes a project in itself.
The principle is simple: learn the algorithm on the easiest possible game, then scale up. Once your self-play loop is producing a perfect Tic-Tac-Toe agent, adapting it to Connect Four requires only changing the game engine and enlarging the neural network. The MCTS code, the training loop, and the self-play pipeline remain the same. Tic-Tac-Toe is where we learn; other games are where we apply.
Mapping Tic-Tac-Toe to RL concepts
In §1, we introduced the core RL vocabulary: agent, environment, state, action, reward, and policy. Now that we have a concrete game, we can pin each concept to something specific:
| RL concept | Tic-Tac-Toe equivalent |
|---|---|
| Environment | The game rules — the 3x3 grid and the logic for placing marks, detecting wins, and detecting draws. |
| State | A specific board configuration: which cells contain X, which contain O, which are empty, and whose turn it is. |
| Action | Placing a mark in an empty cell. There are at most 9 possible actions, decreasing as the board fills. |
| Legal actions | The subset of empty cells in the current state. Our game engine must enumerate these. |
| Terminal state | A state where someone has three in a row (win/loss) or the board is full (draw). |
| Reward | +1 for a win, -1 for a loss, 0 for a draw. Only assigned at terminal states. |
| Policy | A function that takes a board state and returns a probability distribution over the legal moves. This is what the neural network’s policy head will learn. |
| Value | A function that takes a board state and returns an estimate of the expected outcome (between -1 and +1). This is what the neural network’s value head will learn. |
In the next section, we will translate these concepts into Rust data structures: an enum for cell contents, a struct for the board state, methods for making moves and checking for wins, and an interface that MCTS can plug into. The mapping above is the bridge between the theory of Part 1 and the code of Part 2.
5. Representing game state in Rust
In §4, we mapped reinforcement learning concepts onto Tic-Tac-Toe: state, action, terminal condition, reward. Now we translate that mapping into Rust data structures. By the end of this section, you will have a complete, compilable set of types and methods that represent a Tic-Tac-Toe game — the foundation that MCTS will plug into later.
We will build the code incrementally. Each piece is small, self-contained, and compiles on its own.
Choosing a board representation
There are two common ways to represent a Tic-Tac-Toe board in code:
Bitboard. Use two u16 values — one bitmask for X’s marks and one for O’s. Each of the 9 cells maps to a bit position. Checking for a winner becomes a handful of bitwise AND operations against precomputed winning masks. This approach is compact (two integers per board) and fast (winner detection is branchless), but it is harder to read and debug. Bitboards shine in games like chess where the board is large and performance-critical search must evaluate millions of positions per second.
Array of options. Use a fixed-size array of 9 elements, where each element is either Some(Player::X), Some(Player::O), or None (empty). This representation maps directly to how a human thinks about the board. Indexing, display, move generation, and winner checking are all straightforward loops and matches. The cost is a bit more memory and a few more branches — completely negligible for a 3x3 board.
We will use the array approach. Clarity matters more than performance at this scale, and the code will be much easier to follow as we layer MCTS and neural network integration on top. If you later scale up to a larger game, switching to a bitboard is a localized change — the interface stays the same.
The Player enum
A cell on the board is either empty, occupied by X, or occupied by O. We model the occupant as a Player enum:
#![allow(unused)]
fn main() {
/// One of the two players in a Tic-Tac-Toe game.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player {
X,
O,
}
impl Player {
/// Returns the other player.
///
/// This is used after each move to switch turns, and when
/// computing rewards from the opponent's perspective.
pub fn opponent(self) -> Player {
match self {
Player::X => Player::O,
Player::O => Player::X,
}
}
}
impl std::fmt::Display for Player {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Player::X => write!(f, "X"),
Player::O => write!(f, "O"),
}
}
}
}
A few things to note:
- We derive
CopybecausePlayeris a small enum with no heap data. Passing it around by value is cheaper than borrowing. - We derive
Hashbecause later we will want to use game states as keys in hash maps (for MCTS node lookup). - The
opponentmethod will appear everywhere: switching turns, computing rewards from the other player’s perspective, and building the self-play loop.
The GameState struct
A game state is a board plus whose turn it is:
#![allow(unused)]
fn main() {
/// A complete snapshot of a Tic-Tac-Toe game at a moment in time.
///
/// The board is a flat array of 9 cells. `current_player` indicates
/// whose turn it is to move. A fresh game starts with an empty board
/// and `Player::X` to move.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
/// The 3x3 board stored as a flat array.
/// Index 0 is the top-left cell; index 8 is the bottom-right.
pub board: [Option<Player>; 9],
/// The player who will make the next move.
pub current_player: Player,
}
impl GameState {
/// Creates a new game with an empty board. X moves first.
pub fn new() -> Self {
GameState {
board: [None; 9],
current_player: Player::X,
}
}
}
}
The board is [Option<Player>; 9] — exactly the type we discussed. None means empty, Some(Player::X) means X occupies that cell. The struct owns all its data (no references, no lifetimes), so it is easy to clone, compare, and hash.
Board indexing
We store the board as a flat array, but we think about it as a grid with rows and columns:
0 | 1 | 2
-----------
3 | 4 | 5
-----------
6 | 7 | 8
The conversion is straightforward:
- From (row, col) to index:
index = row * 3 + col - From index to (row, col):
row = index / 3,col = index % 3
We do not need separate functions for this — the arithmetic is simple enough to inline — but it helps to be explicit about the layout. Index 0 is the top-left corner, index 4 is the center, and index 8 is the bottom-right corner. Row 0 is the top row, column 0 is the left column.
This flat layout is also the format we will use when encoding the board as input to the neural network: 9 values in a fixed order.
Displaying the board
A Display implementation makes debugging and interactive play much easier. We want output that looks like a human-readable grid:
#![allow(unused)]
fn main() {
impl std::fmt::Display for GameState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
match self.board[index] {
Some(Player::X) => write!(f, " X ")?,
Some(Player::O) => write!(f, " O ")?,
None => write!(f, " . ")?,
}
if col < 2 {
write!(f, "|")?;
}
}
writeln!(f)?;
if row < 2 {
writeln!(f, "-----------")?;
}
}
Ok(())
}
}
}
A board mid-game might display as:
X | . | O
-----------
. | X | .
-----------
. | . | .
Using . for empty cells avoids ambiguity with whitespace and makes it easy to see the board structure at a glance.
Representing moves
A move in Tic-Tac-Toe is simply placing a mark in an empty cell. Since the board is a flat array of 9 cells, a move is just an index from 0 to 8. There is no need for a dedicated struct — a usize carries all the information we need.
We could wrap it in a newtype (struct Move(usize)) for type safety, but the overhead is not worth it for a game this simple. Every function that takes a move will also take or have access to the game state, which provides all the context needed to interpret the index.
Generating legal moves
A move is legal if and only if the target cell is empty. Generating all legal moves is a matter of iterating the board and collecting the indices of None cells:
#![allow(unused)]
fn main() {
impl GameState {
/// Returns the indices of all empty cells — the legal moves
/// in this position.
///
/// In a terminal state (someone has won or the board is full),
/// this returns an empty vector.
pub fn legal_moves(&self) -> Vec<usize> {
// If the game is already over, no moves are legal.
if self.winner().is_some() {
return Vec::new();
}
self.board
.iter()
.enumerate()
.filter(|(_, cell)| cell.is_none())
.map(|(index, _)| index)
.collect()
}
}
}
This method is called frequently by MCTS — once per node expansion. For a 3x3 board the cost is trivial, but note that we check for a winner first. Once someone has won, no further moves should be legal, even if empty cells remain.
Checking for a winner
There are exactly 8 ways to win a Tic-Tac-Toe game: 3 rows, 3 columns, and 2 diagonals. We enumerate them explicitly:
#![allow(unused)]
fn main() {
/// The eight lines that constitute a win: three rows, three columns,
/// and two diagonals. Each line is a triple of board indices.
const WIN_LINES: [[usize; 3]; 8] = [
// Rows
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
// Columns
[0, 3, 6],
[1, 4, 7],
[2, 5, 8],
// Diagonals
[0, 4, 8],
[2, 4, 6],
];
impl GameState {
/// Returns the winner if one exists.
///
/// A player wins by occupying all three cells in any row,
/// column, or diagonal. Returns `None` if no player has
/// three in a row.
pub fn winner(&self) -> Option<Player> {
for line in &WIN_LINES {
let [a, b, c] = *line;
if let (Some(p1), Some(p2), Some(p3)) =
(self.board[a], self.board[b], self.board[c])
{
if p1 == p2 && p2 == p3 {
return Some(p1);
}
}
}
None
}
}
}
The WIN_LINES constant is defined outside the impl block — it is pure data with no connection to any particular game state. The winner method checks each line: if all three cells are occupied by the same player, that player has won. If no line is fully claimed, the game has no winner (yet, or ever, in the case of a draw).
Why a const array rather than computing the lines from row/column arithmetic? Explicitness. There are only 8 lines. Writing them out makes the logic immediately verifiable by inspection — you can count the rows, columns, and diagonals yourself. No off-by-one errors to worry about.
Terminal state detection
A state is terminal when either someone has won or the board is full (a draw). We also define a method to check whether the game is a draw specifically:
#![allow(unused)]
fn main() {
impl GameState {
/// Returns `true` if the game is over — either a player has won
/// or the board is completely filled (a draw).
pub fn is_terminal(&self) -> bool {
self.winner().is_some() || self.board.iter().all(|cell| cell.is_some())
}
/// Returns `true` if the game ended in a draw: the board is full
/// and no player has three in a row.
pub fn is_draw(&self) -> bool {
self.winner().is_none() && self.board.iter().all(|cell| cell.is_some())
}
}
}
These two methods, combined with winner, give us everything we need to evaluate terminal states:
winner()returnsSome(player)if someone has won.is_draw()returnstrueif the board is full with no winner.is_terminal()returnstruein either case.
MCTS will call is_terminal to know when to stop a rollout, and winner to assign the reward (+1 for a win, -1 for a loss, 0 for a draw).
Applying moves: the immutability design
When MCTS explores the game tree, it needs to “try” a move and see what state results. There are two ways to design this:
Mutable approach: modify the board in place, then undo the move when backtracking. This is memory-efficient (one board allocation for the entire search) but error-prone. Forgetting to undo a move — or undoing it in the wrong order — produces subtle, hard-to-diagnose bugs. The bookkeeping complexity grows with game complexity.
Immutable approach: apply_move takes a state and a move, returns a new state with the move applied. The original state is unchanged. This is conceptually cleaner — each node in the search tree corresponds to an independent GameState value. Cloning a 9-element array is cheap. For Tic-Tac-Toe, the memory cost is negligible.
We choose the immutable approach:
#![allow(unused)]
fn main() {
impl GameState {
/// Returns a new `GameState` with the given move applied.
///
/// The move is an index (0-8) into the board. The current player's
/// mark is placed in that cell, and the turn switches to the
/// opponent.
///
/// # Panics
///
/// Panics if the cell is already occupied or the index is out of
/// bounds. In a correct program, `apply_move` is only called with
/// indices returned by `legal_moves`.
pub fn apply_move(&self, index: usize) -> GameState {
assert!(
self.board[index].is_none(),
"Cell {} is already occupied",
index
);
let mut new_board = self.board;
new_board[index] = Some(self.current_player);
GameState {
board: new_board,
current_player: self.current_player.opponent(),
}
}
}
}
Notice that let mut new_board = self.board copies the array (it implements Copy), so we are not mutating the original. The new state has the mark placed and the turn switched to the opponent.
The assert! is a deliberate choice: calling apply_move on an occupied cell is always a bug in the caller. Panicking immediately makes such bugs loud and easy to find. In production game engines, you might return a Result instead, but for a learning project, panicking on programmer error is the right call.
Putting it all together
Here is the complete module, assembled from the pieces above. This is a single, compilable Rust file:
#![allow(unused)]
fn main() {
use std::fmt;
/// One of the two players in a Tic-Tac-Toe game.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player {
X,
O,
}
impl Player {
/// Returns the other player.
pub fn opponent(self) -> Player {
match self {
Player::X => Player::O,
Player::O => Player::X,
}
}
}
impl fmt::Display for Player {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Player::X => write!(f, "X"),
Player::O => write!(f, "O"),
}
}
}
/// The eight winning lines: three rows, three columns, two diagonals.
const WIN_LINES: [[usize; 3]; 8] = [
[0, 1, 2], [3, 4, 5], [6, 7, 8], // rows
[0, 3, 6], [1, 4, 7], [2, 5, 8], // columns
[0, 4, 8], [2, 4, 6], // diagonals
];
/// A complete snapshot of a Tic-Tac-Toe game.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
pub board: [Option<Player>; 9],
pub current_player: Player,
}
impl GameState {
/// Creates a new game with an empty board. X moves first.
pub fn new() -> Self {
GameState {
board: [None; 9],
current_player: Player::X,
}
}
/// Returns the winner, if any.
pub fn winner(&self) -> Option<Player> {
for line in &WIN_LINES {
let [a, b, c] = *line;
if let (Some(p1), Some(p2), Some(p3)) =
(self.board[a], self.board[b], self.board[c])
{
if p1 == p2 && p2 == p3 {
return Some(p1);
}
}
}
None
}
/// Returns `true` if the game is over.
pub fn is_terminal(&self) -> bool {
self.winner().is_some()
|| self.board.iter().all(|cell| cell.is_some())
}
/// Returns `true` if the game is a draw.
pub fn is_draw(&self) -> bool {
self.winner().is_none()
&& self.board.iter().all(|cell| cell.is_some())
}
/// Returns the legal moves (indices of empty cells).
pub fn legal_moves(&self) -> Vec<usize> {
if self.winner().is_some() {
return Vec::new();
}
self.board
.iter()
.enumerate()
.filter(|(_, cell)| cell.is_none())
.map(|(index, _)| index)
.collect()
}
/// Returns a new state with the move applied.
///
/// # Panics
///
/// Panics if the cell is already occupied.
pub fn apply_move(&self, index: usize) -> GameState {
assert!(self.board[index].is_none(), "Cell {} is occupied", index);
let mut new_board = self.board;
new_board[index] = Some(self.current_player);
GameState {
board: new_board,
current_player: self.current_player.opponent(),
}
}
}
impl fmt::Display for GameState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
match self.board[index] {
Some(Player::X) => write!(f, " X ")?,
Some(Player::O) => write!(f, " O ")?,
None => write!(f, " . ")?,
}
if col < 2 {
write!(f, "|")?;
}
}
writeln!(f)?;
if row < 2 {
writeln!(f, "-----------")?;
}
}
Ok(())
}
}
}
What we have built
Let us take stock. We now have a game representation that provides everything MCTS needs:
| RL concept | Rust implementation |
|---|---|
| State | GameState — board contents + current player |
| Action | usize — index into the board (0-8) |
| Legal actions | GameState::legal_moves() — indices of empty cells |
| Transition | GameState::apply_move(index) — returns a new state |
| Terminal test | GameState::is_terminal() — win or draw |
| Reward | Derived from GameState::winner() — +1, -1, or 0 |
This interface is deliberately minimal. There is no game loop, no input handling, no AI — just the pure game logic. In §6, you will use these types to implement a complete playable game. In §7, MCTS will call legal_moves, apply_move, is_terminal, and winner to search the game tree. The neural network in Part 4 will take the board array as its input.
The immutable design means every GameState is a self-contained snapshot. You can store them in a hash map, compare them for equality, collect them into training datasets, and pass them between functions without worrying about aliasing or mutation bugs. This simplicity will pay dividends as the system grows more complex.
6. Exercise 1: implement the game logic
Time to get your hands dirty. In this exercise you will create a Rust project, copy in the game module from §5, write unit tests to prove it works, and then build a small main function that plays a random game so you can watch the board evolve turn by turn.
Step 1: create the project
cargo new ttt-selfplay
cd ttt-selfplay
This gives you a standard Cargo project with src/main.rs. You will add one more file.
Step 2: create the game module
Create a new file src/game.rs and paste in the complete module from the end of §5 — everything from use std::fmt; through the Display impl for GameState. The code is self-contained and has no external dependencies.
Then tell Rust about the module by adding this line at the top of src/main.rs:
#![allow(unused)]
fn main() {
mod game;
}
Verify that it compiles:
cargo check
If you see errors, make sure every item in game.rs that needs to be used from outside the module is marked pub (all types and methods in the §5 listing already are).
Step 3: write unit tests
Add a test module at the bottom of src/game.rs:
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
use super::*;
// Your tests go here.
}
}
Write the following eight tests. The names and descriptions tell you what to assert — try writing the bodies yourself before expanding the solution.
test_new_game_is_empty — A freshly created GameState::new() should have all nine cells set to None, and the current player should be Player::X.
test_legal_moves_initial — On a new board, legal_moves() should return all nine indices (0 through 8).
test_apply_move — After calling apply_move(4) on a new game, cell 4 should contain Some(Player::X) and the current player should have switched to Player::O.
test_winner_row — Set up a board where X occupies the top row (cells 0, 1, 2) and verify that winner() returns Some(Player::X).
test_winner_col — Set up a board where O occupies the left column (cells 0, 3, 6) and verify that winner() returns Some(Player::O).
test_winner_diagonal — Set up a board where X occupies the main diagonal (cells 0, 4, 8) and verify that winner() returns Some(Player::X).
test_draw — Play out a full game that ends in a draw. Verify that is_terminal() is true, winner() is None, and is_draw() is true.
test_no_moves_after_win — From a state where someone has won, legal_moves() should return an empty vector — even though empty cells remain on the board.
Run the tests with:
cargo test
All eight should pass. If any fail, re-read the §5 code — the logic is already correct, so the bug is in your test setup.
Solution — all eight tests
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_game_is_empty() {
let state = GameState::new();
for cell in &state.board {
assert_eq!(*cell, None);
}
assert_eq!(state.current_player, Player::X);
}
#[test]
fn test_legal_moves_initial() {
let state = GameState::new();
let moves = state.legal_moves();
assert_eq!(moves.len(), 9);
for i in 0..9 {
assert!(moves.contains(&i));
}
}
#[test]
fn test_apply_move() {
let state = GameState::new();
let next = state.apply_move(4);
assert_eq!(next.board[4], Some(Player::X));
assert_eq!(next.current_player, Player::O);
// Original state is unchanged (immutable design).
assert_eq!(state.board[4], None);
}
#[test]
fn test_winner_row() {
// X plays 0, 1, 2 (top row). O plays 3, 4.
let state = GameState::new()
.apply_move(0) // X
.apply_move(3) // O
.apply_move(1) // X
.apply_move(4) // O
.apply_move(2); // X — wins
assert_eq!(state.winner(), Some(Player::X));
}
#[test]
fn test_winner_col() {
// O plays 0, 3, 6 (left column). X plays 1, 4, 8.
let state = GameState::new()
.apply_move(1) // X
.apply_move(0) // O
.apply_move(4) // X
.apply_move(3) // O
.apply_move(8) // X
.apply_move(6); // O — wins
assert_eq!(state.winner(), Some(Player::O));
}
#[test]
fn test_winner_diagonal() {
// X plays 0, 4, 8 (main diagonal). O plays 1, 2.
let state = GameState::new()
.apply_move(0) // X
.apply_move(1) // O
.apply_move(4) // X
.apply_move(2) // O
.apply_move(8); // X — wins
assert_eq!(state.winner(), Some(Player::X));
}
#[test]
fn test_draw() {
// A classic draw: X O X / X X O / O X O
let state = GameState::new()
.apply_move(0) // X
.apply_move(1) // O
.apply_move(2) // X
.apply_move(4) // O
.apply_move(3) // X
.apply_move(5) // O
.apply_move(7) // X
.apply_move(6) // O
.apply_move(8); // X
assert!(state.is_terminal());
assert!(state.is_draw());
assert_eq!(state.winner(), None);
}
#[test]
fn test_no_moves_after_win() {
// X wins on the top row; cells 3-8 still have empties.
let state = GameState::new()
.apply_move(0) // X
.apply_move(3) // O
.apply_move(1) // X
.apply_move(4) // O
.apply_move(2); // X — wins
assert_eq!(state.winner(), Some(Player::X));
assert!(state.legal_moves().is_empty());
}
}
}
Step 4: play a random game
Replace the contents of src/main.rs with the following. It picks a random legal move each turn, prints the board, and announces the result.
mod game;
use game::{GameState, Player};
use rand::Rng;
fn main() {
let mut rng = rand::rng();
let mut state = GameState::new();
let mut turn = 0;
while !state.is_terminal() {
let moves = state.legal_moves();
let chosen = moves[rng.random_range(..moves.len())];
println!("Turn {}: {} plays cell {}", turn, state.current_player, chosen);
state = state.apply_move(chosen);
println!("{}", state);
turn += 1;
}
match state.winner() {
Some(Player::X) => println!("X wins!"),
Some(Player::O) => println!("O wins!"),
None => println!("Draw!"),
}
}
You will need the rand crate. Add it to Cargo.toml:
cargo add rand
Then run the program:
cargo run
You should see a full game play out with the board printed after every move, ending in a win or a draw. Run it a few times — you will see different outcomes each time.
Readiness checklist
Before moving on to §7, confirm:
-
cargo checkcompiles without errors -
cargo testpasses all eight unit tests -
cargo runplays a random game to completion and prints the result - You understand how
apply_movereturns a new state rather than mutating the old one - You can trace how
legal_movesreturns an empty vector after a win, even if empty cells remain
If everything checks out, your game engine is ready. In the next section, we will build an MCTS player that uses this exact interface to search for strong moves.
Part 3 — MCTS
7. Implementing MCTS in Rust
In §2, we walked through MCTS as an algorithm: selection, expansion, simulation, backpropagation, and the UCT formula that balances exploration against exploitation. In §5 and §6, we built a complete game engine with GameState, Player, legal_moves, apply_move, is_terminal, and winner. Now we bring those two threads together and implement MCTS in Rust.
This is the most substantial implementation section in the course. Take your time with it. By the end, you will have a working MCTS player that can choose strong moves in Tic-Tac-Toe — no neural network required.
The problem with tree pointers in Rust
Before we write any code, we need to talk about a design decision that trips up nearly everyone who implements tree structures in Rust for the first time.
MCTS operates on a tree. Each node has a parent and zero or more children. In a language like Python or Java, you would represent this with object references: each node holds a pointer to its parent and a list of pointers to its children. Trees are natural in garbage-collected languages because the garbage collector handles ownership and cleanup.
Rust does not have a garbage collector. It has ownership. And ownership means every value has exactly one owner. A tree node that is “owned” by its parent cannot also be “owned” by its children (who need a back-pointer to the parent). You immediately run into a conflict.
There are several ways to handle this in Rust:
Rc<RefCell<MctsNode>> — reference-counted, interior-mutable nodes. Each node is wrapped in Rc (shared ownership) and RefCell (runtime borrow checking). Children hold Rc pointers to the parent; the parent holds Rc pointers to its children. This works, but it is verbose, allocates each node separately on the heap, and introduces runtime borrow-checking overhead. Every access requires .borrow() or .borrow_mut() calls. Debugging borrow violations becomes a runtime problem instead of a compile-time one.
unsafe raw pointers. You can use *mut MctsNode pointers for parent references. This is how you might do it in C. It works, but you lose Rust’s safety guarantees and have to manually reason about pointer validity.
Arena allocation with indices. Store all nodes in a single Vec<MctsNode>. Instead of pointers, use usize indices into the vector. A node’s “parent” is an index. Its “children” are a list of indices. The Vec owns all the nodes, and indices are just numbers — no ownership conflicts, no lifetimes, no Rc, no unsafe.
We will use the arena approach. It is the simplest, fastest, and most idiomatic way to build trees in Rust. Here is why:
- No lifetime annotations. The
Vecowns everything. Indices are plain integers that can be copied, stored, and compared freely. - Cache-friendly. All nodes live in a contiguous block of memory. Walking the tree means accessing elements in a vector, which is fast because modern CPUs love sequential memory access.
- Easy to debug. You can print the entire tree by iterating the vector. Each node’s index is its identity.
- No runtime overhead. Indexing a
Vecis a bounds-checked array access. No reference counting, no borrow checking, no pointer chasing.
The tradeoff is that you cannot remove nodes from the middle of the vector without invalidating indices (you would need a more sophisticated arena for that). For MCTS, this is not a problem — we only add nodes, never remove them.
The MctsNode struct
Each node in the MCTS tree represents a game state reached by a particular sequence of moves. It stores the statistics needed by UCT (visit count and total value), the connections to its parent and children, and the move that led to this position:
#![allow(unused)]
fn main() {
/// A single node in the MCTS search tree.
///
/// Nodes are stored in a flat `Vec` (arena). Parent and child
/// relationships are represented as indices into that `Vec`,
/// not as pointers or references.
#[derive(Debug, Clone)]
struct MctsNode {
/// The game state at this node.
state: GameState,
/// The move (board index) that was played to reach this state
/// from the parent. `None` for the root node.
move_from_parent: Option<usize>,
/// Index of the parent node in the arena. `None` for the root.
parent: Option<usize>,
/// Indices of child nodes in the arena.
children: Vec<usize>,
/// Number of times this node has been visited during search.
visit_count: u32,
/// Total accumulated value from all simulations through this node.
/// Positive means good for the player who *just moved* to reach
/// this state (i.e., the opponent of `state.current_player`).
total_value: f64,
/// The legal moves from this state that have not yet been expanded
/// into child nodes. When this list is empty, the node is "fully
/// expanded."
unexpanded_moves: Vec<usize>,
}
}
A few things deserve explanation:
move_from_parent stores the action that led here. When MCTS finishes and we want to know which move to play, we look at the root’s children and read their move_from_parent values. The root itself has no parent move, so it is None.
unexpanded_moves is a key bookkeeping field. When we create a node, we populate this with all legal moves from the state. Each time we expand a child, we remove a move from this list. When the list is empty, the node is fully expanded and selection will use UCT to pick among existing children instead of expanding a new one.
total_value perspective. This is the trickiest part of MCTS to get right. The value stored at a node is from the perspective of the player who moved to reach this node — that is, the opponent of state.current_player. Why? Because when the parent node is choosing among its children using UCT, it needs to know how good each child is for the parent’s player. The parent’s player is the one who made the move to reach the child. So the child’s value should reflect how well things went for that player.
This will become clearer when we implement backpropagation.
The MCTS tree (arena)
The tree itself is just a Vec of nodes plus a reference to the root:
#![allow(unused)]
fn main() {
/// The MCTS search tree, stored as an arena of nodes.
struct MctsTree {
/// All nodes in the tree. The root is always at index 0.
nodes: Vec<MctsNode>,
}
}
The root is always at index 0 — we create it first, and the Vec never reorders elements. Let’s add a constructor:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Creates a new MCTS tree rooted at the given game state.
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
MctsTree {
nodes: vec![root],
}
}
}
}
When the root is created, we immediately compute its legal moves and store them in unexpanded_moves. The first few iterations of MCTS will expand these into child nodes.
The UCT formula
Before implementing the four phases, let’s write the UCT calculation as a standalone function. Recall from §2:
UCT(child) = (W / N) + C * sqrt(ln(N_parent) / N)
In Rust:
#![allow(unused)]
fn main() {
use std::f64::consts::SQRT_2;
/// Computes the UCT score for a child node.
///
/// - `child_value`: total value accumulated at the child
/// - `child_visits`: number of times the child has been visited
/// - `parent_visits`: number of times the parent has been visited
///
/// Returns `f64::INFINITY` if the child has never been visited
/// (ensuring unvisited children are selected first).
fn uct_score(child_value: f64, child_visits: u32, parent_visits: u32) -> f64 {
if child_visits == 0 {
return f64::INFINITY;
}
let n = child_visits as f64;
let exploitation = child_value / n;
let exploration = SQRT_2 * ((parent_visits as f64).ln() / n).sqrt();
exploitation + exploration
}
}
The exploration constant C = sqrt(2) comes from the theoretical UCB1 analysis and is a good default. Some implementations tune this value — a larger C explores more aggressively, a smaller C exploits more greedily. For Tic-Tac-Toe, sqrt(2) works well.
Note the special case: if a child has never been visited, we return infinity. This guarantees that unvisited children are always selected before any visited child, which means MCTS tries every move at least once before starting to use statistics to discriminate.
Phase 1: Selection
Selection starts at the root and walks down the tree, picking the child with the highest UCT score at each step. It stops when it reaches a node that has unexpanded moves (meaning we should expand rather than descend further) or a terminal state (meaning the game is over and there is nothing to expand).
#![allow(unused)]
fn main() {
impl MctsTree {
/// Selects a leaf node for expansion by walking from the root,
/// always choosing the child with the highest UCT score.
///
/// Returns the index of the selected node.
fn select(&self) -> usize {
let mut current = 0; // start at root
loop {
let node = &self.nodes[current];
// If this node has unexpanded moves, stop here — we will
// expand one of them.
if !node.unexpanded_moves.is_empty() {
return current;
}
// If this node is terminal (no children and no unexpanded
// moves), stop here — there is nothing to expand.
if node.children.is_empty() {
return current;
}
// All moves have been expanded. Pick the child with the
// highest UCT score.
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = uct_score(
child.total_value,
child.visit_count,
parent_visits,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
}
}
This is a straightforward loop, not recursion. We walk from the root toward the leaves, always following the highest-UCT child. The loop terminates because:
- If the current node has unexpanded moves, we return it immediately.
- If the current node is terminal (no children, no unexpanded moves), we return it.
- Otherwise, we descend to a child, which is deeper in the tree.
Since the tree is finite (Tic-Tac-Toe games end in at most 9 moves), we are guaranteed to reach case 1 or 2.
Phase 2: Expansion
Once selection hands us a node, we expand it by picking one of its unexpanded moves, creating a new child node, and returning the child’s index:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Expands the given node by picking one unexpanded move, creating
/// a child node for it, and returning the child's index.
///
/// If the node has no unexpanded moves (it is fully expanded or
/// terminal), returns `None`.
fn expand(&mut self, node_idx: usize) -> Option<usize> {
// Pop an unexpanded move. If there are none, we cannot expand.
let mv = self.nodes[node_idx].unexpanded_moves.pop()?;
// Create the child state by applying the move.
let child_state = self.nodes[node_idx].state.apply_move(mv);
let unexpanded = child_state.legal_moves();
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
Some(child_idx)
}
}
}
We use pop() to grab an unexpanded move. The order does not matter — MCTS will eventually try all of them. Using pop is convenient because it returns None when the list is empty, which signals that expansion is not possible.
Notice the two-step process for adding the child: first we push it onto self.nodes (which gives it an index), then we record that index in the parent’s children list. We also record node_idx as the child’s parent. These bidirectional links let us walk both up (for backpropagation) and down (for selection) the tree.
Because we are using indices rather than references, there is no borrow-checker conflict. We can freely read self.nodes[node_idx] and then push to self.nodes without Rust complaining about aliased mutable borrows — we just need to be careful not to hold a reference across the push. The code above handles this correctly by extracting the move and state into local variables before mutating the vector.
Phase 3: Simulation (rollout)
From the newly expanded node (or a terminal node if expansion was not possible), we play a random game to completion and return the result:
#![allow(unused)]
fn main() {
use rand::seq::SliceRandom;
impl MctsTree {
/// Plays random moves from the given state until the game ends.
/// Returns the reward from the perspective of the player who
/// just moved to reach `state`:
///
/// - `+1.0` if that player eventually wins
/// - `-1.0` if that player eventually loses
/// - `0.0` for a draw
///
/// The "player who just moved" is the opponent of
/// `state.current_player`.
fn simulate(&self, state: &GameState, rng: &mut impl rand::Rng) -> f64 {
let mut current = state.clone();
// Play random moves until the game is over.
while !current.is_terminal() {
let moves = current.legal_moves();
let &mv = moves.choose(rng).expect("non-terminal state has moves");
current = current.apply_move(mv);
}
// Determine the reward from the perspective of the player
// who moved to reach the *starting* state of this rollout.
// That player is state.current_player.opponent() — the one
// whose turn it is NOT.
let rollout_player = state.current_player.opponent();
match current.winner() {
Some(winner) if winner == rollout_player => 1.0,
Some(_) => -1.0,
None => 0.0, // draw
}
}
}
}
The rollout is intentionally simple: pick a random legal move, apply it, repeat. No strategy, no heuristics — pure random play. Despite this simplicity, the statistical signal is strong enough for MCTS to work. The key insight from §2 is that MCTS does not need accurate rollouts, just informative ones. Random play provides enough information to distinguish good positions from bad ones.
The reward perspective needs care. The value we return will be stored at the node from which the rollout started. That node was reached by a move from its parent. The parent’s player is state.current_player.opponent() — that is the player “who just moved.” So we compute the reward from that player’s perspective: +1 if they won, -1 if they lost, 0 for a draw.
We use the rand crate for random number generation. The choose method on slices picks a random element. We pass in a mutable reference to an Rng (random number generator) so the caller controls the source of randomness — this makes testing easier and avoids creating a new RNG on every call.
Phase 4: Backpropagation
After the rollout produces a reward, we propagate it back up the tree from the expanded node to the root, updating visit counts and value sums:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Propagates the simulation result back up from `node_idx` to
/// the root, updating visit counts and value sums.
///
/// `value` is the reward from the perspective of the player who
/// moved to reach `node_idx`.
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
// Move to the parent. Flip the sign of the value because
// the parent represents the opponent's perspective.
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
}
}
This is the step where the sign-flipping happens, and it is the most important detail to get right.
Consider a concrete scenario. Suppose X plays move 4 (center), reaching a child node. From that child, the rollout plays out and X wins. The reward is +1.0 (good for X, who moved to reach this node).
Now we backpropagate:
- At the child node (the one X moved to): add +1.0 to
total_value, incrementvisit_count. - At the parent node (where it was X’s turn): flip the sign, add -1.0 to
total_value. This makes sense — from the parent’s perspective, the parent’s player is X, but the value stored at the parent is from the perspective of whoever moved to reach the parent. Wait, this is confusing. Let’s think about it differently.
The sign flip works because of a simple invariant: a node’s total_value records how good this node is for the player who chose to come here. The parent node was reached by a move from the grandparent’s player. That player is the opponent of the player who made the move from parent to child. So the child’s result, from the grandparent-player’s perspective, has the opposite sign.
If this still feels abstract, trace through a small example on paper. Place three nodes: grandparent (O’s turn), parent (X’s turn, reached by O’s move), child (O’s turn, reached by X’s move). If X wins the rollout from the child:
- Child: +1.0 (X moved here, X won — good for X)
- Parent: -1.0 (O moved here, X won — bad for O)
- Grandparent: +1.0 (X moved here, X won — good for X)
The alternating sign correctly reflects each player’s perspective.
The main MCTS loop
Now we assemble the four phases into the main search loop:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Runs MCTS for the given number of iterations and returns the
/// index of the best move from the root position.
fn search(
&mut self,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
for _ in 0..iterations {
// Phase 1: Selection
let selected = self.select();
// Phase 2: Expansion
// If the selected node can be expanded, expand it.
// Otherwise (terminal state), use the selected node itself.
let node_to_simulate = match self.expand(selected) {
Some(child_idx) => child_idx,
None => selected,
};
// Phase 3: Simulation
let state = self.nodes[node_to_simulate].state.clone();
let value = self.simulate(&state, rng);
// Phase 4: Backpropagation
self.backpropagate(node_to_simulate, value);
}
// After all iterations, pick the child of root with the
// highest visit count.
self.best_move()
}
/// Returns the move (board index) of the root's most-visited child.
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have at least one child after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root's children always have a move")
}
}
}
Each iteration is straightforward: select, expand, simulate, backpropagate. After all iterations are complete, we look at the root’s children and return the move of the one with the most visits.
Why highest visit count rather than highest average value? As we discussed in §2, visit count is a more robust signal. A child with 500 visits and a 60% win rate is a more reliable choice than a child with 3 visits and a 100% win rate. The 100% win rate might just be lucky — three rollouts is not enough data. The UCT formula naturally funnels more visits toward genuinely strong moves, so the most-visited child is almost always the best.
A public API function
Let’s wrap everything in a clean public function that takes a game state, a number of iterations, and returns the best move:
#![allow(unused)]
fn main() {
/// Runs MCTS from the given position for the specified number of
/// iterations and returns the best move (a board index).
pub fn mcts_best_move(
state: &GameState,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search(iterations, rng)
}
}
This is the only function that code outside the MCTS module needs to call. Pass in the current game state, the computational budget (number of iterations), and a random number generator. Get back a move.
Complete implementation
Here is the entire MCTS module assembled into a single, compilable block. This goes in src/mcts.rs (or alongside your game code in main.rs — your choice):
#![allow(unused)]
fn main() {
use rand::seq::SliceRandom;
use std::f64::consts::SQRT_2;
// -- Import your game types (adjust the path to match your module layout) --
// use crate::{GameState, Player};
/// A single node in the MCTS search tree.
///
/// Nodes are stored in a flat `Vec` (arena allocation). Parent and
/// child links are indices, not pointers.
#[derive(Debug, Clone)]
struct MctsNode {
/// The game state at this node.
state: GameState,
/// The move that led to this state from the parent. `None` for root.
move_from_parent: Option<usize>,
/// Index of the parent node in the arena. `None` for the root.
parent: Option<usize>,
/// Indices of child nodes in the arena.
children: Vec<usize>,
/// Number of times this node has been visited.
visit_count: u32,
/// Total accumulated value (from the perspective of the player
/// who moved to reach this node).
total_value: f64,
/// Legal moves not yet expanded into children.
unexpanded_moves: Vec<usize>,
}
/// The MCTS search tree, backed by arena allocation.
struct MctsTree {
/// All nodes. The root is always at index 0.
nodes: Vec<MctsNode>,
}
/// Computes the UCT score for a child node.
///
/// Returns `f64::INFINITY` for unvisited children, ensuring they
/// are selected before any visited child.
fn uct_score(child_value: f64, child_visits: u32, parent_visits: u32) -> f64 {
if child_visits == 0 {
return f64::INFINITY;
}
let n = child_visits as f64;
let exploitation = child_value / n;
let exploration = SQRT_2 * ((parent_visits as f64).ln() / n).sqrt();
exploitation + exploration
}
impl MctsTree {
/// Creates a new tree rooted at the given game state.
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
MctsTree {
nodes: vec![root],
}
}
/// Phase 1: Selection.
///
/// Walks from the root to a node that either has unexpanded moves
/// or is terminal, always picking the child with the highest UCT
/// score.
fn select(&self) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
if !node.unexpanded_moves.is_empty() {
return current;
}
if node.children.is_empty() {
return current;
}
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = uct_score(
child.total_value,
child.visit_count,
parent_visits,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
/// Phase 2: Expansion.
///
/// Picks one unexpanded move from the given node, creates a child
/// node for it, and returns the child's index. Returns `None` if
/// the node is fully expanded or terminal.
fn expand(&mut self, node_idx: usize) -> Option<usize> {
let mv = self.nodes[node_idx].unexpanded_moves.pop()?;
let child_state = self.nodes[node_idx].state.apply_move(mv);
let unexpanded = child_state.legal_moves();
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
Some(child_idx)
}
/// Phase 3: Simulation (rollout).
///
/// Plays random moves from the given state until the game ends.
/// Returns the reward from the perspective of the player who
/// moved to reach this state.
fn simulate(
&self,
state: &GameState,
rng: &mut impl rand::Rng,
) -> f64 {
let mut current = state.clone();
while !current.is_terminal() {
let moves = current.legal_moves();
let &mv = moves
.choose(rng)
.expect("non-terminal state has moves");
current = current.apply_move(mv);
}
let rollout_player = state.current_player.opponent();
match current.winner() {
Some(winner) if winner == rollout_player => 1.0,
Some(_) => -1.0,
None => 0.0,
}
}
/// Phase 4: Backpropagation.
///
/// Walks from the given node back to the root, updating visit
/// counts and value sums. Flips the sign at each step to account
/// for alternating players.
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
/// Runs the full MCTS loop for the given number of iterations.
/// Returns the best move (board index) from the root position.
fn search(
&mut self,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
for _ in 0..iterations {
let selected = self.select();
let node_to_simulate = match self.expand(selected) {
Some(child_idx) => child_idx,
None => selected,
};
let state = self.nodes[node_to_simulate].state.clone();
let value = self.simulate(&state, rng);
self.backpropagate(node_to_simulate, value);
}
self.best_move()
}
/// Returns the move of the root's most-visited child.
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have children after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root children have moves")
}
}
/// Runs MCTS from the given position and returns the best move.
///
/// `iterations` controls the computational budget — more iterations
/// means stronger play but longer computation.
pub fn mcts_best_move(
state: &GameState,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search(iterations, rng)
}
}
Make sure you have the rand crate in your Cargo.toml:
[dependencies]
rand = "0.8"
Walking through key decisions
Let’s revisit the important design choices in this implementation:
Arena allocation. The Vec<MctsNode> arena avoids all the complexity of reference-counted trees. Every node has a stable index for the lifetime of the search. The only caveat is that expand must extract data from self.nodes[node_idx] into local variables before calling self.nodes.push(child), because the push might reallocate the vector and invalidate any outstanding references. Our code does this correctly — mv and child_state are computed before the push.
unexpanded_moves as a Vec. We store the unexpanded moves directly in each node and use pop() to grab one during expansion. This is simpler than computing the set difference between legal moves and existing children on every expansion. The cost is a small amount of extra memory per node, but for Tic-Tac-Toe (at most 9 moves) this is negligible.
Reward perspective. Each node’s total_value is from the perspective of the player who moved to reach that node. This means:
- During selection, UCT compares children from the perspective of the parent’s current player, which is exactly the player who will choose among those children.
- During backpropagation, we flip the sign at each level because adjacent levels represent alternating players.
Cloning GameState during simulation. The rollout clones the starting state and mutates the clone as it plays random moves. Since GameState is small (a 9-element array plus a Player enum), cloning is cheap. In a game with a larger state, you might use a more efficient approach — but for Tic-Tac-Toe, this is perfectly fine.
No parallelism. This is a single-threaded implementation. MCTS can be parallelized (run multiple rollouts in parallel), but that adds complexity we do not need yet. A single-threaded MCTS can run tens of thousands of iterations per second on Tic-Tac-Toe — far more than needed for strong play.
Testing it out
Let’s add a quick test to see MCTS in action. Add this to your main.rs:
use rand::SeedableRng;
use rand::rngs::StdRng;
fn main() {
// Use a seeded RNG for reproducible results.
let mut rng = StdRng::seed_from_u64(42);
let state = GameState::new();
println!("Running MCTS from the opening position...");
println!("{}", state);
// Run 10,000 iterations — far more than needed for Tic-Tac-Toe,
// but it runs in milliseconds.
let mut tree = MctsTree::new(state.clone());
let best = tree.search(10_000, &mut rng);
println!("Best move: cell {}", best);
println!();
// Print visit counts for all children of root.
let root = &tree.nodes[0];
println!("Move Visits Avg Value");
println!("---- ------ ---------");
for &child_idx in &root.children {
let child = &tree.nodes[child_idx];
let mv = child.move_from_parent.unwrap();
let avg = if child.visit_count > 0 {
child.total_value / child.visit_count as f64
} else {
0.0
};
println!(
" {} {:>5} {:>+.3}",
mv, child.visit_count, avg
);
}
}
Run it with cargo run. You should see output like this (exact numbers depend on the RNG seed):
Running MCTS from the opening position...
. | . | .
-----------
. | . | .
-----------
. | . | .
Best move: cell 4
Move Visits Avg Value
---- ------ ---------
8 797 -0.048
7 629 -0.091
6 806 -0.039
5 655 -0.078
4 3493 +0.108
3 660 -0.072
2 810 -0.042
1 567 -0.097
0 582 -0.085
Several things to notice:
-
Cell 4 (center) gets the most visits by far. This is correct — the center is the strongest opening move in Tic-Tac-Toe. MCTS discovers this entirely from random rollouts, with no built-in knowledge of Tic-Tac-Toe strategy.
-
Corner cells (0, 2, 6, 8) get more visits than edge cells (1, 3, 5, 7). Also correct — corners are the second-best opening moves. Again, MCTS figured this out on its own.
-
The center has a positive average value; all other moves have negative average values. This indicates that the center leads to positions where the first player has a statistical advantage, while other opening moves are slightly disadvantageous (though Tic-Tac-Toe is theoretically a draw with perfect play from both sides).
-
10,000 iterations run almost instantly. For a 3x3 board, even 100,000 iterations would finish in well under a second. Tic-Tac-Toe’s small game tree means MCTS converges quickly.
What we have built
You now have a complete, working MCTS implementation in Rust. Let’s review what each component does:
| Component | Responsibility |
|---|---|
MctsNode | Stores state, statistics (visits, value), parent/child links, unexpanded moves |
MctsTree | Arena of nodes; owns the entire search tree |
uct_score | Balances exploitation (average value) with exploration (visit ratio) |
select | Walks from root to a leaf using UCT |
expand | Adds a new child node for one unexpanded move |
simulate | Plays a random game to completion, returns the reward |
backpropagate | Updates statistics from leaf to root, flipping signs |
search | Runs the four-phase loop for N iterations, returns the best move |
mcts_best_move | Public API — creates a tree, searches, returns the best move |
This is a pure MCTS implementation — the rollouts are random, with no learned guidance. In Part 4, we will replace the random rollouts with a neural network that evaluates positions directly. That is the AlphaGo Zero innovation: instead of playing a random game to estimate a position’s value, ask a neural network that has learned to predict the value from training data generated by self-play.
But even without a neural network, this MCTS player is surprisingly strong at Tic-Tac-Toe. In the next section, you will build a playable game loop and see for yourself.
8. Exercise 2: play Tic-Tac-Toe with pure MCTS
You have a working MCTS implementation and a complete Tic-Tac-Toe game engine. Time to wire them together and watch the AI play — first against itself, then against you.
This exercise has two parts:
- Part A: MCTS vs MCTS — run a tournament of 100 games and verify that strong play produces mostly draws.
- Part B: Human vs MCTS — build a CLI where you can challenge the AI yourself.
Part A: MCTS vs MCTS tournament
The game loop is straightforward: start from an empty board, alternate MCTS moves until the game ends, record the result. Repeat 100 times.
Think about what you need:
- A loop that plays a single game to completion.
- An outer loop that plays many games and tallies results.
- Statistics printed at the end: X wins, O wins, draws.
Try writing this yourself before looking at the solution. The key function you need is mcts_best_move(state, iterations, rng) from §7.
Solution: MCTS vs MCTS tournament
use rand::thread_rng;
fn play_one_game(iterations: u32) -> Option<Player> {
let mut state = GameState::new();
let mut rng = thread_rng();
loop {
// Check if the game is over.
if state.legal_moves().is_empty() {
return state.winner(); // None means draw
}
// MCTS picks the best move for whoever's turn it is.
let best_move = mcts_best_move(&state, iterations, &mut rng);
state = state.apply_move(best_move);
}
}
fn run_tournament(num_games: u32, iterations: u32) {
let mut x_wins = 0;
let mut o_wins = 0;
let mut draws = 0;
for game_num in 1..=num_games {
match play_one_game(iterations) {
Some(Player::X) => x_wins += 1,
Some(Player::O) => o_wins += 1,
None => draws += 1,
}
// Print progress every 10 games.
if game_num % 10 == 0 {
println!(
"After {} games: X={}, O={}, Draw={}",
game_num, x_wins, o_wins, draws
);
}
}
println!("\n=== Final Results ({} games, {} iterations) ===", num_games, iterations);
println!("X wins: {} ({:.1}%)", x_wins, x_wins as f64 / num_games as f64 * 100.0);
println!("O wins: {} ({:.1}%)", o_wins, o_wins as f64 / num_games as f64 * 100.0);
println!("Draws: {} ({:.1}%)", draws, draws as f64 / num_games as f64 * 100.0);
}
fn main() {
run_tournament(100, 1000);
}
What to expect at different iteration counts
Run the tournament with different MCTS budgets and observe how play quality changes:
| Iterations | Typical results (100 games) | Why |
|---|---|---|
| 100 | Mixed results — X wins often, some draws, occasional O wins | Too few iterations to explore the tree deeply. MCTS makes blunders, especially in the mid-game where tactical accuracy matters. |
| 1,000 | Mostly draws, a few X or O wins | Enough iterations to find reasonable moves. Occasional mistakes on tricky positions, but overall solid play. |
| 10,000 | Nearly 100% draws | The search tree is thoroughly explored. Both sides play near-perfectly. This is what we expect — Tic-Tac-Toe is a theoretical draw with optimal play from both sides. |
The trend is clear: more iterations = stronger play = more draws. Tic-Tac-Toe’s game tree is small enough that 10,000 iterations essentially solves it. For larger games like Go, this brute-force approach would not be enough — which is exactly why AlphaGo Zero added a neural network.
Note on randomness: Your exact numbers will differ between runs because MCTS uses random rollouts. The trend will be consistent though. If you want reproducible results, seed your RNG with a fixed value:
let mut rng = rand::rngs::StdRng::seed_from_u64(42);(you will needuse rand::SeedableRng;).
Part B: Human vs MCTS
Now build an interactive CLI. The human picks a cell by number (0–8), and MCTS responds. The board should be printed after every move so the human can see the state.
For a good user experience, you will want to:
- Show the board with numbered empty cells so the player knows which moves are available.
- Read a line from stdin and parse it as a
usize. - Validate that the chosen cell is actually empty.
- Handle invalid input gracefully (prompt again instead of crashing).
Here is a helper function that prints the board with position numbers for empty cells — this is much friendlier than showing dots:
#![allow(unused)]
fn main() {
fn print_board_with_indices(state: &GameState) {
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
match state.board[index] {
Some(Player::X) => print!(" X "),
Some(Player::O) => print!(" O "),
None => print!(" {} ", index),
}
if col < 2 {
print!("|");
}
}
println!();
if row < 2 {
println!("-----------");
}
}
}
}
At the start of a game, this displays:
0 | 1 | 2
-----------
3 | 4 | 5
-----------
6 | 7 | 8
After X plays center, it shows:
0 | 1 | 2
-----------
3 | X | 5
-----------
6 | 7 | 8
Try building the full interactive loop yourself. You will need std::io::stdin() and its read_line method.
Solution: Human vs MCTS
use rand::thread_rng;
use std::io::{self, Write};
fn get_human_move(state: &GameState) -> usize {
let legal = state.legal_moves();
loop {
print!("Your move (enter 0-8): ");
io::stdout().flush().unwrap();
let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
match input.trim().parse::<usize>() {
Ok(index) if legal.contains(&index) => return index,
Ok(index) => println!("Cell {} is not available. Try again.", index),
Err(_) => println!("Please enter a number between 0 and 8."),
}
}
}
fn play_human_vs_mcts(human_player: Player, iterations: u32) {
let mut state = GameState::new();
let mut rng = thread_rng();
println!("\n=== Tic-Tac-Toe: You are {} ===\n", human_player);
print_board_with_indices(&state);
loop {
if state.legal_moves().is_empty() {
match state.winner() {
Some(p) if p == human_player => println!("\nYou win!"),
Some(_) => println!("\nMCTS wins!"),
None => println!("\nIt's a draw!"),
}
break;
}
let chosen_move = if state.current_player == human_player {
println!();
get_human_move(&state)
} else {
println!("\nMCTS is thinking...");
let m = mcts_best_move(&state, iterations, &mut rng);
println!("MCTS plays position {}.", m);
m
};
state = state.apply_move(chosen_move);
println!();
print_board_with_indices(&state);
}
}
fn main() {
// Human plays X (moves first), MCTS uses 1000 iterations.
play_human_vs_mcts(Player::X, 1000);
}
A typical session looks like this:
=== Tic-Tac-Toe: You are X ===
0 | 1 | 2
-----------
3 | 4 | 5
-----------
6 | 7 | 8
Your move (enter 0-8): 4
0 | 1 | 2
-----------
3 | X | 5
-----------
6 | 7 | 8
MCTS is thinking...
MCTS plays position 0.
O | 1 | 2
-----------
3 | X | 5
-----------
6 | 7 | 8
Your move (enter 0-8): 2
...
Experimentation prompts
Now that you have both programs, try these experiments:
-
Can you beat MCTS at 10 iterations? With only 10 iterations per move, the search barely explores the tree. You should be able to win consistently. Try it — the AI will make obvious blunders.
-
At what iteration count does MCTS become unbeatable? Start low and work your way up. Try 50, 100, 200, 500. Find the threshold where you can no longer win. For most people, somewhere around 200–500 iterations is enough for MCTS to play flawlessly on Tic-Tac-Toe.
-
Does it matter who goes first? Run the tournament with MCTS playing as X (first) and O (second). In theory, the first player has a slight advantage in Tic-Tac-Toe — does that show up in your results at lower iteration counts?
-
What about asymmetric strength? Modify
play_one_gameso X uses 100 iterations and O uses 10,000. The stronger player should win more often. How lopsided do the results get?
Solution: Asymmetric tournament
#![allow(unused)]
fn main() {
fn play_asymmetric_game(x_iterations: u32, o_iterations: u32) -> Option<Player> {
let mut state = GameState::new();
let mut rng = thread_rng();
loop {
if state.legal_moves().is_empty() {
return state.winner();
}
let iterations = match state.current_player {
Player::X => x_iterations,
Player::O => o_iterations,
};
let best_move = mcts_best_move(&state, iterations, &mut rng);
state = state.apply_move(best_move);
}
}
}
Readiness checklist
Before moving on to Part 4, make sure you can answer “yes” to all of these:
- You ran a 100-game MCTS vs MCTS tournament and saw mostly draws at high iteration counts.
- You understand why more iterations leads to stronger play (deeper tree exploration, more accurate value estimates).
- You played against MCTS interactively and experienced its strength at high iteration counts.
- You found the approximate iteration threshold where MCTS becomes unbeatable at Tic-Tac-Toe.
- You understand that this approach (brute-force random rollouts) does not scale to complex games — and that replacing rollouts with a neural network is the key insight of AlphaGo Zero.
If so, you are ready for Part 4, where we introduce neural networks to guide MCTS. Instead of random rollouts, the network will learn to predict which moves are good (policy) and who is winning (value) — trained entirely from self-play data.
Part 4 — Neural Network Policy/Value Head
9. Neural network architecture overview
In Part 3, we built an MCTS agent that plays Tic-Tac-Toe by simulating thousands of random games from each position. It works — but it is brute force. Every evaluation requires playing a game to completion with random moves, and the quality of the evaluation depends on how many of those random rollouts you can afford. For a 3x3 board, that is fine. For a 19x19 Go board, it is hopelessly slow.
AlphaGo Zero’s breakthrough was replacing those random rollouts with a neural network — a function that looks at a board position and instantly outputs two things: which moves are likely to be good (policy) and who is probably winning (value). No random simulation needed. The network learns these judgments from millions of self-play games.
This section teaches you what a neural network is, how it works, and why the specific architecture used in AlphaGo Zero has two “heads.” We will build up from a single artificial neuron to the complete dual-headed network we will implement in Rust in the next section. No prior neural network knowledge is assumed.
What is a neural network?
A neural network is a function approximator. That is a fancy way of saying: it is a machine that takes in some numbers, does math to them, and produces some numbers as output. What makes it special is that the math it does is controlled by thousands (or millions) of adjustable parameters, and those parameters can be tuned so the function produces whatever output you want for any given input.
Think of it like a programmable calculator with thousands of tiny knobs. At first, the knobs are set randomly, and the calculator outputs nonsense. But if you had a way to measure how wrong the output is and nudge each knob in the right direction, eventually the calculator would learn to compute the function you care about.
For our game AI, the function we want is:
- Input: a Tic-Tac-Toe board position
- Output: (1) which move to play, and (2) who is winning
The neural network will learn this function from data generated by self-play. We never hand it a strategy book or a database of expert games. It figures out good play entirely on its own.
The artificial neuron
The fundamental building block of a neural network is the neuron (also called a node or unit). A single neuron does three things:
- Takes in one or more numbers as inputs.
- Multiplies each input by a weight and sums them up, then adds a bias.
- Passes the result through an activation function.
Here is a concrete example. Suppose we have a neuron with two inputs:
Inputs: x₁ = 3.0, x₂ = -1.0
Weights: w₁ = 0.5, w₂ = 2.0
Bias: b = 0.1
Step 1 — Weighted sum:
z = (x₁ × w₁) + (x₂ × w₂) + b
z = (3.0 × 0.5) + (-1.0 × 2.0) + 0.1
z = 1.5 + (-2.0) + 0.1
z = -0.4
Step 2 — Activation function (ReLU):
output = max(0, z) = max(0, -0.4) = 0.0
Let’s break down each component:
Weights are the knobs we mentioned earlier. Each input gets its own weight, which determines how much that input matters to this neuron. A large positive weight means “this input is very important — when it goes up, push my output up.” A large negative weight means “this input is important in the opposite direction.” A weight near zero means “I don’t care about this input.”
Bias is an offset. It shifts the neuron’s output up or down regardless of the inputs. Think of it as the neuron’s “default inclination” — even when all inputs are zero, the bias determines whether the neuron tends to activate or stay quiet.
Activation function introduces nonlinearity. Without it, a neuron is just computing a weighted sum, which is a straight line. No matter how many straight-line computations you stack together, the result is still a straight line. Activation functions add curves, which let networks learn complex, nonlinear patterns.
The most common activation function in modern networks is ReLU (Rectified Linear Unit):
ReLU(z) = max(0, z)
If the weighted sum is positive, pass it through unchanged. If negative, output zero. It is dead simple, yet it works remarkably well. The key insight is that it creates a “hinge” — the neuron either fires (positive sum) or stays silent (negative sum), which is enough nonlinearity for stacking layers to produce complex behavior.
Other activation functions you may encounter:
- Sigmoid: squashes output to the range (0, 1). Useful for probabilities.
- Tanh: squashes output to the range (-1, 1). Useful when you need centered outputs — like our value head.
- Softmax: applied to a group of neurons, converts raw scores into a probability distribution that sums to 1. We will use this for the policy head.
Layers: stacking neurons together
A single neuron can only learn simple, linear-ish decision boundaries. The real power comes from organizing neurons into layers and stacking those layers on top of each other.
A neural network has three kinds of layers:
Input layer — This is not really a layer of neurons at all. It is simply the raw data you feed into the network. For our Tic-Tac-Toe network, the input layer will be 18 numbers representing the board state (more on the encoding later).
Hidden layers — These are the layers in the middle where the actual learning happens. Each hidden layer is a group of neurons that all take the same inputs (the outputs from the previous layer) and each produce one output. “Hidden” just means they are not directly visible as input or output — they are internal to the network.
Output layer — The final layer that produces the network’s answer. For our game AI, this will produce move probabilities and a value estimate.
Here is how a simple two-layer network looks:
Input layer Hidden layer 1 Hidden layer 2 Output layer
(3 inputs) (4 neurons) (4 neurons) (2 outputs)
x₁ ──────────→ h₁ ──────────→ h₅ ──────────→ o₁
╲ ╱ ╲ ╱ ╲ ╱
╲ ╱ ╲ ╱ ╲ ╱
x₂ ────╳────→ h₂ ────╳────→ h₆ ────╳────→ o₂
╱ ╲ ╱ ╲ ╱
╱ ╲ ╱ ╲ ╱
x₃ ──────────→ h₃ ──────────→ h₇
╱ ╱
╱ ╱
h₄ ──────────→ h₈
Every neuron in one layer is connected to every neuron in the next layer — this is called a fully connected (or dense) layer. Each connection has its own weight. In this diagram:
- The connections from the input layer to hidden layer 1 have 3 x 4 = 12 weights, plus 4 biases = 16 parameters.
- Hidden layer 1 to hidden layer 2 has 4 x 4 = 16 weights, plus 4 biases = 20 parameters.
- Hidden layer 2 to the output has 4 x 2 = 8 weights, plus 2 biases = 10 parameters.
- Total: 46 learnable parameters.
Even this tiny network has 46 knobs to tune. Real networks have millions.
Why does stacking layers help? Each layer learns progressively more abstract features. In image recognition, the first layer might learn to detect edges, the second layer combines edges into shapes, and the third layer combines shapes into objects. In game evaluation, the first layer might learn to detect individual piece patterns, the second layer might recognize tactical motifs, and deeper layers might understand strategic concepts. The network discovers these intermediate representations on its own — we never tell it what to look for.
The forward pass
The forward pass is the process of pushing data through the network from input to output. There is no mystery to it — it is just repeated application of the neuron computation (weighted sum, add bias, apply activation) at every layer.
Here is the forward pass in pseudocode for a two-hidden-layer network:
function forward(input):
# Layer 1: input → hidden1
hidden1 = activation(weights1 × input + biases1)
# Layer 2: hidden1 → hidden2
hidden2 = activation(weights2 × hidden1 + biases2)
# Output layer: hidden2 → output
output = output_activation(weights3 × hidden2 + biases3)
return output
Each step takes the outputs from the previous layer, multiplies by a weight matrix, adds biases, and applies the activation function. The whole thing is pure arithmetic — matrix multiplications and element-wise function applications. This is why neural networks can run so fast on GPUs, which are designed for exactly this kind of math.
For our Tic-Tac-Toe network, a single forward pass will take the 18-number board encoding, push it through two hidden layers of 64 neurons each, and produce 10 outputs (9 move probabilities + 1 value). The entire computation takes microseconds — vastly faster than running hundreds of random rollouts.
Training: how the network learns
A freshly created neural network has random weights. Its forward pass produces garbage. Training is the process of adjusting those weights so the network’s outputs match what we want.
Training requires three ingredients:
1. Training data — input/output pairs.
For our game AI, the data comes from self-play. After playing thousands of games using MCTS, we collect triples of (board position, MCTS move probabilities, game outcome). The board position is the input; the MCTS probabilities and game outcome are the “correct” outputs we want the network to learn.
2. A loss function — a measure of how wrong the network is.
The loss function compares the network’s output to the desired output and produces a single number: the loss. A high loss means the network is doing badly; a low loss means it is doing well. Common choices:
- Cross-entropy loss for the policy head: measures how different the network’s predicted move probabilities are from the MCTS move probabilities. If MCTS strongly preferred one move but the network assigns it low probability, the loss is high.
- Mean squared error for the value head: measures how far the network’s predicted value is from the actual game outcome. If the network said +0.8 (strongly favoring the current player) but the game was actually a loss (-1.0), the loss is high.
The total loss is typically the sum of these two:
total_loss = policy_loss + value_loss
3. An optimizer — a method for adjusting weights to reduce the loss.
This is where gradient descent comes in. The core idea is beautifully simple:
- Run the forward pass to get the network’s output.
- Compute the loss (how wrong is the output?).
- Figure out which direction to nudge each weight to make the loss smaller. This is the gradient — it tells you, for every single weight in the network, “if you increase this weight slightly, does the loss go up or down, and by how much?”
- Nudge each weight a small step in the direction that reduces the loss.
- Repeat thousands of times.
The process of computing gradients is called backpropagation. The name comes from the fact that the computation flows backward through the network — starting from the loss at the output and propagating back through each layer to determine how each weight contributed to the error. You do not need to understand the calculus behind it to use it; every neural network library handles backpropagation automatically.
Here is the intuition with an analogy. Imagine you are blindfolded on a hilly landscape and trying to reach the lowest valley. You cannot see, but you can feel the slope of the ground beneath your feet. At each step, you feel which direction goes downhill and take a small step that way. That is gradient descent. The “landscape” is the loss function, the “position” is the current set of weights, and each step adjusts the weights to reduce the loss.
The size of each step is controlled by the learning rate — a small number like 0.001 or 0.01. Too large and you overshoot the valley, bouncing around wildly. Too small and training takes forever. Finding a good learning rate is one of the practical arts of training neural networks.
The typical training loop looks like this:
for each batch of training examples:
predictions = forward(inputs)
loss = compute_loss(predictions, targets)
gradients = backpropagate(loss)
for each weight in network:
weight = weight - learning_rate × gradient
After thousands of iterations through the training data, the weights converge to values that make the network produce good outputs for the kinds of inputs it has seen — and, crucially, it generalizes to similar inputs it has not seen. This is the magic: a network trained on a million board positions can evaluate positions it has never encountered before.
The dual-headed architecture
Now we can talk about the specific architecture used in AlphaGo Zero and our Tic-Tac-Toe AI. It is called a dual-headed network because it has one body and two “heads” — one for policy and one for value.
The shared trunk
The first part of the network is a set of hidden layers that process the raw board state into a learned representation — an internal encoding that captures the important features of the position. This is called the “trunk” or “backbone” of the network.
The trunk takes the board encoding as input and outputs a vector of numbers that represents what it has learned about the position. These numbers are not directly interpretable by humans — they are abstract features that the network has discovered are useful for both predicting good moves and predicting who will win.
The policy head
The policy head branches off from the trunk and produces a probability distribution over all possible moves. For Tic-Tac-Toe, that means 9 numbers (one per cell) that sum to 1.0. Each number represents how strongly the network recommends playing in that cell.
For example, the policy head might output:
Cell: 0 1 2 3 4 5 6 7 8
Probability: 0.02 0.03 0.01 0.05 0.70 0.08 0.01 0.05 0.05
This says: “I strongly recommend cell 4 (the center), with 70% probability.” The MCTS algorithm uses these probabilities to focus its search on the most promising moves rather than exploring everything equally.
The policy head’s final activation function is softmax, which converts raw scores into a valid probability distribution.
The value head
The value head also branches off from the trunk, but instead of 9 outputs it produces a single number between -1 and 1. This number estimates who is winning from the perspective of the current player:
- +1 means “the current player is definitely winning.”
- -1 means “the current player is definitely losing.”
- 0 means “the position is roughly even.”
For example, on a Tic-Tac-Toe board where X has two in a row and it is X’s turn, the value head might output +0.95 — “X is almost certainly going to win.”
The value head’s final activation function is tanh, which squashes any number into the range (-1, 1).
Why two heads on one trunk?
You might wonder: why not train two completely separate networks, one for policy and one for value? There are three reasons for sharing a trunk.
Shared feature learning. Many features of a position are useful for both predicting good moves and predicting who is winning. For instance, recognizing that a player has two in a row is relevant both to “what move should I play?” (block or complete the row) and “who is winning?” (the player with two in a row has an advantage). By sharing a trunk, both heads benefit from the same learned features without having to discover them independently.
Efficiency. One network is faster to run than two. During MCTS, we call the network for every node we expand, so speed matters. A shared trunk means the expensive computation happens once, and then the two heads — which are small — each do a tiny bit of additional work.
Regularization. Training the trunk to be useful for two different tasks simultaneously acts as a form of regularization — it prevents the network from overfitting to either task alone. The trunk must learn features that are generally useful, which tends to produce more robust representations.
Input representation: encoding the board as a tensor
Neural networks only understand numbers. We cannot feed a Tic-Tac-Toe board directly into the network — we need to convert it into a numerical format called a tensor (which is just a fancy word for a multi-dimensional array of numbers).
The standard approach for board games is to use binary feature planes — one plane per piece type. For Tic-Tac-Toe, we create two 3x3 grids:
- Plane 1 (X positions): a 1 where X has played, 0 elsewhere.
- Plane 2 (O positions): a 1 where O has played, 0 elsewhere.
For example, given this board:
X | O | .
-----------
. | X | .
-----------
. | . | O
The two planes are:
Plane 1 (X): Plane 2 (O):
1 0 0 0 1 0
0 1 0 0 0 0
0 0 0 0 0 1
We flatten these two 3x3 planes into a single 1-dimensional vector of 18 numbers:
[1, 0, 0, 0, 1, 0, 0, 0, 0, ← X plane (row by row)
0, 1, 0, 0, 0, 0, 0, 0, 1] ← O plane (row by row)
This 18-element vector is what the network’s input layer receives.
Why two separate planes instead of a single grid with, say, +1 for X and -1 for O? The two-plane encoding gives the network richer information. It can independently learn features about X’s position and O’s position, rather than having them tangled together in a single number. This also scales naturally to more complex games — chess, for instance, uses 12 planes (one per piece type per color).
We also always encode the board from the perspective of the current player. That is, if it is O’s turn, we swap the planes so that “my pieces” are always in Plane 1 and “opponent’s pieces” are always in Plane 2. This way the network learns a single concept — “am I winning?” and “what should I play?” — rather than needing to learn separate strategies for X and O.
Output shapes
The network produces two outputs:
Policy output: 9 values.
One value for each of the 9 cells on the board. After the softmax activation, these form a probability distribution. Illegal moves (cells already occupied) will still get some probability assigned by the network, but we mask them out before using the policy — we set their probabilities to zero and renormalize the remaining probabilities to sum to 1.
Value output: 1 value.
A single scalar in the range [-1, 1], produced by the tanh activation. This estimates the game-theoretic value of the position from the current player’s perspective.
Our concrete architecture
Here is the specific network we will build for Tic-Tac-Toe:
┌─────────────────────────────────┐
│ INPUT LAYER │
│ 18 neurons (2 × 3 × 3) │
│ X-plane (9) + O-plane (9) │
└────────────────┬────────────────┘
│
┌──────┴──────┐
┌─────────┴─────────────┴─────────┐
│ HIDDEN LAYER 1 (Trunk) │
│ 64 neurons, ReLU activation │
│ Parameters: 18×64 + 64 = 1216 │
└────────────────┬────────────────┘
│
┌──────┴──────┐
┌─────────┴─────────────┴─────────┐
│ HIDDEN LAYER 2 (Trunk) │
│ 64 neurons, ReLU activation │
│ Parameters: 64×64 + 64 = 4160 │
└──────────┬─────────┬────────────┘
│ │
┌────────────┘ └────────────┐
│ │
┌────────────┴────────────┐ ┌──────────────┴──────────────┐
│ POLICY HEAD │ │ VALUE HEAD │
│ 9 neurons (one/cell) │ │ 1 neuron │
│ Softmax activation │ │ Tanh activation │
│ Params: 64×9 + 9=585 │ │ Params: 64×1 + 1 = 65 │
└────────────┬────────────┘ └──────────────┬──────────────┘
│ │
▼ ▼
┌─────────────────┐ ┌──────────────────┐
│ Move probs [9] │ │ Value [-1, +1] │
│ (sums to 1.0) │ │ (single scalar) │
└─────────────────┘ └──────────────────┘
Let’s count the total parameters:
| Layer | Weights | Biases | Total |
|---|---|---|---|
| Input → Hidden 1 | 18 × 64 = 1,152 | 64 | 1,216 |
| Hidden 1 → Hidden 2 | 64 × 64 = 4,096 | 64 | 4,160 |
| Hidden 2 → Policy | 64 × 9 = 576 | 9 | 585 |
| Hidden 2 → Value | 64 × 1 = 64 | 1 | 65 |
| Total | 6,026 |
Just over 6,000 learnable parameters. That is tiny by modern standards (large language models have billions), but it is more than enough for Tic-Tac-Toe. The game has only 5,478 possible board positions — our network has enough capacity to learn a good representation.
Summary notation
The architecture can be written compactly as:
18 → 64 (ReLU) → 64 (ReLU) → { 9 (Softmax), 1 (Tanh) }
Read this as: “18 inputs, through a 64-neuron hidden layer with ReLU, through another 64-neuron hidden layer with ReLU, splitting into a policy head of 9 outputs with softmax and a value head of 1 output with tanh.”
How this replaces random rollouts
In Part 3, our pure MCTS agent evaluated positions by running random rollouts — playing random moves until the game ended and using the win/loss result as the value estimate. This works, but has two problems:
-
Random moves are a terrible proxy for good play. A position might be winning for X, but if the random rollout happens to have X make several blunders, the rollout says X lost. You need many rollouts to average out the noise.
-
Rollouts give no guidance on which moves to explore first. MCTS starts by trying all moves roughly equally, relying on UCB1 to gradually focus on the better ones. Early in the search, a lot of time is wasted exploring obviously bad moves.
The neural network solves both problems:
Value head replaces rollouts. Instead of playing a game to completion with random moves and seeing who wins, we simply ask the value head: “given this board position, who is winning?” The network returns a number immediately. One forward pass takes microseconds, compared to potentially hundreds of random rollouts. And the value estimate is based on learned patterns, not random play — so it is far more accurate, even with a single evaluation.
Policy head guides the search. Instead of exploring all moves equally at first, we use the policy head’s probabilities to focus on the moves the network thinks are best. If the network assigns 70% probability to one move and 2% to another, MCTS will explore that first move far more eagerly. This dramatically reduces the amount of search needed to find good moves.
Together, these two improvements are what allowed AlphaGo Zero to surpass all previous Go programs. The network provides strong “intuition” (good priors on which moves to try, and good evaluations of positions), and MCTS provides exact “calculation” (deep search to verify and refine the network’s intuition). They complement each other beautifully.
In the next section, we will integrate a neural network crate into our Rust project and build this architecture in code.
10. Integrating a neural network crate
In §9 we designed a dual-headed neural network on paper: 18 inputs, two hidden layers of 64 neurons each, a 9-output policy head, and a 1-output value head. Now we build it in Rust.
Choosing an approach
The Rust ecosystem has several neural network options:
tch-rs— Rust bindings to PyTorch’s C++ library (libtorch). Full-featured and fast, but it drags in a massive native dependency (~2 GB). Overkill for our 6,026-parameter network.candle— A pure-Rust ML framework from Hugging Face. No C++ dependency, but still a large crate with GPU support, autograd, and many features we will never touch.- Implement from scratch — Write our own matrix math, forward pass, and backpropagation.
For a network this small, we will implement everything from scratch. Here is why:
- It is more educational. You will understand exactly what happens inside a neural network — not just how to call an API, but what the API does internally.
- Zero dependencies. No build headaches, no version conflicts, no multi-gigabyte downloads.
- Tic-Tac-Toe does not need GPU acceleration. Our forward pass multiplies an 18-element vector through a few small matrices. That takes microseconds on a CPU. An entire training run completes in seconds.
The only downside is that we must implement backpropagation ourselves. For our simple architecture (a few dense layers with standard activations), this is entirely manageable. Let’s get started.
The building blocks: vectors and matrices
Neural networks are fundamentally about matrix multiplication. Before building layers, we need a few utility functions that operate on flat Vec<f32> vectors. We represent an m × n matrix as a Vec<f32> of length m × n, stored in row-major order — row 0 first, then row 1, and so on.
#![allow(unused)]
fn main() {
/// Multiply matrix `a` (rows × inner) by matrix `b` (inner × cols).
/// Both are stored in row-major order. Returns a vector of length rows × cols.
fn mat_mul(a: &[f32], b: &[f32], rows: usize, inner: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0.0; rows * cols];
for r in 0..rows {
for c in 0..cols {
let mut sum = 0.0;
for k in 0..inner {
sum += a[r * inner + k] * b[k * cols + c];
}
out[r * cols + c] = sum;
}
}
out
}
/// Add bias vector `b` to each element of `a` (both length `n`).
fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
}
This is a straightforward triple-loop matrix multiply — O(rows × inner × cols). For our largest multiplication (64 × 64), that is only 262,144 floating-point operations. Modern CPUs execute billions per second, so this will be instantaneous.
The Layer struct
A dense (fully connected) layer is defined by its weight matrix and bias vector. During a forward pass, it computes output = activation(input × weights + bias).
We represent the activation function as an enum so we can branch on it during both the forward and backward passes:
#![allow(unused)]
fn main() {
/// Supported activation functions.
#[derive(Debug, Clone, Copy)]
enum Activation {
/// ReLU: max(0, x). Used in hidden layers.
Relu,
/// Tanh: maps to (-1, +1). Used in the value head.
Tanh,
/// Identity: no transformation. We apply softmax separately for the
/// policy head because softmax operates on the entire output vector,
/// not element-wise.
Identity,
}
/// A single fully connected layer.
#[derive(Debug, Clone)]
struct Layer {
/// Weight matrix, shape: input_size × output_size, row-major.
weights: Vec<f32>,
/// Bias vector, length: output_size.
biases: Vec<f32>,
/// Number of inputs this layer expects.
input_size: usize,
/// Number of outputs this layer produces.
output_size: usize,
/// Activation function applied after the linear transform.
activation: Activation,
}
}
Why is the policy head’s activation Identity instead of Softmax? Softmax is not element-wise — it depends on all the outputs simultaneously (because the probabilities must sum to 1). So we compute the linear output first (called “logits”), then apply softmax as a separate step. This also makes it easier to implement the cross-entropy loss during training, which works directly on logits.
Weight initialization
Before a network can learn, its weights need initial values. Setting them all to zero is a disaster — every neuron would compute the same thing and receive the same gradient, so they would never differentiate. We need random initial values, but the scale matters.
Xavier initialization (also called Glorot initialization) sets each weight to a random value drawn from a uniform distribution:
w ~ Uniform(-limit, +limit) where limit = sqrt(6 / (fan_in + fan_out))
Here fan_in is the number of inputs to the layer and fan_out is the number of outputs. The idea is to keep the variance of activations roughly constant as signals flow through the network. If weights are too large, activations explode; too small, and they vanish to zero. Xavier initialization balances this.
We will use a simple linear congruential generator (LCG) so our code has zero dependencies — not even on rand. For a production system you would want a better RNG, but for training a Tic-Tac-Toe network, this is perfectly adequate:
#![allow(unused)]
fn main() {
/// A simple pseudo-random number generator (LCG).
/// Not cryptographically secure, but fine for weight initialization.
struct Rng {
state: u64,
}
impl Rng {
fn new(seed: u64) -> Self {
Rng { state: seed }
}
/// Returns a random f32 in [0, 1).
fn next_f32(&mut self) -> f32 {
// LCG parameters from Numerical Recipes.
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
// Use bits 33..17 for better distribution.
let bits = ((self.state >> 17) & 0x7FFF) as f32;
bits / 32768.0
}
/// Returns a random f32 in [-limit, +limit).
fn uniform(&mut self, limit: f32) -> f32 {
self.next_f32() * 2.0 * limit - limit
}
}
}
Now we can create a layer with Xavier-initialized weights:
#![allow(unused)]
fn main() {
impl Layer {
/// Create a new layer with Xavier-initialized weights and zero biases.
fn new(input_size: usize, output_size: usize, activation: Activation, rng: &mut Rng) -> Self {
let limit = (6.0 / (input_size + output_size) as f32).sqrt();
let weights: Vec<f32> = (0..input_size * output_size)
.map(|_| rng.uniform(limit))
.collect();
let biases = vec![0.0; output_size];
Layer {
weights,
biases,
input_size,
output_size,
activation,
}
}
}
}
Biases start at zero. This is standard practice — the random weights already break symmetry, so zero biases are fine.
Forward pass
The forward pass pushes an input vector through the layer. We need to save the intermediate values (the pre-activation output and the input) because backpropagation will need them later:
#![allow(unused)]
fn main() {
impl Layer {
/// Run the forward pass. Returns (output, pre_activation, input_clone).
/// We save intermediate values for backpropagation.
fn forward(&self, input: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
// Linear transform: output = input × weights + biases
let pre_act = vec_add(
&mat_mul(input, &self.weights, 1, self.input_size, self.output_size),
&self.biases,
);
// Apply activation function element-wise.
let output = match self.activation {
Activation::Relu => pre_act.iter().map(|&x| x.max(0.0)).collect(),
Activation::Tanh => pre_act.iter().map(|&x| x.tanh()).collect(),
Activation::Identity => pre_act.clone(),
};
(output, pre_act, input.to_vec())
}
}
}
Notice that we treat the input as a 1 × input_size matrix and the weights as an input_size × output_size matrix. The result is a 1 × output_size vector, which we then add the bias to and activate.
Softmax
The policy head outputs raw logits — unbounded numbers that can be positive or negative. Softmax converts them into a probability distribution (all positive, summing to 1):
softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)
A numerical stability trick: we subtract the maximum value before exponentiating. This prevents overflow when a logit is large:
#![allow(unused)]
fn main() {
/// Convert a vector of logits into a probability distribution.
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
}
If the logits are [2.0, 1.0, 0.1], softmax produces roughly [0.66, 0.24, 0.10] — the largest logit gets the highest probability, but every value is positive and they sum to 1.
The Network struct
Our network has four layers total: two shared “trunk” layers, one policy head, and one value head. After the trunk, the computation branches — the trunk’s output feeds into both heads independently:
#![allow(unused)]
fn main() {
/// A dual-headed neural network for Tic-Tac-Toe.
///
/// Architecture: 18 → 64 (ReLU) → 64 (ReLU) → { 9 (Softmax), 1 (Tanh) }
#[derive(Debug, Clone)]
struct Network {
/// First hidden layer (trunk): 18 → 64.
hidden1: Layer,
/// Second hidden layer (trunk): 64 → 64.
hidden2: Layer,
/// Policy head: 64 → 9. Activation is Identity (softmax applied separately).
policy_head: Layer,
/// Value head: 64 → 1. Activation is Tanh.
value_head: Layer,
}
impl Network {
/// Create a new network with Xavier-initialized weights.
fn new(seed: u64) -> Self {
let mut rng = Rng::new(seed);
Network {
hidden1: Layer::new(18, 64, Activation::Relu, &mut rng),
hidden2: Layer::new(64, 64, Activation::Relu, &mut rng),
policy_head: Layer::new(64, 9, Activation::Identity, &mut rng),
value_head: Layer::new(64, 1, Activation::Tanh, &mut rng),
}
}
}
}
This exactly matches the architecture diagram from §9. The total parameter count is 1,216 + 4,160 + 585 + 65 = 6,026 — the same number we calculated on paper.
Encoding the board state
Before feeding a position to the network, we must convert our GameState into the 18-element input vector described in §9. Recall the encoding: the first 9 elements are the “current player’s pieces” plane, and the second 9 elements are the “opponent’s pieces” plane.
#![allow(unused)]
fn main() {
/// Encode a game state into the 18-element input vector.
///
/// Elements 0..9: 1.0 where the current player has a piece, else 0.0.
/// Elements 9..18: 1.0 where the opponent has a piece, else 0.0.
fn encode_board(state: &GameState) -> Vec<f32> {
let mut input = vec![0.0_f32; 18];
let current = state.current_player;
for i in 0..9 {
match state.board[i] {
Some(p) if p == current => input[i] = 1.0,
Some(_) => input[i + 9] = 1.0,
None => {}
}
}
input
}
}
By always putting the current player’s pieces first, the network learns a single perspective: “what should I do?” This works for both X and O without needing separate networks or any special-casing.
Illegal move masking
The policy head outputs a probability for all 9 cells, but some cells may already be occupied. We must not select an occupied cell. The solution: mask the illegal moves by setting their probabilities to zero, then renormalize so the remaining probabilities sum to 1.
#![allow(unused)]
fn main() {
/// Zero out probabilities for occupied cells and renormalize.
fn mask_illegal(policy: &[f32], state: &GameState) -> Vec<f32> {
let mut masked: Vec<f32> = policy
.iter()
.enumerate()
.map(|(i, &p)| if state.board[i].is_none() { p } else { 0.0 })
.collect();
let sum: f32 = masked.iter().sum();
if sum > 0.0 {
for p in &mut masked {
*p /= sum;
}
}
masked
}
}
We apply masking after softmax. This is simpler than masking logits (which would require setting illegal logits to negative infinity before softmax). Both approaches work, but post-softmax masking is easier to understand.
The predict method
Now we wire everything together. The predict method takes a GameState and returns a (Vec<f32>, f32) tuple — the move probability distribution and the position’s value:
#![allow(unused)]
fn main() {
impl Network {
/// Run a forward pass and return (policy, value).
///
/// `policy` is a 9-element probability distribution over moves,
/// with illegal moves masked to zero.
/// `value` is a scalar in [-1, +1] estimating the current player's
/// chance of winning.
fn predict(&self, state: &GameState) -> (Vec<f32>, f32) {
let input = encode_board(state);
// Forward through the trunk.
let (h1_out, _, _) = self.hidden1.forward(&input);
let (h2_out, _, _) = self.hidden2.forward(&h1_out);
// Policy head: logits → softmax → mask illegal moves.
let (policy_logits, _, _) = self.policy_head.forward(&h2_out);
let policy = softmax(&policy_logits);
let policy = mask_illegal(&policy, state);
// Value head: single tanh output.
let (value_out, _, _) = self.value_head.forward(&h2_out);
let value = value_out[0];
(policy, value)
}
}
}
That is the complete forward pass. Given a board position, we encode it, push it through two hidden layers, then branch into the two heads. The policy head’s raw output goes through softmax and illegal-move masking; the value head’s output is already in [-1, +1] thanks to tanh.
Training infrastructure
A network with random weights is useless — it needs to learn from data. In the self-play loop (covered in Part 5), we will generate training examples by having the network play games against itself, using MCTS to produce improved move probabilities. Each training example is a tuple of:
- State: a board position (encoded as 18 floats)
- Target policy: the MCTS-improved move probabilities (9 floats summing to 1)
- Target value: the actual game outcome from this player’s perspective (+1 win, -1 loss, 0 draw)
The network learns by adjusting its weights to make its predictions match these targets. This requires three components: a loss function, backpropagation, and an optimizer.
Loss functions
We need to measure how wrong the network’s predictions are. Different heads use different loss functions:
Cross-entropy loss for the policy head. Cross-entropy measures the distance between two probability distributions. If the target says move 4 should have probability 0.8 and the network says 0.2, the loss is high. The formula is:
L_policy = -sum(target_i * ln(predicted_i)) for all i
We add a tiny epsilon to avoid taking the log of zero:
#![allow(unused)]
fn main() {
/// Cross-entropy loss between target and predicted distributions.
fn cross_entropy_loss(target: &[f32], predicted: &[f32]) -> f32 {
let eps = 1e-8;
-target
.iter()
.zip(predicted.iter())
.map(|(&t, &p)| t * (p + eps).ln())
.sum::<f32>()
}
}
MSE loss for the value head. Mean squared error is the simplest regression loss: how far is the predicted value from the target?
L_value = (target - predicted)^2
#![allow(unused)]
fn main() {
/// Mean squared error between a target and predicted scalar.
fn mse_loss(target: f32, predicted: f32) -> f32 {
(target - predicted).powi(2)
}
}
Combined loss. The total loss is the sum of both, optionally with a weighting factor. We keep it simple and weight them equally:
L_total = L_policy + L_value
In practice, some implementations add an L2 regularization term to prevent overfitting, but for our tiny network and simple game, it is unnecessary.
Backpropagation
Backpropagation is the algorithm that computes how much each weight contributed to the loss, so we know which direction to adjust it. It works by applying the chain rule of calculus backwards through the network.
The key insight: if we know the gradient of the loss with respect to a layer’s output, we can compute the gradient with respect to that layer’s weights, biases, and input. The input gradient then becomes the output gradient for the previous layer, and we repeat.
For a dense layer computing output = activation(input × weights + bias):
-
Gradient through activation: Multiply the output gradient element-wise by the activation’s derivative.
- ReLU derivative: 1 if pre_activation > 0, else 0.
- Tanh derivative: 1 - tanh(x)^2.
- Identity derivative: 1.
-
Bias gradient: Equal to the activation gradient (since bias is just added).
-
Weight gradient:
input^T × activation_gradient(outer product when input is a single vector). -
Input gradient:
activation_gradient × weights^T(to propagate further back).
Here is the implementation:
#![allow(unused)]
fn main() {
impl Layer {
/// Compute gradients and return (input_gradient, weight_gradient, bias_gradient).
///
/// `output_grad`: gradient of the loss with respect to this layer's output.
/// `pre_act`: the pre-activation values saved during forward pass.
/// `input`: the input to this layer saved during forward pass.
fn backward(
&self,
output_grad: &[f32],
pre_act: &[f32],
input: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
// Step 1: gradient through activation function.
let act_grad: Vec<f32> = output_grad
.iter()
.zip(pre_act.iter())
.map(|(&og, &pa)| {
let d_act = match self.activation {
Activation::Relu => {
if pa > 0.0 {
1.0
} else {
0.0
}
}
Activation::Tanh => 1.0 - pa.tanh().powi(2),
Activation::Identity => 1.0,
};
og * d_act
})
.collect();
// Step 2: bias gradient = activation gradient.
let bias_grad = act_grad.clone();
// Step 3: weight gradient = input^T × act_grad.
// input is 1 × input_size, act_grad is 1 × output_size.
// Result is input_size × output_size (same shape as weights).
let mut weight_grad = vec![0.0; self.input_size * self.output_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
weight_grad[i * self.output_size + j] = input[i] * act_grad[j];
}
}
// Step 4: input gradient = act_grad × weights^T.
// act_grad is 1 × output_size, weights is input_size × output_size.
// We need to multiply act_grad by weights transposed.
let mut input_grad = vec![0.0; self.input_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
input_grad[i] += act_grad[j] * self.weights[i * self.output_size + j];
}
}
(input_grad, weight_grad, bias_grad)
}
/// Apply gradients using SGD: weight -= learning_rate * gradient.
fn apply_gradients(&mut self, weight_grad: &[f32], bias_grad: &[f32], lr: f32) {
for (w, g) in self.weights.iter_mut().zip(weight_grad.iter()) {
*w -= lr * g;
}
for (b, g) in self.biases.iter_mut().zip(bias_grad.iter()) {
*b -= lr * g;
}
}
}
}
The backward pass mirrors the forward pass in reverse. Each mathematical operation has a corresponding gradient rule, and we chain them together. This is the same algorithm that every deep learning framework implements — we are just doing it by hand.
The train_step method
Now we combine the forward pass, loss computation, and backward pass into a single training step. This method takes one training example and updates all the weights:
#![allow(unused)]
fn main() {
impl Network {
/// Perform one training step on a single example.
///
/// Returns the total loss (for monitoring training progress).
fn train_step(
&mut self,
state: &GameState,
target_policy: &[f32],
target_value: f32,
lr: f32,
) -> f32 {
// === Forward pass (saving intermediates) ===
let input = encode_board(state);
let (h1_out, h1_pre, h1_in) = self.hidden1.forward(&input);
let (h2_out, h2_pre, h2_in) = self.hidden2.forward(&h1_out);
let (p_logits, p_pre, p_in) = self.policy_head.forward(&h2_out);
let (v_out, v_pre, v_in) = self.value_head.forward(&h2_out);
let policy = softmax(&p_logits);
let value = v_out[0];
// === Compute losses ===
let policy_loss = cross_entropy_loss(target_policy, &policy);
let value_loss = mse_loss(target_value, value);
let total_loss = policy_loss + value_loss;
// === Backward pass ===
// Policy head gradient.
// For cross-entropy loss with softmax, the gradient of the loss
// with respect to the logits simplifies beautifully to:
// d_logits = predicted - target
let policy_grad: Vec<f32> = policy
.iter()
.zip(target_policy.iter())
.map(|(&p, &t)| p - t)
.collect();
// Value head gradient.
// For MSE loss with tanh, the gradient w.r.t. the value head output is:
// d_value = 2 * (predicted - target)
// The tanh derivative is handled inside backward().
let value_grad = vec![2.0 * (value - target_value)];
// Backprop through heads.
let (policy_input_grad, p_wg, p_bg) =
self.policy_head.backward(&policy_grad, &p_pre, &p_in);
let (value_input_grad, v_wg, v_bg) =
self.value_head.backward(&value_grad, &v_pre, &v_in);
// The trunk's output feeds both heads, so its gradient is the sum.
let h2_output_grad: Vec<f32> = policy_input_grad
.iter()
.zip(value_input_grad.iter())
.map(|(&a, &b)| a + b)
.collect();
// Backprop through trunk.
let (h1_output_grad, h2_wg, h2_bg) =
self.hidden2.backward(&h2_output_grad, &h2_pre, &h2_in);
let (_, h1_wg, h1_bg) = self.hidden1.backward(&h1_output_grad, &h1_pre, &h1_in);
// === Apply gradients (SGD) ===
self.hidden1.apply_gradients(&h1_wg, &h1_bg, lr);
self.hidden2.apply_gradients(&h2_wg, &h2_bg, lr);
self.policy_head.apply_gradients(&p_wg, &p_bg, lr);
self.value_head.apply_gradients(&v_wg, &v_bg, lr);
total_loss
}
}
}
There is one subtle but important detail in the policy gradient. When you combine softmax with cross-entropy loss, the gradient of the loss with respect to the logits (not the probabilities) simplifies to predicted - target. This is one of the most elegant results in neural network math, and it is why softmax + cross-entropy is such a popular combination. We take advantage of this by passing the simplified gradient directly to the policy head’s backward method, which treats the head as having Identity activation — the softmax derivative is already baked into the predicted - target formula.
Another detail: where the trunk branches into the two heads, the gradient flowing back is the sum of the gradients from both heads. This follows from calculus — if a variable contributes to two terms in the loss, its total gradient is the sum of its contributions to each.
SGD optimizer
Our optimizer is stochastic gradient descent (SGD) in its simplest form: weight -= learning_rate * gradient. The apply_gradients method on Layer already implements this. There are fancier optimizers (Adam, RMSProp) that adapt the learning rate per-parameter, but plain SGD works well for small networks. The learning rate is a hyperparameter we will tune — a typical starting point is 0.01.
Putting it all together
Here is the complete, compilable code. All the pieces above are assembled into a single program that creates a network, runs a forward pass on a sample board, and performs a training step:
use std::fmt;
// ─── Player and GameState (from Part 1) ──────────────────────────
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player { X, O }
impl Player {
pub fn opponent(self) -> Self {
match self { Player::X => Player::O, Player::O => Player::X }
}
}
impl fmt::Display for Player {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { Player::X => write!(f, "X"), Player::O => write!(f, "O") }
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
pub board: [Option<Player>; 9],
pub current_player: Player,
}
impl GameState {
pub fn new() -> Self {
GameState { board: [None; 9], current_player: Player::X }
}
}
// ─── Random number generator ─────────────────────────────────────
struct Rng { state: u64 }
impl Rng {
fn new(seed: u64) -> Self { Rng { state: seed } }
fn next_f32(&mut self) -> f32 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
let bits = ((self.state >> 17) & 0x7FFF) as f32;
bits / 32768.0
}
fn uniform(&mut self, limit: f32) -> f32 {
self.next_f32() * 2.0 * limit - limit
}
}
// ─── Matrix utilities ────────────────────────────────────────────
fn mat_mul(a: &[f32], b: &[f32], rows: usize, inner: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0.0; rows * cols];
for r in 0..rows {
for c in 0..cols {
let mut sum = 0.0;
for k in 0..inner {
sum += a[r * inner + k] * b[k * cols + c];
}
out[r * cols + c] = sum;
}
}
out
}
fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
// ─── Layer ───────────────────────────────────────────────────────
#[derive(Debug, Clone, Copy)]
enum Activation { Relu, Tanh, Identity }
#[derive(Debug, Clone)]
struct Layer {
weights: Vec<f32>,
biases: Vec<f32>,
input_size: usize,
output_size: usize,
activation: Activation,
}
impl Layer {
fn new(
input_size: usize,
output_size: usize,
activation: Activation,
rng: &mut Rng,
) -> Self {
let limit = (6.0 / (input_size + output_size) as f32).sqrt();
let weights: Vec<f32> = (0..input_size * output_size)
.map(|_| rng.uniform(limit))
.collect();
let biases = vec![0.0; output_size];
Layer { weights, biases, input_size, output_size, activation }
}
fn forward(&self, input: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let pre_act = vec_add(
&mat_mul(input, &self.weights, 1, self.input_size, self.output_size),
&self.biases,
);
let output = match self.activation {
Activation::Relu => pre_act.iter().map(|&x| x.max(0.0)).collect(),
Activation::Tanh => pre_act.iter().map(|&x| x.tanh()).collect(),
Activation::Identity => pre_act.clone(),
};
(output, pre_act, input.to_vec())
}
fn backward(
&self,
output_grad: &[f32],
pre_act: &[f32],
input: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let act_grad: Vec<f32> = output_grad
.iter()
.zip(pre_act.iter())
.map(|(&og, &pa)| {
let d_act = match self.activation {
Activation::Relu => if pa > 0.0 { 1.0 } else { 0.0 },
Activation::Tanh => 1.0 - pa.tanh().powi(2),
Activation::Identity => 1.0,
};
og * d_act
})
.collect();
let bias_grad = act_grad.clone();
let mut weight_grad = vec![0.0; self.input_size * self.output_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
weight_grad[i * self.output_size + j] = input[i] * act_grad[j];
}
}
let mut input_grad = vec![0.0; self.input_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
input_grad[i] += act_grad[j] * self.weights[i * self.output_size + j];
}
}
(input_grad, weight_grad, bias_grad)
}
fn apply_gradients(&mut self, weight_grad: &[f32], bias_grad: &[f32], lr: f32) {
for (w, g) in self.weights.iter_mut().zip(weight_grad.iter()) {
*w -= lr * g;
}
for (b, g) in self.biases.iter_mut().zip(bias_grad.iter()) {
*b -= lr * g;
}
}
}
// ─── Network ─────────────────────────────────────────────────────
#[derive(Debug, Clone)]
struct Network {
hidden1: Layer,
hidden2: Layer,
policy_head: Layer,
value_head: Layer,
}
fn encode_board(state: &GameState) -> Vec<f32> {
let mut input = vec![0.0_f32; 18];
let current = state.current_player;
for i in 0..9 {
match state.board[i] {
Some(p) if p == current => input[i] = 1.0,
Some(_) => input[i + 9] = 1.0,
None => {}
}
}
input
}
fn mask_illegal(policy: &[f32], state: &GameState) -> Vec<f32> {
let mut masked: Vec<f32> = policy
.iter()
.enumerate()
.map(|(i, &p)| if state.board[i].is_none() { p } else { 0.0 })
.collect();
let sum: f32 = masked.iter().sum();
if sum > 0.0 {
for p in &mut masked {
*p /= sum;
}
}
masked
}
fn cross_entropy_loss(target: &[f32], predicted: &[f32]) -> f32 {
let eps = 1e-8;
-target
.iter()
.zip(predicted.iter())
.map(|(&t, &p)| t * (p + eps).ln())
.sum::<f32>()
}
fn mse_loss(target: f32, predicted: f32) -> f32 {
(target - predicted).powi(2)
}
impl Network {
fn new(seed: u64) -> Self {
let mut rng = Rng::new(seed);
Network {
hidden1: Layer::new(18, 64, Activation::Relu, &mut rng),
hidden2: Layer::new(64, 64, Activation::Relu, &mut rng),
policy_head: Layer::new(64, 9, Activation::Identity, &mut rng),
value_head: Layer::new(64, 1, Activation::Tanh, &mut rng),
}
}
fn predict(&self, state: &GameState) -> (Vec<f32>, f32) {
let input = encode_board(state);
let (h1_out, _, _) = self.hidden1.forward(&input);
let (h2_out, _, _) = self.hidden2.forward(&h1_out);
let (policy_logits, _, _) = self.policy_head.forward(&h2_out);
let policy = softmax(&policy_logits);
let policy = mask_illegal(&policy, state);
let (value_out, _, _) = self.value_head.forward(&h2_out);
(policy, value_out[0])
}
fn train_step(
&mut self,
state: &GameState,
target_policy: &[f32],
target_value: f32,
lr: f32,
) -> f32 {
let input = encode_board(state);
let (h1_out, h1_pre, h1_in) = self.hidden1.forward(&input);
let (h2_out, h2_pre, h2_in) = self.hidden2.forward(&h1_out);
let (p_logits, p_pre, p_in) = self.policy_head.forward(&h2_out);
let (v_out, v_pre, v_in) = self.value_head.forward(&h2_out);
let policy = softmax(&p_logits);
let value = v_out[0];
let policy_loss = cross_entropy_loss(target_policy, &policy);
let value_loss = mse_loss(target_value, value);
let total_loss = policy_loss + value_loss;
// Backprop: policy head
let policy_grad: Vec<f32> = policy
.iter()
.zip(target_policy.iter())
.map(|(&p, &t)| p - t)
.collect();
let value_grad = vec![2.0 * (value - target_value)];
let (pg_input, p_wg, p_bg) =
self.policy_head.backward(&policy_grad, &p_pre, &p_in);
let (vg_input, v_wg, v_bg) =
self.value_head.backward(&value_grad, &v_pre, &v_in);
// Sum gradients from both heads at the trunk output.
let h2_grad: Vec<f32> = pg_input.iter()
.zip(vg_input.iter())
.map(|(&a, &b)| a + b)
.collect();
let (h1_grad, h2_wg, h2_bg) =
self.hidden2.backward(&h2_grad, &h2_pre, &h2_in);
let (_, h1_wg, h1_bg) =
self.hidden1.backward(&h1_grad, &h1_pre, &h1_in);
// SGD update.
self.hidden1.apply_gradients(&h1_wg, &h1_bg, lr);
self.hidden2.apply_gradients(&h2_wg, &h2_bg, lr);
self.policy_head.apply_gradients(&p_wg, &p_bg, lr);
self.value_head.apply_gradients(&v_wg, &v_bg, lr);
total_loss
}
}
// ─── Test ────────────────────────────────────────────────────────
fn main() {
let mut net = Network::new(42);
// Create a sample board position:
// X | O | .
// ---------
// . | X | .
// ---------
// . | . | O
let mut state = GameState::new();
state.board[0] = Some(Player::X);
state.board[1] = Some(Player::O);
state.board[4] = Some(Player::X);
state.board[8] = Some(Player::O);
// It is X's turn (4 pieces placed, 2 each).
// Forward pass.
let (policy, value) = net.predict(&state);
println!("Board:");
for r in 0..3 {
for c in 0..3 {
let i = r * 3 + c;
match state.board[i] {
Some(Player::X) => print!(" X "),
Some(Player::O) => print!(" O "),
None => print!(" . "),
}
if c < 2 { print!("|"); }
}
println!();
if r < 2 { println!("-----------"); }
}
println!("\nPolicy (move probabilities):");
for i in 0..9 {
let label = if state.board[i].is_some() { "occupied" } else { "legal" };
println!(" cell {}: {:.4} ({})", i, policy[i], label);
}
let policy_sum: f32 = policy.iter().sum();
println!(" sum: {:.4}", policy_sum);
println!("\nValue: {:.4}", value);
// Verify shapes and constraints.
assert_eq!(policy.len(), 9, "Policy must have 9 elements");
assert!((policy_sum - 1.0).abs() < 1e-5, "Policy must sum to ~1.0");
assert!(policy[0] == 0.0, "Occupied cell 0 must have 0 probability");
assert!(policy[1] == 0.0, "Occupied cell 1 must have 0 probability");
assert!(policy[4] == 0.0, "Occupied cell 4 must have 0 probability");
assert!(policy[8] == 0.0, "Occupied cell 8 must have 0 probability");
assert!(value >= -1.0 && value <= 1.0, "Value must be in [-1, +1]");
println!("\nAll shape checks passed!\n");
// Training test: run a few training steps and verify the loss decreases.
let target_policy = [0.0, 0.0, 0.3, 0.2, 0.0, 0.2, 0.1, 0.2, 0.0];
let target_value = 0.5;
let lr = 0.01;
let loss_before = net.train_step(&state, &target_policy, target_value, lr);
println!("Loss before training: {:.4}", loss_before);
// Run 100 training steps on the same example.
let mut last_loss = loss_before;
for _ in 0..100 {
last_loss = net.train_step(&state, &target_policy, target_value, lr);
}
println!("Loss after 100 steps: {:.4}", last_loss);
assert!(
last_loss < loss_before,
"Loss should decrease after training"
);
println!("Training is working — loss decreased.\n");
}
Running this program produces output like:
Board:
X | O | .
-----------
. | X | .
-----------
. | . | O
Policy (move probabilities):
cell 0: 0.0000 (occupied)
cell 1: 0.0000 (occupied)
cell 2: 0.1891 (legal)
cell 3: 0.2004 (legal)
cell 4: 0.0000 (occupied)
cell 5: 0.2147 (legal)
cell 6: 0.1965 (legal)
cell 7: 0.1993 (legal)
cell 8: 0.0000 (occupied)
sum: 1.0000
Value: -0.0274
All shape checks passed!
Loss before training: 2.0835
Loss after 100 steps: 0.5921
Training is working — loss decreased.
The exact numbers depend on the random seed, but the structure is always the same: occupied cells have zero probability, legal cells share the remaining probability mass (roughly uniform initially because the weights are random), the value is a small number near zero, and training reduces the loss.
What we built
Let’s step back and appreciate what we have. From zero external dependencies, we implemented:
- A complete feedforward neural network with dense layers, ReLU/tanh/softmax activations, and Xavier initialization.
- A forward pass that converts a board position into a policy distribution and value estimate.
- Illegal move masking that ensures the network never suggests occupied cells.
- Cross-entropy and MSE loss functions to measure prediction quality.
- Full backpropagation through the entire network, including the tricky softmax + cross-entropy gradient simplification and gradient summation at the trunk fork.
- SGD optimization that updates weights to reduce the loss.
This is a real, working neural network — the same fundamental architecture (input → hidden layers → output heads, trained with backprop and SGD) that powers everything from image classifiers to language models. Ours happens to have 6,026 parameters instead of billions, but the principles are identical.
In the next section, we will use this network to generate training data through MCTS self-play and teach the network to actually play good Tic-Tac-Toe.
11. Exercise 3: train the network on MCTS data
We have two powerful components sitting side by side: a pure-MCTS player that discovers strong moves through thousands of random rollouts (§7-8), and a neural network that can be trained to predict move probabilities and position values (§9-10). Right now they are strangers. This exercise introduces them.
The idea is simple but profound: use MCTS as a teacher and the neural network as a student. MCTS plays games against itself, generating high-quality data about which moves are good and who is winning. We then train the network to mimic MCTS’s judgments. If successful, the network will learn to instantly predict what MCTS would conclude after thousands of iterations of search.
This is exactly how AlphaGo Zero bootstraps its neural network. The network starts knowing nothing — its weights are random, its predictions are noise. But MCTS does not need the network. MCTS can play reasonably well using random rollouts alone. So we let MCTS generate training data, train the network on it, and later (in §12) we will plug the trained network back into MCTS to make it even stronger.
What data does the network need?
Recall from §10 that our network has two output heads:
- Policy head: a probability distribution over the 9 cells — “which move should I play?”
- Value head: a single number in [-1, +1] — “who is winning from this position?”
To train these heads, we need labeled examples: positions where we know the right policy and the right value. MCTS gives us both.
Policy targets from visit counts. After MCTS finishes searching a position (say, 1000 iterations), the root’s children have accumulated visit counts. The most-visited move is the one MCTS considers best, but the full visit distribution is more informative. If cell 4 got 500 visits, cell 0 got 200, and cell 2 got 300, the visit distribution is [0.2, 0.0, 0.3, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0]. This is our policy target — it tells the network not just the best move, but how the search budget was distributed across all moves.
Value targets from game outcomes. After the game finishes, we know the result: X won (+1), O won (-1), or draw (0). We assign this outcome to every position in the game, adjusted for perspective. If X won, then positions where X is the current player get a value target of +1 (the current player won), and positions where O is the current player get -1 (the current player lost).
The TrainingExample struct
Each position in a self-play game produces one training example:
#![allow(unused)]
fn main() {
/// A single training example generated from MCTS self-play.
///
/// Contains the board position, the MCTS visit distribution as a
/// policy target, and the eventual game outcome as a value target.
struct TrainingExample {
/// The game state at this position.
state: GameState,
/// The MCTS visit distribution (visits per move / total visits).
/// This is the policy target — what the network should learn to predict.
policy_target: [f32; 9],
/// The game outcome from the perspective of the current player.
/// +1.0 = current player won, -1.0 = current player lost, 0.0 = draw.
value_target: f32,
}
}
Generating training data
The data generation pipeline works as follows:
- Start a new game from the empty board.
- At each position, run MCTS (e.g., 800 iterations) to get the visit distribution.
- Record a
TrainingExamplewith the current state and the visit distribution. Leave the value target blank for now. - Play the move that MCTS chose (the most-visited child).
- Repeat steps 2-4 until the game ends.
- Once the game ends, go back and fill in the value target for every recorded position using the game outcome.
- Repeat for many games (e.g., 100-500 games).
The key function we need is one that extracts the visit distribution from an MCTS tree after search:
#![allow(unused)]
fn main() {
/// Extracts the visit distribution from the root of an MCTS tree.
///
/// Returns an array of 9 floats representing the fraction of visits
/// each move received. Moves that were never visited (or are illegal)
/// get 0.0. The values sum to 1.0.
fn visit_distribution(tree: &MctsTree) -> [f32; 9] {
let root = &tree.nodes[0];
let total: u32 = root.children.iter()
.map(|&idx| tree.nodes[idx].visit_count)
.sum();
let mut dist = [0.0_f32; 9];
if total == 0 {
return dist;
}
for &child_idx in &root.children {
let child = &tree.nodes[child_idx];
let mv = child.move_from_parent
.expect("root children always have a move");
dist[mv] = child.visit_count as f32 / total as f32;
}
dist
}
}
Now we can write the self-play data generation:
#![allow(unused)]
fn main() {
/// Plays one complete game using MCTS and collects training examples.
///
/// Returns a vector of `TrainingExample` — one per move in the game.
/// Each example has the board state, the MCTS visit distribution as
/// the policy target, and the game outcome as the value target.
fn generate_game_data(
mcts_iterations: u32,
rng: &mut impl rand::Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
// Play the game, recording each position and its visit distribution.
while !state.is_terminal() {
let mut tree = MctsTree::new(state.clone());
tree.search(mcts_iterations, rng);
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
let best_move = tree.best_move();
state = state.apply_move(best_move);
}
// Determine the game outcome.
let winner = state.winner();
// Convert history into training examples with value targets.
history
.into_iter()
.map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample {
state: s,
policy_target,
value_target,
}
})
.collect()
}
}
To generate a full dataset, we call this function many times:
#![allow(unused)]
fn main() {
/// Generates training data from many self-play games.
fn generate_training_data(
num_games: u32,
mcts_iterations: u32,
rng: &mut impl rand::Rng,
) -> Vec<TrainingExample> {
let mut all_examples = Vec::new();
for game_num in 1..=num_games {
let examples = generate_game_data(mcts_iterations, rng);
all_examples.extend(examples);
if game_num % 50 == 0 {
println!(
"Generated {} games ({} examples so far)",
game_num,
all_examples.len()
);
}
}
println!(
"Total: {} games, {} training examples",
num_games,
all_examples.len()
);
all_examples
}
}
Each Tic-Tac-Toe game lasts 5-9 moves, so 100 games produce roughly 500-900 training examples. With 800 MCTS iterations per move, generating 100 games takes a few seconds on a modern machine.
The training loop
Now we train the network on this data. The approach is standard mini-batch stochastic gradient descent:
- Shuffle the training examples to break correlations (examples from the same game are adjacent, which could bias learning).
- Split into mini-batches of size B (e.g., 32).
- For each mini-batch, process every example through
train_stepand accumulate the average loss. - Repeat for multiple epochs (full passes through the dataset).
- Track the loss — it should decrease over time.
Why mini-batches rather than processing one example at a time? Single-example SGD is noisy — the gradient from one position might point in a misleading direction. Averaging gradients over a batch of 32 examples smooths out the noise, leading to more stable learning. Why not process the entire dataset at once? That would be standard gradient descent, which is slow for large datasets because you only update weights once per pass. Mini-batches strike a balance: stable enough to converge, frequent enough to learn quickly.
Here is the training loop:
#![allow(unused)]
fn main() {
/// Shuffles the training examples in place using Fisher-Yates.
fn shuffle(examples: &mut [TrainingExample], rng: &mut impl rand::Rng) {
for i in (1..examples.len()).rev() {
let j = rng.gen_range(0..=i);
examples.swap(i, j);
}
}
/// Trains the network on the given examples for a specified number
/// of epochs using mini-batch SGD.
///
/// Returns the average loss from the final epoch.
fn train_network(
net: &mut Network,
examples: &mut Vec<TrainingExample>,
epochs: u32,
batch_size: usize,
learning_rate: f32,
rng: &mut impl rand::Rng,
) -> f32 {
let mut final_avg_loss = 0.0;
for epoch in 0..epochs {
shuffle(examples, rng);
let mut epoch_loss = 0.0;
let mut count = 0;
// Process in mini-batches.
for batch_start in (0..examples.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(examples.len());
for i in batch_start..batch_end {
let ex = &examples[i];
let loss = net.train_step(
&ex.state,
&ex.policy_target,
ex.value_target,
learning_rate,
);
epoch_loss += loss;
count += 1;
}
}
let avg_loss = epoch_loss / count as f32;
final_avg_loss = avg_loss;
if (epoch + 1) % 5 == 0 || epoch == 0 {
println!("Epoch {:>3}: avg loss = {:.4}", epoch + 1, avg_loss);
}
}
final_avg_loss
}
}
A note on our mini-batch implementation: because our train_step method applies gradients after every single example (online SGD), processing examples in “batches” here is really about grouping for reporting purposes. A true mini-batch implementation would accumulate gradients across the batch and apply them once. For our small network and dataset, per-example SGD works well enough. The shuffle is what matters most — it ensures the network sees a diverse mix of positions rather than all moves from the same game in sequence.
Your task
Write a program that:
- Generates training data from 100 MCTS self-play games (800 iterations per move).
- Creates a fresh neural network.
- Trains the network for 30 epochs with a learning rate of 0.01 and a batch size of 32.
- Evaluates the trained network on a few sample positions:
- The empty board (opening position).
- A mid-game position.
- Prints the network’s policy predictions and compares them to what MCTS would choose.
Try writing this yourself before looking at the solution. You already have all the building blocks — this exercise is about wiring them together.
Solution
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand::Rng;
use std::f64::consts::SQRT_2;
use std::fmt;
// ─── Player and GameState ────────────────────────────────────────
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player { X, O }
impl Player {
pub fn opponent(self) -> Self {
match self { Player::X => Player::O, Player::O => Player::X }
}
}
impl fmt::Display for Player {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { Player::X => write!(f, "X"), Player::O => write!(f, "O") }
}
}
const WIN_LINES: [[usize; 3]; 8] = [
[0, 1, 2], [3, 4, 5], [6, 7, 8],
[0, 3, 6], [1, 4, 7], [2, 5, 8],
[0, 4, 8], [2, 4, 6],
];
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
pub board: [Option<Player>; 9],
pub current_player: Player,
}
impl GameState {
pub fn new() -> Self {
GameState { board: [None; 9], current_player: Player::X }
}
pub fn winner(&self) -> Option<Player> {
for line in &WIN_LINES {
let [a, b, c] = *line;
if let (Some(p1), Some(p2), Some(p3)) =
(self.board[a], self.board[b], self.board[c])
{
if p1 == p2 && p2 == p3 {
return Some(p1);
}
}
}
None
}
pub fn is_terminal(&self) -> bool {
self.winner().is_some() || self.board.iter().all(|c| c.is_some())
}
pub fn legal_moves(&self) -> Vec<usize> {
(0..9).filter(|&i| self.board[i].is_none()).collect()
}
pub fn apply_move(&self, index: usize) -> Self {
assert!(
self.board[index].is_none(),
"Cell {} is already occupied",
index
);
let mut new_board = self.board;
new_board[index] = Some(self.current_player);
GameState {
board: new_board,
current_player: self.current_player.opponent(),
}
}
}
impl fmt::Display for GameState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for r in 0..3 {
for c in 0..3 {
let i = r * 3 + c;
match self.board[i] {
Some(Player::X) => write!(f, " X ")?,
Some(Player::O) => write!(f, " O ")?,
None => write!(f, " . ")?,
}
if c < 2 { write!(f, "|")?; }
}
writeln!(f)?;
if r < 2 { writeln!(f, "-----------")?; }
}
Ok(())
}
}
// ─── MCTS ────────────────────────────────────────────────────────
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
}
struct MctsTree {
nodes: Vec<MctsNode>,
}
fn uct_score(child_value: f64, child_visits: u32, parent_visits: u32) -> f64 {
if child_visits == 0 {
return f64::INFINITY;
}
let n = child_visits as f64;
let exploitation = child_value / n;
let exploration = SQRT_2 * ((parent_visits as f64).ln() / n).sqrt();
exploitation + exploration
}
impl MctsTree {
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
MctsTree { nodes: vec![root] }
}
fn select(&self) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
if !node.unexpanded_moves.is_empty() { return current; }
if node.children.is_empty() { return current; }
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = uct_score(
child.total_value,
child.visit_count,
parent_visits,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
fn expand(&mut self, node_idx: usize) -> Option<usize> {
let mv = self.nodes[node_idx].unexpanded_moves.pop()?;
let child_state = self.nodes[node_idx].state.apply_move(mv);
let unexpanded = child_state.legal_moves();
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
Some(child_idx)
}
fn simulate(&self, state: &GameState, rng: &mut impl Rng) -> f64 {
let mut current = state.clone();
while !current.is_terminal() {
let moves = current.legal_moves();
let &mv = moves.choose(rng).expect("non-terminal has moves");
current = current.apply_move(mv);
}
let rollout_player = state.current_player.opponent();
match current.winner() {
Some(winner) if winner == rollout_player => 1.0,
Some(_) => -1.0,
None => 0.0,
}
}
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
fn search(&mut self, iterations: u32, rng: &mut impl Rng) -> usize {
for _ in 0..iterations {
let selected = self.select();
let node_to_simulate = match self.expand(selected) {
Some(child_idx) => child_idx,
None => selected,
};
let state = self.nodes[node_to_simulate].state.clone();
let value = self.simulate(&state, rng);
self.backpropagate(node_to_simulate, value);
}
self.best_move()
}
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have children after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root children have moves")
}
}
// ─── Neural network (from §9-10) ────────────────────────────────
struct SimpleRng { state: u64 }
impl SimpleRng {
fn new(seed: u64) -> Self { SimpleRng { state: seed } }
fn next_f32(&mut self) -> f32 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
let bits = ((self.state >> 17) & 0x7FFF) as f32;
bits / 32768.0
}
fn uniform(&mut self, limit: f32) -> f32 {
self.next_f32() * 2.0 * limit - limit
}
}
fn mat_mul(a: &[f32], b: &[f32], rows: usize, inner: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0.0; rows * cols];
for r in 0..rows {
for c in 0..cols {
let mut sum = 0.0;
for k in 0..inner {
sum += a[r * inner + k] * b[k * cols + c];
}
out[r * cols + c] = sum;
}
}
out
}
fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
#[derive(Debug, Clone, Copy)]
enum Activation { Relu, Tanh, Identity }
#[derive(Debug, Clone)]
struct Layer {
weights: Vec<f32>,
biases: Vec<f32>,
input_size: usize,
output_size: usize,
activation: Activation,
}
impl Layer {
fn new(
input_size: usize,
output_size: usize,
activation: Activation,
rng: &mut SimpleRng,
) -> Self {
let limit = (6.0 / (input_size + output_size) as f32).sqrt();
let weights: Vec<f32> = (0..input_size * output_size)
.map(|_| rng.uniform(limit))
.collect();
let biases = vec![0.0; output_size];
Layer { weights, biases, input_size, output_size, activation }
}
fn forward(&self, input: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let pre_act = vec_add(
&mat_mul(input, &self.weights, 1, self.input_size, self.output_size),
&self.biases,
);
let output = match self.activation {
Activation::Relu => pre_act.iter().map(|&x| x.max(0.0)).collect(),
Activation::Tanh => pre_act.iter().map(|&x| x.tanh()).collect(),
Activation::Identity => pre_act.clone(),
};
(output, pre_act, input.to_vec())
}
fn backward(
&self,
output_grad: &[f32],
pre_act: &[f32],
input: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let act_grad: Vec<f32> = output_grad
.iter()
.zip(pre_act.iter())
.map(|(&og, &pa)| {
let d_act = match self.activation {
Activation::Relu => if pa > 0.0 { 1.0 } else { 0.0 },
Activation::Tanh => 1.0 - pa.tanh().powi(2),
Activation::Identity => 1.0,
};
og * d_act
})
.collect();
let bias_grad = act_grad.clone();
let mut weight_grad = vec![0.0; self.input_size * self.output_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
weight_grad[i * self.output_size + j] = input[i] * act_grad[j];
}
}
let mut input_grad = vec![0.0; self.input_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
input_grad[i] += act_grad[j] * self.weights[i * self.output_size + j];
}
}
(input_grad, weight_grad, bias_grad)
}
fn apply_gradients(&mut self, weight_grad: &[f32], bias_grad: &[f32], lr: f32) {
for (w, g) in self.weights.iter_mut().zip(weight_grad.iter()) {
*w -= lr * g;
}
for (b, g) in self.biases.iter_mut().zip(bias_grad.iter()) {
*b -= lr * g;
}
}
}
#[derive(Debug, Clone)]
struct Network {
hidden1: Layer,
hidden2: Layer,
policy_head: Layer,
value_head: Layer,
}
fn encode_board(state: &GameState) -> Vec<f32> {
let mut input = vec![0.0_f32; 18];
let current = state.current_player;
for i in 0..9 {
match state.board[i] {
Some(p) if p == current => input[i] = 1.0,
Some(_) => input[i + 9] = 1.0,
None => {}
}
}
input
}
fn mask_illegal(policy: &[f32], state: &GameState) -> Vec<f32> {
let mut masked: Vec<f32> = policy
.iter()
.enumerate()
.map(|(i, &p)| if state.board[i].is_none() { p } else { 0.0 })
.collect();
let sum: f32 = masked.iter().sum();
if sum > 0.0 {
for p in &mut masked {
*p /= sum;
}
}
masked
}
fn cross_entropy_loss(target: &[f32], predicted: &[f32]) -> f32 {
let eps = 1e-8;
-target
.iter()
.zip(predicted.iter())
.map(|(&t, &p)| t * (p + eps).ln())
.sum::<f32>()
}
fn mse_loss(target: f32, predicted: f32) -> f32 {
(target - predicted).powi(2)
}
impl Network {
fn new(seed: u64) -> Self {
let mut rng = SimpleRng::new(seed);
Network {
hidden1: Layer::new(18, 64, Activation::Relu, &mut rng),
hidden2: Layer::new(64, 64, Activation::Relu, &mut rng),
policy_head: Layer::new(64, 9, Activation::Identity, &mut rng),
value_head: Layer::new(64, 1, Activation::Tanh, &mut rng),
}
}
fn predict(&self, state: &GameState) -> (Vec<f32>, f32) {
let input = encode_board(state);
let (h1_out, _, _) = self.hidden1.forward(&input);
let (h2_out, _, _) = self.hidden2.forward(&h1_out);
let (policy_logits, _, _) = self.policy_head.forward(&h2_out);
let policy = softmax(&policy_logits);
let policy = mask_illegal(&policy, state);
let (value_out, _, _) = self.value_head.forward(&h2_out);
(policy, value_out[0])
}
fn train_step(
&mut self,
state: &GameState,
target_policy: &[f32],
target_value: f32,
lr: f32,
) -> f32 {
let input = encode_board(state);
let (h1_out, h1_pre, h1_in) = self.hidden1.forward(&input);
let (h2_out, h2_pre, h2_in) = self.hidden2.forward(&h1_out);
let (p_logits, p_pre, p_in) = self.policy_head.forward(&h2_out);
let (v_out, v_pre, v_in) = self.value_head.forward(&h2_out);
let policy = softmax(&p_logits);
let value = v_out[0];
let policy_loss = cross_entropy_loss(target_policy, &policy);
let value_loss = mse_loss(target_value, value);
let total_loss = policy_loss + value_loss;
let policy_grad: Vec<f32> = policy
.iter()
.zip(target_policy.iter())
.map(|(&p, &t)| p - t)
.collect();
let value_grad = vec![2.0 * (value - target_value)];
let (pg_input, p_wg, p_bg) =
self.policy_head.backward(&policy_grad, &p_pre, &p_in);
let (vg_input, v_wg, v_bg) =
self.value_head.backward(&value_grad, &v_pre, &v_in);
let h2_grad: Vec<f32> = pg_input.iter()
.zip(vg_input.iter())
.map(|(&a, &b)| a + b)
.collect();
let (h1_grad, h2_wg, h2_bg) =
self.hidden2.backward(&h2_grad, &h2_pre, &h2_in);
let (_, h1_wg, h1_bg) =
self.hidden1.backward(&h1_grad, &h1_pre, &h1_in);
self.hidden1.apply_gradients(&h1_wg, &h1_bg, lr);
self.hidden2.apply_gradients(&h2_wg, &h2_bg, lr);
self.policy_head.apply_gradients(&p_wg, &p_bg, lr);
self.value_head.apply_gradients(&v_wg, &v_bg, lr);
total_loss
}
}
// ─── Training data generation ────────────────────────────────────
struct TrainingExample {
state: GameState,
policy_target: [f32; 9],
value_target: f32,
}
/// Extracts the visit distribution from the root of an MCTS tree.
fn visit_distribution(tree: &MctsTree) -> [f32; 9] {
let root = &tree.nodes[0];
let total: u32 = root.children.iter()
.map(|&idx| tree.nodes[idx].visit_count)
.sum();
let mut dist = [0.0_f32; 9];
if total == 0 {
return dist;
}
for &child_idx in &root.children {
let child = &tree.nodes[child_idx];
let mv = child.move_from_parent
.expect("root children always have a move");
dist[mv] = child.visit_count as f32 / total as f32;
}
dist
}
/// Plays one complete game using MCTS and collects training examples.
fn generate_game_data(
mcts_iterations: u32,
rng: &mut impl Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
while !state.is_terminal() {
let mut tree = MctsTree::new(state.clone());
tree.search(mcts_iterations, rng);
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
let best_move = tree.best_move();
state = state.apply_move(best_move);
}
let winner = state.winner();
history
.into_iter()
.map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample {
state: s,
policy_target,
value_target,
}
})
.collect()
}
/// Generates training data from many self-play games.
fn generate_training_data(
num_games: u32,
mcts_iterations: u32,
rng: &mut impl Rng,
) -> Vec<TrainingExample> {
let mut all_examples = Vec::new();
for game_num in 1..=num_games {
let examples = generate_game_data(mcts_iterations, rng);
all_examples.extend(examples);
if game_num % 50 == 0 {
println!(
"Generated {} games ({} examples so far)",
game_num,
all_examples.len()
);
}
}
println!(
"Total: {} games, {} training examples",
num_games,
all_examples.len()
);
all_examples
}
// ─── Training loop ───────────────────────────────────────────────
/// Shuffles the training examples using Fisher-Yates.
fn shuffle_examples(examples: &mut [TrainingExample], rng: &mut impl Rng) {
for i in (1..examples.len()).rev() {
let j = rng.gen_range(0..=i);
examples.swap(i, j);
}
}
/// Trains the network on the given examples for the specified
/// number of epochs. Returns the average loss from the final epoch.
fn train_network(
net: &mut Network,
examples: &mut Vec<TrainingExample>,
epochs: u32,
batch_size: usize,
learning_rate: f32,
rng: &mut impl Rng,
) -> f32 {
let mut final_avg_loss = 0.0;
for epoch in 0..epochs {
shuffle_examples(examples, rng);
let mut epoch_loss = 0.0;
let mut count = 0;
for batch_start in (0..examples.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(examples.len());
for i in batch_start..batch_end {
let ex = &examples[i];
let loss = net.train_step(
&ex.state,
&ex.policy_target,
ex.value_target,
learning_rate,
);
epoch_loss += loss;
count += 1;
}
}
let avg_loss = epoch_loss / count as f32;
final_avg_loss = avg_loss;
if (epoch + 1) % 5 == 0 || epoch == 0 {
println!("Epoch {:>3}: avg loss = {:.4}", epoch + 1, avg_loss);
}
}
final_avg_loss
}
// ─── Evaluation ──────────────────────────────────────────────────
/// Prints the network's policy and value predictions for a position,
/// alongside the MCTS visit distribution for comparison.
fn evaluate_position(
label: &str,
state: &GameState,
net: &Network,
mcts_iterations: u32,
rng: &mut impl Rng,
) {
// Get network predictions.
let (net_policy, net_value) = net.predict(state);
// Get MCTS visit distribution.
let mut tree = MctsTree::new(state.clone());
tree.search(mcts_iterations, rng);
let mcts_dist = visit_distribution(&tree);
println!("=== {} ===", label);
println!("{}", state);
println!("Current player: {}", state.current_player);
println!();
println!("Cell Network MCTS");
println!("---- ------- -----");
for i in 0..9 {
let label = if state.board[i].is_some() { " (occupied)" } else { "" };
println!(
" {} {:.3} {:.3}{}",
i, net_policy[i], mcts_dist[i], label
);
}
println!();
let net_best = net_policy
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
let mcts_best = tree.best_move();
println!(
"Network best move: {} | MCTS best move: {} | {}",
net_best,
mcts_best,
if net_best == mcts_best { "MATCH" } else { "differ" }
);
println!("Network value: {:.4} (positive = good for {})", net_value, state.current_player);
println!();
}
// ─── Main ────────────────────────────────────────────────────────
fn main() {
let mut rng = StdRng::seed_from_u64(42);
// Step 1: Generate training data.
println!("--- Generating training data ---\n");
let mut examples = generate_training_data(100, 800, &mut rng);
// Step 2: Create a fresh network.
let mut net = Network::new(123);
// Step 3: Evaluate BEFORE training (to see the baseline).
println!("\n--- Before training ---\n");
let opening = GameState::new();
evaluate_position("Opening (before training)", &opening, &net, 1000, &mut rng);
// Step 4: Train.
println!("--- Training ---\n");
let final_loss = train_network(&mut net, &mut examples, 30, 32, 0.01, &mut rng);
println!("\nFinal average loss: {:.4}\n", final_loss);
// Step 5: Evaluate AFTER training.
println!("--- After training ---\n");
// Opening position.
let opening = GameState::new();
evaluate_position("Opening (after training)", &opening, &net, 1000, &mut rng);
// Mid-game: X took center, O took corner.
let mut mid_game = GameState::new();
mid_game = mid_game.apply_move(4); // X takes center
mid_game = mid_game.apply_move(0); // O takes corner
evaluate_position("Mid-game: X=center, O=corner", &mid_game, &net, 1000, &mut rng);
// Late game: X is about to win.
let mut late_game = GameState::new();
late_game = late_game.apply_move(4); // X center
late_game = late_game.apply_move(0); // O corner
late_game = late_game.apply_move(2); // X top-right
late_game = late_game.apply_move(3); // O mid-left
// X's turn — should play cell 6 to win (diagonal 2-4-6).
evaluate_position("Late game: X to play (winning move at 6)", &late_game, &net, 1000, &mut rng);
}
Running the solution
The output will look something like this (exact numbers vary with the RNG seed):
--- Generating training data ---
Generated 50 games (346 examples so far)
Generated 100 games (671 examples so far)
Total: 100 games, 671 training examples
--- Before training ---
=== Opening (before training) ===
. | . | .
-----------
. | . | .
-----------
. | . | .
Current player: X
Cell Network MCTS
---- ------- -----
0 0.104 0.077
1 0.098 0.058
2 0.101 0.078
3 0.124 0.061
4 0.126 0.349
5 0.104 0.067
6 0.098 0.081
7 0.140 0.066
8 0.106 0.163
Network best move: 7 | MCTS best move: 4 | differ
Network value: -0.0274 (positive = good for X)
--- Training ---
Epoch 1: avg loss = 2.3849
Epoch 5: avg loss = 1.7542
Epoch 10: avg loss = 1.4321
Epoch 15: avg loss = 1.2188
Epoch 20: avg loss = 1.0763
Epoch 25: avg loss = 0.9612
Epoch 30: avg loss = 0.8895
Final average loss: 0.8895
--- After training ---
=== Opening (after training) ===
Cell Network MCTS
---- ------- -----
0 0.084 0.079
1 0.053 0.055
2 0.092 0.083
3 0.066 0.064
4 0.371 0.348
5 0.068 0.069
6 0.089 0.076
7 0.064 0.059
8 0.113 0.167
Network best move: 4 | MCTS best move: 4 | MATCH
Network value: 0.0412 (positive = good for X)
Several things to notice in the output:
The loss decreases steadily. From around 2.4 in the first epoch to around 0.9 by epoch 30. This confirms the network is learning — its predictions are getting closer to the MCTS targets.
Before training, the network’s policy is roughly uniform. With random weights, the network assigns approximately equal probability to all legal moves. It has no idea that the center is special.
After training, the network strongly prefers the center. Cell 4 gets around 35-40% of the probability mass, closely matching the MCTS visit distribution. The network has learned from MCTS data that the center is the best opening move.
The value prediction for the opening is near zero. This is correct — Tic-Tac-Toe is a theoretical draw with perfect play, so neither player has an advantage from the starting position. A value near 0.0 reflects this.
The network’s best move matches MCTS on common positions. After training, the network agrees with MCTS on the opening and many mid-game positions. It has internalized the strategic patterns that MCTS discovers through search.
Understanding what was learned
The trained network has essentially compressed thousands of MCTS searches into a set of weights. When MCTS searches the opening position, it runs 800 iterations of tree traversal and random rollouts to determine that the center is best. The trained network reaches the same conclusion with a single forward pass — a few matrix multiplications taking microseconds.
This is the key insight of AlphaGo Zero’s approach: the neural network is a fast approximation of a slow search. The search is accurate but expensive. The network is less accurate but nearly instant. And as we will see in §12, we can combine them: use the network to guide the search, making the search both faster and more accurate.
What to look for (and what can go wrong)
Loss not decreasing? Check your learning rate. If it is too high (e.g., 0.1), the gradients overshoot and training oscillates. If it is too low (e.g., 0.0001), learning is too slow to see progress in 30 epochs. Start with 0.01.
Network always predicts uniform policy? This can happen if you forget to shuffle the training data. Without shuffling, the network sees all positions from game 1, then all positions from game 2, and so on. It overfits to the most recent game and forgets everything else.
Value predictions stuck near zero? This is partly expected for Tic-Tac-Toe because most MCTS self-play games end in draws (value target = 0.0). The value head has less signal to learn from compared to the policy head. You can verify it is working by checking value predictions on clearly winning or losing positions.
Network disagrees with MCTS on some positions? This is normal. 100 games of training data is not enough for the network to perfectly approximate MCTS on every position. Rare or unusual board states may not appear in the training data at all. More training games and more epochs improve coverage.
Experiments to try
-
Vary the number of training games. Try 50, 100, 200, and 500 games. How does the final loss change? How much better does the network agree with MCTS on the evaluation positions?
-
Vary MCTS iterations per move. Try 200, 800, and 2000. Higher iteration counts produce higher-quality training targets, but take longer to generate. Is the improvement in network quality worth the extra computation?
-
Vary the learning rate. Try 0.001, 0.01, and 0.05. Plot (or print) the loss curve for each. Which converges fastest? Which achieves the lowest final loss?
-
Train for more epochs. Try 50 or 100 epochs instead of 30. Does the loss keep decreasing, or does it plateau? If it plateaus, the network has learned everything it can from this dataset and needs more training data to improve further.
-
Test on unusual positions. Create board states that are unlikely to appear in MCTS self-play (e.g., positions where one player made obvious blunders). How does the network handle positions outside its training distribution?
These experiments will build your intuition for the key hyperparameters in neural network training — data quantity, data quality, learning rate, and training duration. These same levers matter whether you are training a Tic-Tac-Toe network with 6,000 parameters or a language model with 6 billion.
12. Exercise 4: replace rollout with the value network
In §11 we trained the neural network to mimic MCTS: given a board position, predict the MCTS visit distribution (policy) and the game outcome (value). The network learned to compress thousands of iterations of random-rollout search into a single forward pass. Now we close the loop — we plug the network back into MCTS, replacing the two weakest parts of the algorithm with neural network predictions.
This is the key innovation from AlphaGo Zero. Pure MCTS (§7-8) has two limitations:
- Random rollouts are noisy. Playing random moves to the end of the game is a crude way to evaluate a position. It takes many iterations to average out the noise.
- Blind exploration. UCT treats every unexplored move equally. It has to try each one at least once before it can start discriminating. In games with large branching factors, this wastes a lot of search budget.
The neural network fixes both problems:
- The value head replaces rollouts. Instead of playing random moves to a terminal state, we call
network.predict(state)and use the value output directly. One forward pass, one number — no randomness, no wasted computation. - The policy head guides exploration. Instead of treating all moves equally, we use the policy output as a prior probability on each move. Moves the network considers promising get explored first and more often.
The exercise
Modify your MCTS implementation from §8 to accept a Network and use it in place of random rollouts. You need to make four changes:
- Add a
priorfield toMctsNode— the network’s policy probability for the move that led to this node. - Replace the UCT formula with PUCT (Predictor + UCT), which incorporates the prior.
- When expanding a node, set child priors from the network’s policy output.
- Replace the
simulatefunction with a network forward pass.
Try implementing these changes yourself before looking at the solution. The rest of this section explains each change in detail.
Change 1: add a prior to MctsNode
The original MctsNode has no notion of “how promising is this move?” — it discovers that purely through simulation. We need to add a field that stores the network’s prior probability:
#![allow(unused)]
fn main() {
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
/// Prior probability from the network's policy head.
/// This is the network's estimate of how good the move
/// that led to this node is, before any search.
prior: f64,
}
}
The root node has no parent move, so its prior does not matter — set it to 0.0 (or any value; it is never read during selection).
Change 2: the PUCT formula
In §7 we used the UCT formula for selection:
UCT(child) = W/N + C * sqrt(ln(N_parent) / N_child)
AlphaGo Zero replaces this with PUCT (Predictor + Upper Confidence bounds applied to Trees):
PUCT(child) = Q + c_puct * P * sqrt(N_parent) / (1 + N_child)
where:
- Q =
total_value / visit_count— the average value of this child (exploitation). Same as the W/N term in UCT. - P =
prior— the network’s policy probability for this move. A high prior means the network thinks this move is promising before any search. - N_parent = parent’s visit count.
- N_child = this child’s visit count.
- c_puct = a constant controlling exploration vs exploitation. AlphaGo Zero uses values around 1.0-2.5; we will use 1.5 for Tic-Tac-Toe.
Notice the differences from UCT:
- No logarithm. UCT has
ln(N_parent)in the exploration term; PUCT usessqrt(N_parent)directly. This makes PUCT explore more aggressively early on. - Prior probability P. This is the key addition. Moves with high prior get a large exploration bonus, so the search visits them early. Moves with near-zero prior are effectively pruned — the search may never visit them if better options exist.
- Denominator is (1 + N_child). When a child has zero visits, the exploration term is
c_puct * P * sqrt(N_parent), which is large and finite (unlike UCT, which returns infinity for unvisited nodes). This means the prior determines the order in which unvisited children are explored — not just “all unvisited children are equally attractive.”
In Rust:
#![allow(unused)]
fn main() {
/// Computes the PUCT score for a child node.
///
/// Unlike UCT, PUCT uses the network's prior probability to bias
/// exploration toward moves the network considers promising.
fn puct_score(
child_value: f64,
child_visits: u32,
parent_visits: u32,
prior: f64,
c_puct: f64,
) -> f64 {
let q = if child_visits == 0 {
0.0
} else {
child_value / child_visits as f64
};
let exploration = c_puct
* prior
* (parent_visits as f64).sqrt()
/ (1.0 + child_visits as f64);
q + exploration
}
}
When a child has zero visits, Q defaults to 0.0 (neutral — we have no information yet). The exploration term is then c_puct * P * sqrt(N_parent), which is entirely determined by the prior. The network’s opinion drives the initial exploration order.
Change 3: set priors during expansion
In the original MCTS, expand creates a child node with no prior information. In network-guided MCTS, we query the network when expanding and set each child’s prior from the policy output.
The cleanest approach: when we expand a node, we call network.predict on the node’s state to get the full policy vector, then create all children at once (setting each child’s prior from the corresponding policy entry). This differs from the original MCTS, which expanded one child at a time from unexpanded_moves. Expanding all children at once is simpler and matches the AlphaGo Zero approach — the network evaluates the position once, and we use both outputs (policy for child priors, value for backpropagation).
#![allow(unused)]
fn main() {
/// Expands all children of a node at once, using the network's
/// policy output to set priors. Returns the value estimate for
/// backpropagation.
fn expand_with_network(
&mut self,
node_idx: usize,
network: &Network,
) -> f64 {
let state = &self.nodes[node_idx].state;
// If the position is terminal, return the true outcome
// instead of a network estimate.
if state.is_terminal() {
let eval_player = state.current_player.opponent();
return match state.winner() {
Some(winner) if winner == eval_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
}
let (policy, value) = network.predict(state);
// Create a child for every legal move, setting the prior
// from the network's policy output.
let legal_moves: Vec<usize> = self.nodes[node_idx]
.unexpanded_moves
.drain(..)
.collect();
for mv in &legal_moves {
let child_state = self.nodes[node_idx].state.apply_move(*mv);
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(*mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: Vec::new(), // expanded later
prior: policy[*mv] as f64,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
}
// Populate unexpanded_moves for each child (for future expansion).
let children: Vec<usize> = self.nodes[node_idx].children.clone();
for &child_idx in &children {
let moves = self.nodes[child_idx].state.legal_moves();
self.nodes[child_idx].unexpanded_moves = moves;
}
// Return the value from the perspective of the player who
// moved to reach this node (the opponent of current_player).
// The network returns value from current_player's perspective,
// so we negate it.
-(value as f64)
}
}
The value sign is important. The network’s value output is from the perspective of state.current_player (the player whose turn it is). But total_value at a node is stored from the perspective of the player who moved to reach this node — that is state.current_player.opponent(). So we negate the network’s value before backpropagation.
Change 4: replace simulate with evaluate
The simulate function is gone entirely. Instead of playing random moves to a terminal state, the expand_with_network function already returns the value estimate. The search loop becomes:
#![allow(unused)]
fn main() {
/// Runs network-guided MCTS for the given number of iterations.
fn search_with_network(
&mut self,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
for _ in 0..iterations {
// Selection: walk down the tree using PUCT.
let selected = self.select_puct(c_puct);
// Expansion + evaluation: expand all children and get
// the network's value estimate in one step.
let value = self.expand_with_network(selected, network);
// Backpropagation: same as before.
self.backpropagate(selected, value);
}
self.best_move()
}
}
The selection phase also needs updating to use PUCT instead of UCT:
#![allow(unused)]
fn main() {
/// Selection using PUCT scores instead of UCT.
fn select_puct(&self, c_puct: f64) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
// If this node has unexpanded moves, select it for expansion.
if !node.unexpanded_moves.is_empty() {
return current;
}
// If this is a terminal node (no children, no moves), stop.
if node.children.is_empty() {
return current;
}
// Pick the child with the highest PUCT score.
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = puct_score(
child.total_value,
child.visit_count,
parent_visits,
child.prior,
c_puct,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
}
This is structurally identical to the original select — the only change is calling puct_score instead of uct_score.
Complete modified MCTS
Solution: full network-guided MCTS implementation
#![allow(unused)]
fn main() {
use rand::seq::SliceRandom;
// ─── Network-guided MCTS ──────────────────────────────────────────
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
/// Prior probability from the network's policy head.
prior: f64,
}
struct MctsTree {
nodes: Vec<MctsNode>,
}
/// Computes the PUCT score for a child node.
fn puct_score(
child_value: f64,
child_visits: u32,
parent_visits: u32,
prior: f64,
c_puct: f64,
) -> f64 {
let q = if child_visits == 0 {
0.0
} else {
child_value / child_visits as f64
};
let exploration = c_puct
* prior
* (parent_visits as f64).sqrt()
/ (1.0 + child_visits as f64);
q + exploration
}
impl MctsTree {
/// Creates a new tree rooted at the given game state.
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
prior: 0.0, // root prior is unused
};
MctsTree { nodes: vec![root] }
}
/// Selection using PUCT scores.
fn select_puct(&self, c_puct: f64) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
if !node.unexpanded_moves.is_empty() { return current; }
if node.children.is_empty() { return current; }
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = puct_score(
child.total_value,
child.visit_count,
parent_visits,
child.prior,
c_puct,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
/// Expands all children of a node using the network's policy,
/// and returns the value estimate for backpropagation.
fn expand_with_network(
&mut self,
node_idx: usize,
network: &Network,
) -> f64 {
let state = &self.nodes[node_idx].state;
if state.is_terminal() {
let eval_player = state.current_player.opponent();
return match state.winner() {
Some(winner) if winner == eval_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
}
let (policy, value) = network.predict(state);
let legal_moves: Vec<usize> = self.nodes[node_idx]
.unexpanded_moves
.drain(..)
.collect();
for mv in &legal_moves {
let child_state = self.nodes[node_idx].state.apply_move(*mv);
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(*mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: Vec::new(),
prior: policy[*mv] as f64,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
}
// Fill in unexpanded_moves for each new child.
let children: Vec<usize> = self.nodes[node_idx].children.clone();
for &child_idx in &children {
let moves = self.nodes[child_idx].state.legal_moves();
self.nodes[child_idx].unexpanded_moves = moves;
}
// Negate: network returns value for current_player,
// but we need value for the player who moved here.
-(value as f64)
}
/// Backpropagation — identical to the pure MCTS version.
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
/// Returns the move of the root's most-visited child.
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have children after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root children have moves")
}
/// Runs network-guided MCTS and returns the best move.
fn search_with_network(
&mut self,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
for _ in 0..iterations {
let selected = self.select_puct(c_puct);
let value = self.expand_with_network(selected, network);
self.backpropagate(selected, value);
}
self.best_move()
}
}
/// Convenience function: choose a move using network-guided MCTS.
fn mcts_network_move(
state: &GameState,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(iterations, network, c_puct)
}
}
Testing: pure MCTS vs network-guided MCTS
Now let’s pit the two approaches against each other. We will run a tournament: pure MCTS (random rollouts) vs network-guided MCTS (using the network trained in §11), and count wins, losses, and draws.
Solution: tournament between pure and network-guided MCTS
/// Plays one game: network-guided MCTS (X) vs pure MCTS (O).
/// Returns the outcome from X's perspective.
fn play_one_game(
network: &Network,
net_iterations: u32,
pure_iterations: u32,
c_puct: f64,
rng: &mut impl rand::Rng,
) -> f64 {
let mut state = GameState::new();
while !state.is_terminal() {
let mv = match state.current_player {
Player::X => {
// Network-guided MCTS
mcts_network_move(&state, net_iterations, network, c_puct)
}
Player::O => {
// Pure MCTS with random rollouts
mcts_move(&state, pure_iterations, rng)
}
};
state = state.apply_move(mv);
}
match state.winner() {
Some(Player::X) => 1.0,
Some(Player::O) => -1.0,
None => 0.0,
}
}
fn run_tournament(
network: &Network,
num_games: u32,
net_iterations: u32,
pure_iterations: u32,
c_puct: f64,
) {
let mut rng = rand::thread_rng();
let mut net_wins = 0;
let mut pure_wins = 0;
let mut draws = 0;
for game in 0..num_games {
// Alternate who plays X and O for fairness.
let result = if game % 2 == 0 {
play_one_game(
network, net_iterations, pure_iterations,
c_puct, &mut rng,
)
} else {
// Swap roles: pure MCTS plays X, network plays O.
-play_one_game_swapped(
network, net_iterations, pure_iterations,
c_puct, &mut rng,
)
};
match result.partial_cmp(&0.0) {
Some(std::cmp::Ordering::Greater) => net_wins += 1,
Some(std::cmp::Ordering::Less) => pure_wins += 1,
_ => draws += 1,
}
}
println!("Results over {} games:", num_games);
println!(
" Network-guided MCTS ({} iters): {} wins",
net_iterations, net_wins
);
println!(
" Pure MCTS ({} iters): {} wins",
pure_iterations, pure_wins
);
println!(" Draws: {}", draws);
}
fn main() {
// Load or train the network (from §11).
let network = train_network();
// Fair comparison: same iteration budget.
println!("=== Same budget: 200 iterations each ===");
run_tournament(&network, 100, 200, 200, 1.5);
// Efficiency test: can the network do more with less?
println!("\n=== Network 100 iters vs Pure 500 iters ===");
run_tournament(&network, 100, 100, 500, 1.5);
}
What should you expect? With a well-trained network (from §11 with 100+ self-play games), the network-guided MCTS should:
- Win most games at equal iteration counts. At 200 iterations each, network-guided MCTS has a significant advantage because every iteration is informed by learned knowledge rather than random noise.
- Compete with fewer iterations. Network-guided MCTS at 100 iterations should perform comparably to pure MCTS at 500 iterations. The network replaces thousands of random moves with a single forward pass that encodes the patterns learned from training data.
If the network-guided version performs worse than pure MCTS, the network is not trained well enough. Go back to §11 and train on more games (200+) or more iterations per move (1000+).
Why is this better?
Consider what happens in pure MCTS when it encounters a new position. It must try every legal move at least once, then play random moves to the end of the game, repeating hundreds of times to build up reliable statistics. Most of those random rollouts are wasted — they follow nonsensical move sequences that no reasonable player would choose.
Network-guided MCTS skips all of this. The policy head says “move 4 looks best, move 0 is second, the rest are unlikely” — so the search spends most of its budget on the two or three most promising moves. The value head says “this position is +0.6 for the current player” — one number that summarizes what random rollouts would have (noisily) converged to after thousands of games.
This is the core insight of AlphaGo Zero: a learned evaluation function turns MCTS from a brute-force explorer into a focused, efficient search. The same iteration budget goes much further because the search is guided by knowledge rather than randomness.
But there is a subtlety. The network’s predictions are only as good as its training data. If the network was trained on weak MCTS play (few iterations, few games), its priors may be wrong, and it could steer the search in bad directions. This is why the full AlphaGo Zero loop (§13) keeps iterating: the network guides better MCTS, which generates better training data, which trains a better network, and so on. Each generation is stronger than the last.
Experiments to try
-
Vary c_puct. Try 0.5, 1.0, 1.5, and 3.0. Low values make the search exploit the network’s prior more (less exploration). High values make the search explore more broadly despite the prior. Which value wins the most games?
-
Vary the iteration budget. Run network-guided MCTS at 50, 100, 200, and 500 iterations against pure MCTS at 1000. At what point does the network-guided version start winning consistently?
-
Test with an untrained network. Create a
Network::new()with random weights and use it to guide MCTS. It should perform worse than pure MCTS, because the random priors actively mislead the search. This confirms that the improvement comes from learned knowledge, not from the PUCT formula itself. -
Inspect the search tree. After running network-guided MCTS, print the visit counts at the root’s children. Compare with pure MCTS on the same position. The network-guided version should concentrate visits on fewer moves — the ones the network considers most promising.
Part 5 — Self-Play Loop
13. The full AlphaGo Zero training loop
In §11 we trained a neural network to mimic pure MCTS. In §12 we plugged that network back into MCTS, replacing random rollouts with learned evaluation. The result was a stronger player — but a static one. The network was trained once on data from pure MCTS, and that was it. It could not improve beyond the quality of its training data.
This section closes the loop. Instead of training once and stopping, we iterate: the network guides MCTS to generate better training data, we train a new network on that data, check that it actually improved, and repeat. Each cycle produces a stronger player than the last. This is the full AlphaGo Zero self-improvement pipeline.
The training cycle
The AlphaGo Zero training loop has three phases that repeat indefinitely:
┌──────────────────────┐
│ 1. Self-Play │
│ Network-guided │
│ MCTS plays games │
│ against itself. │
│ Collect (state, │
│ policy, outcome) │
│ triples. │
└────────┬─────────────┘
│
▼
┌──────────────────────┐
│ 2. Train │
│ Update the network │
│ on the collected │
│ data. Produces a │
│ candidate network. │
└────────┬─────────────┘
│
▼
┌──────────────────────┐
┌─No──│ 3. Evaluate │
│ │ Pit the candidate │
│ │ against the current│
│ │ best network. │
│ │ Does it win > 55%? │
│ └────────┬─────────────┘
│ │ Yes
│ ▼
│ ┌──────────────────────┐
│ │ Candidate becomes │
│ │ the new best. │
│ └────────┬─────────────┘
│ │
└──────────────┘
│
▼
Back to step 1
Step 1 — Self-play. The current best network guides MCTS to play complete games against itself. At each position, MCTS runs for a fixed number of iterations (e.g., 200), using the network’s policy head for move priors and value head for position evaluation. We record every position along with the MCTS visit distribution (the policy target) and, once the game ends, the outcome (the value target). This produces a batch of TrainingExample triples.
Step 2 — Train. We create a copy of the current best network and train it on the self-play data for several epochs. This produces a candidate network that (hopefully) has internalized the patterns discovered during self-play.
Step 3 — Evaluate. We play a set of evaluation games between the candidate network and the current best network, each using MCTS with the same iteration budget. If the candidate wins more than a threshold (typically 55% of games), it becomes the new best. Otherwise, we discard it and try again.
Then we go back to step 1 and generate new self-play data using the (possibly updated) best network.
Why the evaluation gate matters
It might seem wasteful to discard a freshly trained network. Why not always use the latest one? The problem is that training can sometimes make things worse. A few specific failure modes:
-
Policy collapse. The network becomes overconfident in a single move and assigns near-zero probability to everything else. MCTS guided by this network effectively stops searching — it always plays the same move. The resulting self-play data is narrow and repetitive, which reinforces the collapse in the next training round.
-
Catastrophic forgetting. Training on a new batch of games can overwrite knowledge from previous iterations. The network might learn to handle mid-game positions well but forget how to play openings, or vice versa.
-
Overfitting to noise. With small datasets (which is what Tic-Tac-Toe generates), the network can memorize training examples rather than learning general patterns. A memorized network performs well on positions it has seen but poorly on novel ones.
The evaluation gate catches all of these. If the candidate cannot beat the current best in head-to-head play, something went wrong during training, and we keep the proven network. This acts as a ratchet — strength can only go up, never down.
In practice with Tic-Tac-Toe, the evaluation gate rarely rejects a candidate because the game is simple enough that training is stable. But for larger games, this safeguard is essential.
Temperature in move selection
During self-play, we do not always pick the move with the highest visit count. Instead, we use a temperature parameter to control how deterministic the move selection is.
After MCTS completes a search, the root’s children have visit counts \(N_1, N_2, \ldots, N_k\). We convert these into a probability distribution using temperature \(\tau\):
P(move_i) = N_i^(1/τ) / Σ_j N_j^(1/τ)
The temperature controls exploration vs exploitation:
-
τ = 1 (temperature = 1): The probability is proportional to visit counts. If move A got 60% of visits and move B got 40%, we pick A 60% of the time and B 40% of the time. This adds randomness, which creates diverse training data.
-
τ → 0 (low temperature): The distribution becomes sharply peaked. The most-visited move gets probability ~1.0, everything else gets ~0.0. This is essentially greedy — always play the best move. Good for strong play, bad for exploration.
-
τ > 1 (high temperature): The distribution flattens. Even rarely-visited moves get a reasonable chance of being selected. This creates maximum diversity but at the cost of playing some weak moves.
AlphaGo Zero uses a schedule: for the first N moves of each game, use τ = 1 (explore). After that, use τ → 0 (exploit). For Tic-Tac-Toe, a reasonable choice is τ = 1 for the first 3 moves and τ → 0 (greedy) for the rest.
Why does this matter? If self-play always picks the best move (τ = 0 everywhere), every game follows the same trajectory. The network only sees a handful of positions and learns nothing about the rest of the game tree. With τ = 1 in the opening, games branch into diverse positions, producing a richer training set that covers more of the state space.
Here is the temperature-scaled move selection in Rust:
#![allow(unused)]
fn main() {
/// Selects a move from the MCTS root using temperature-scaled
/// visit counts.
///
/// When temperature is very small (< 0.01), picks the most-visited
/// move deterministically. Otherwise, samples from the distribution
/// P(move_i) = N_i^(1/τ) / Σ N_j^(1/τ).
fn select_move_with_temperature(
tree: &MctsTree,
temperature: f64,
rng: &mut impl rand::Rng,
) -> usize {
let root = &tree.nodes[0];
if temperature < 0.01 {
// Greedy: pick the most-visited child.
return tree.best_move();
}
// Compute temperature-scaled visit counts.
let inv_temp = 1.0 / temperature;
let scaled: Vec<f64> = root
.children
.iter()
.map(|&idx| (tree.nodes[idx].visit_count as f64).powf(inv_temp))
.collect();
let total: f64 = scaled.iter().sum();
if total == 0.0 {
// Fallback: random legal move.
return tree.best_move();
}
// Sample from the distribution.
let mut r: f64 = rng.gen::<f64>() * total;
for (i, &s) in scaled.iter().enumerate() {
r -= s;
if r <= 0.0 {
let child_idx = root.children[i];
return tree.nodes[child_idx]
.move_from_parent
.expect("child must have a move");
}
}
// Rounding edge case — return last child.
let last = *root.children.last().unwrap();
tree.nodes[last]
.move_from_parent
.expect("child must have a move")
}
}
Self-play with network-guided MCTS
The self-play function is similar to generate_game_data from §11, but with two key differences: it uses network-guided MCTS (from §12) instead of pure MCTS, and it uses temperature for move selection.
#![allow(unused)]
fn main() {
/// Plays one complete game using network-guided MCTS and collects
/// training examples.
///
/// Uses temperature = 1.0 for the first `temp_moves` moves
/// (exploration), then temperature → 0 (exploitation) for
/// the rest.
fn self_play_game(
network: &Network,
mcts_iterations: u32,
c_puct: f64,
temp_moves: usize,
rng: &mut impl rand::Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
let mut move_number = 0;
while !state.is_terminal() {
// Run network-guided MCTS from the current position.
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(mcts_iterations, network, c_puct);
// Record the visit distribution as the policy target.
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
// Select a move using temperature.
let temperature = if move_number < temp_moves {
1.0
} else {
0.0 // greedy
};
let chosen_move =
select_move_with_temperature(&tree, temperature, rng);
state = state.apply_move(chosen_move);
move_number += 1;
}
// Determine the game outcome and build training examples.
let winner = state.winner();
history
.into_iter()
.map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample {
state: s,
policy_target,
value_target,
}
})
.collect()
}
}
Notice that the policy target is always the raw visit distribution from MCTS — not the temperature-adjusted distribution. We want the network to learn what MCTS thinks is best, not the exploration noise we added for diversity. The temperature only affects which move is played (and therefore which positions appear later in the game).
Evaluating the candidate network
After training a candidate network, we need to know if it is actually better than the current best. The simplest approach: play a set of games where each side uses network-guided MCTS with its respective network, and count wins.
#![allow(unused)]
fn main() {
/// Plays one evaluation game between two networks.
///
/// `network_x` plays as X, `network_o` plays as O.
/// Returns +1.0 if X wins, -1.0 if O wins, 0.0 for draw.
fn play_evaluation_game(
network_x: &Network,
network_o: &Network,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut state = GameState::new();
while !state.is_terminal() {
let network = match state.current_player {
Player::X => network_x,
Player::O => network_o,
};
// Evaluation games use temperature = 0 (always play
// the best move). No exploration — we want to measure
// true strength.
let mv = mcts_network_move(&state, mcts_iterations, network, c_puct);
state = state.apply_move(mv);
}
match state.winner() {
Some(Player::X) => 1.0,
Some(Player::O) => -1.0,
None => 0.0,
}
}
/// Evaluates a candidate network against the current best.
///
/// Plays `num_games` games, alternating which network plays X
/// and which plays O for fairness. Returns the candidate's win
/// rate (0.0 to 1.0).
fn evaluate_networks(
candidate: &Network,
current_best: &Network,
num_games: u32,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut candidate_wins = 0;
let mut total_decisive = 0;
for game in 0..num_games {
let result = if game % 2 == 0 {
// Candidate plays X
play_evaluation_game(
candidate, current_best,
mcts_iterations, c_puct,
)
} else {
// Candidate plays O (negate result so positive = candidate won)
-play_evaluation_game(
current_best, candidate,
mcts_iterations, c_puct,
)
};
if result > 0.0 {
candidate_wins += 1;
total_decisive += 1;
} else if result < 0.0 {
total_decisive += 1;
}
// Draws don't count toward win rate
}
if total_decisive == 0 {
0.5 // All draws — treat as neutral
} else {
candidate_wins as f64 / total_decisive as f64
}
}
}
A note on draws: Tic-Tac-Toe between strong players ends in a draw most of the time. The evaluation function only counts decisive games (wins and losses) when computing the win rate. If all games are draws, the candidate is considered equal (50%) and does not replace the current best. This avoids the situation where a network that always draws “wins” 100% because it never lost.
The main training loop
Now we tie everything together. The outer loop runs for a fixed number of iterations, each time generating self-play data, training a candidate, and evaluating it.
#![allow(unused)]
fn main() {
/// Hyperparameters for the AlphaGo Zero training loop.
struct TrainingConfig {
/// Number of training iterations (cycles of self-play + train + evaluate).
num_iterations: u32,
/// Number of self-play games per iteration.
games_per_iteration: u32,
/// MCTS iterations per move during self-play.
mcts_iterations: u32,
/// PUCT exploration constant.
c_puct: f64,
/// Number of opening moves to play with temperature = 1.
temp_moves: usize,
/// Training epochs per iteration.
training_epochs: u32,
/// Mini-batch size for training.
batch_size: usize,
/// Learning rate for SGD.
learning_rate: f32,
/// Number of evaluation games to play.
eval_games: u32,
/// Win rate threshold to accept the candidate (e.g., 0.55 = 55%).
win_threshold: f64,
}
/// Runs the full AlphaGo Zero training loop.
///
/// Starts with a randomly initialized network and iteratively
/// improves it through self-play, training, and evaluation.
fn alphago_zero_training(config: &TrainingConfig, seed: u64) -> Network {
let mut rng = StdRng::seed_from_u64(seed);
let mut best_network = Network::new(rng.gen());
println!("Starting AlphaGo Zero training loop");
println!(" Iterations: {}", config.num_iterations);
println!(" Games/iter: {}", config.games_per_iteration);
println!(" MCTS iters: {}", config.mcts_iterations);
println!(" Eval games: {}", config.eval_games);
println!(" Win threshold: {:.0}%", config.win_threshold * 100.0);
println!();
for iteration in 1..=config.num_iterations {
println!("═══ Iteration {}/{} ═══", iteration, config.num_iterations);
// ── Phase 1: Self-play ──────────────────────────────────
println!("Phase 1: Generating self-play data...");
let mut all_examples = Vec::new();
for game in 1..=config.games_per_iteration {
let examples = self_play_game(
&best_network,
config.mcts_iterations,
config.c_puct,
config.temp_moves,
&mut rng,
);
all_examples.extend(examples);
if game % 25 == 0 {
println!(
" Game {}/{} ({} examples so far)",
game, config.games_per_iteration, all_examples.len()
);
}
}
println!(
" Self-play complete: {} examples from {} games",
all_examples.len(),
config.games_per_iteration
);
// ── Phase 2: Train ──────────────────────────────────────
println!("Phase 2: Training candidate network...");
let mut candidate = best_network.clone();
let final_loss = train_network(
&mut candidate,
&mut all_examples,
config.training_epochs,
config.batch_size,
config.learning_rate,
&mut rng,
);
println!(" Training complete: final loss = {:.4}", final_loss);
// ── Phase 3: Evaluate ───────────────────────────────────
println!("Phase 3: Evaluating candidate vs current best...");
let win_rate = evaluate_networks(
&candidate,
&best_network,
config.eval_games,
config.mcts_iterations,
config.c_puct,
);
println!(" Candidate win rate: {:.1}%", win_rate * 100.0);
if win_rate >= config.win_threshold {
println!(" ✓ Candidate accepted as new best network!");
best_network = candidate;
} else {
println!(" ✗ Candidate rejected. Keeping current best.");
}
println!();
}
best_network
}
}
A concrete training schedule for Tic-Tac-Toe
Tic-Tac-Toe has a small state space (at most 5,478 reachable positions) and short games (5-9 moves). The network converges quickly. Here is a training schedule that works well:
fn main() {
let config = TrainingConfig {
num_iterations: 20, // 20 cycles of improvement
games_per_iteration: 100, // 100 self-play games per cycle
mcts_iterations: 200, // 200 MCTS iterations per move
c_puct: 1.5, // PUCT exploration constant
temp_moves: 3, // temperature = 1 for first 3 moves
training_epochs: 20, // 20 epochs of training per cycle
batch_size: 32, // mini-batch size
learning_rate: 0.01, // SGD learning rate
eval_games: 40, // 40 evaluation games
win_threshold: 0.55, // candidate must win > 55%
};
let trained = alphago_zero_training(&config, 42);
// Test the final network on a few positions.
println!("═══ Final Network Evaluation ═══");
let empty = GameState::new();
let (policy, value) = trained.predict(&empty);
let masked = mask_illegal(&policy, &empty);
println!("Empty board:");
println!(" Policy: {:?}", masked);
println!(" Value: {:.3}", value);
println!(" Best move: cell {}", masked
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0
);
// Test against pure MCTS to verify strength.
println!("\nFinal tournament: trained network vs pure MCTS (500 iters)");
run_tournament(&trained, 100, 200, 500, 1.5);
}
This schedule produces roughly 2,000 self-play games total (20 iterations times 100 games), each generating 5-9 training examples. Total training time on a modern machine is under a minute.
What to expect
As training progresses, you should observe several things:
Iteration 1-3: rapid initial improvement. The network starts with random weights, so its predictions are noise. Even one round of self-play data gives the network some signal. The candidate easily beats the random network, and the evaluation gate accepts it every time.
Iteration 4-8: diminishing returns. The network already plays reasonably. Self-play games between identical copies of a decent player produce less diverse data — many games follow similar lines. Improvement slows down, and the candidate occasionally fails the evaluation gate.
Iteration 9-15: convergence. The network approaches near-perfect play. Tic-Tac-Toe is a solved game (perfect play from both sides results in a draw), so the network should converge to a strategy that never loses. Most evaluation games end in draws.
Iteration 15-20: plateau. The network has learned the game. Further iterations produce candidates that tie with the current best but rarely beat it. The evaluation gate rejects most candidates — not because they are worse, but because perfect-vs-perfect play always draws, and draws do not count as wins.
The key insight: the network teaches itself to play the game with zero human knowledge. It starts knowing nothing — only the rules. Through repeated cycles of self-play, training, and evaluation, it discovers opening strategy, tactical patterns, and endgame play entirely on its own. For Tic-Tac-Toe this happens in minutes. For Go (the original AlphaGo Zero domain), it took three days on thousands of TPUs — but the same algorithm produced superhuman play.
Summary
The AlphaGo Zero training loop is a simple cycle with powerful consequences:
- Self-play generates training data from the network’s own games. Temperature in the opening ensures diverse positions.
- Training updates a candidate network on the self-play data.
- Evaluation gates the update — the candidate must prove it is stronger than the current best. This prevents regression.
- Repeat with the improved network, generating higher-quality data each round.
Each component is something we built in earlier sections. §11 showed how to generate training data from MCTS and train a network on it. §12 showed how to use the network to guide MCTS. This section connected them into a loop. The only new pieces are temperature-based move selection (for diverse self-play data) and the evaluation gate (to prevent regression).
The result is an algorithm that, given only the rules of a game, learns to play it at an expert level — by playing against itself, learning from its own games, and steadily improving.
14. Exercise 5: 1000 self-play games; observe improvement
This is the capstone exercise. You have built every piece of the AlphaGo Zero pipeline — the game engine (§6), Monte Carlo Tree Search (§7-8), a neural network with policy and value heads (§9-10), supervised training on MCTS data (§11), network-guided MCTS (§12), and the full training loop (§13). Now you wire them all together, press run, and watch an AI teach itself to play perfect Tic-Tac-Toe from nothing.
The exercise
Run the full AlphaGo Zero training loop for 20 iterations of 50 self-play games each (1,000 games total). After each iteration, print a diagnostics panel so you can observe the network improving in real time. When training finishes, validate that the network has converged to perfect play.
Step 1: Wire up all the modules
Your main.rs needs everything from the previous exercises assembled into one file:
GameState,Player, and the game logic (§6)MctsNode,MctsTree, and MCTS search — both pure and network-guided (§7-8, §12)Networkwith forward pass, backpropagation, andtrain_step(§9-10)TrainingExample,self_play_game,evaluate_networks, andtrain_network(§11, §13)select_move_with_temperatureandvisit_distribution(§13)TrainingConfigandalphago_zero_training(§13)
If you have been building incrementally, you already have all of these. If not, the complete solution at the end of this section provides everything in one place.
Step 2: Add diagnostics to the training loop
Modify the alphago_zero_training function to collect and print diagnostics after each iteration. Here is what to track:
Self-play diagnostics:
#![allow(unused)]
fn main() {
/// Statistics collected during one iteration of self-play.
struct IterationStats {
/// Average number of moves per game.
avg_game_length: f64,
/// Number of games won by X.
x_wins: u32,
/// Number of games won by O.
o_wins: u32,
/// Number of draws.
draws: u32,
/// Total training examples generated.
num_examples: usize,
}
}
To compute these, modify the self-play phase to track game outcomes and lengths:
#![allow(unused)]
fn main() {
// Inside the self-play loop:
let mut total_moves = 0u32;
let mut x_wins = 0u32;
let mut o_wins = 0u32;
let mut draws = 0u32;
for _game in 0..config.games_per_iteration {
let examples = self_play_game(
&best_network,
config.mcts_iterations,
config.c_puct,
config.temp_moves,
&mut rng,
);
// The number of examples equals the number of moves in the game.
let game_length = examples.len() as u32;
total_moves += game_length;
// Check the terminal state to determine the outcome.
// The last example's state had one more move applied to reach
// the terminal position, so we reconstruct it.
if let Some(last) = examples.last() {
// The value_target tells us the outcome from the last
// player's perspective.
match last.value_target.partial_cmp(&0.0) {
Some(std::cmp::Ordering::Greater) => {
// Last position's current player won.
match last.state.current_player {
Player::X => x_wins += 1,
Player::O => o_wins += 1,
}
}
Some(std::cmp::Ordering::Less) => {
match last.state.current_player {
Player::X => o_wins += 1,
Player::O => x_wins += 1,
}
}
_ => draws += 1,
}
}
all_examples.extend(examples);
}
let avg_game_length =
total_moves as f64 / config.games_per_iteration as f64;
}
Training diagnostics — policy loss and value loss:
Split the loss reported by train_network into its two components. If you followed §10, the total loss is policy_loss + value_loss. Modify train_network (or add a wrapper) to return both:
#![allow(unused)]
fn main() {
/// Loss components from one training run.
struct TrainingLoss {
/// Cross-entropy loss between predicted and target policy.
policy_loss: f32,
/// Mean squared error between predicted and target value.
value_loss: f32,
}
}
Sample prediction on the empty board:
After each iteration, run a forward pass on the empty board and print the policy and value. This is the simplest way to see the network’s “opening book” evolving:
#![allow(unused)]
fn main() {
// After each iteration:
let empty = GameState::new();
let (policy, value) = best_network.predict(&empty);
let masked = mask_illegal(&policy, &empty);
let best_cell = masked
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
}
Step 3: Print the diagnostics panel
After each iteration, print a summary block like this:
═══ Iteration 7/20 ═══════════════════════════════════════
Self-play: 50 games, 312 examples
Game length: 6.2 moves (avg)
Outcomes: X wins 12 | O wins 8 | Draws 30
Training: policy_loss = 1.2843 value_loss = 0.4521
Evaluation: candidate win rate = 62.5% → ACCEPTED
Empty board: best move = cell 4 (center)
policy = [.06 .04 .07 .05 .42 .05 .08 .04 .06]
value = +0.031
═══════════════════════════════════════════════════════════
Here is the Rust code that prints this:
#![allow(unused)]
fn main() {
println!("═══ Iteration {}/{} ═══", iteration, config.num_iterations);
println!(" Self-play: {} games, {} examples",
config.games_per_iteration, all_examples.len());
println!(" Game length: {:.1} moves (avg)", avg_game_length);
println!(" Outcomes: X wins {} | O wins {} | Draws {}",
x_wins, o_wins, draws);
println!(" Training: policy_loss = {:.4} value_loss = {:.4}",
loss.policy_loss, loss.value_loss);
println!(" Evaluation: candidate win rate = {:.1}% → {}",
win_rate * 100.0,
if win_rate >= config.win_threshold { "ACCEPTED" } else { "REJECTED" });
let (policy, value) = best_network.predict(&GameState::new());
let masked = mask_illegal(&policy, &GameState::new());
let best_cell = masked
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
let cell_name = match best_cell {
0 => "top-left", 1 => "top-center", 2 => "top-right",
3 => "mid-left", 4 => "center", 5 => "mid-right",
6 => "bot-left", 7 => "bot-center", 8 => "bot-right",
_ => "unknown",
};
println!(" Empty board: best move = cell {} ({})", best_cell, cell_name);
println!(" policy = [{:.2} {:.2} {:.2} {:.2} {:.2} {:.2} {:.2} {:.2} {:.2}]",
masked[0], masked[1], masked[2], masked[3], masked[4],
masked[5], masked[6], masked[7], masked[8]);
println!(" value = {:+.3}", value);
println!();
}
What to expect in the diagnostics
Watch for these trends as training progresses:
Average game length should generally decrease. Early on, the network plays semi-random moves, and games meander for 7-9 moves. As the network improves, games become more decisive — strong opening play leads to faster wins or quicker draws. Expect the average to settle around 5-7 moves.
Win/draw/loss ratio should shift toward draws. In early iterations, the network is weak and games are chaotic — there are many decisive results (X or O wins). As the network approaches perfect play, most self-play games end in draws because both sides defend correctly. By iteration 15-20, you should see 80-100% draws.
Policy loss should decrease steadily. The network is learning to predict MCTS visit distributions more accurately. Starting from random weights (policy loss around 2.0-2.5, which is roughly -ln(1/9)), it should drop below 1.0 within a few iterations. Perfect prediction would be 0.0, but the network will plateau around 0.5-1.0 because MCTS visit distributions have inherent noise.
Value loss should also decrease. Starting around 0.5-1.0 (random predictions vs outcomes in {-1, 0, +1}), it should drop below 0.3 within a few iterations. As more games end in draws, the value loss may stabilize at a low level because the target is consistently 0.0.
Evaluation results should show early candidates accepted (they easily beat the random initial network), middle candidates sometimes rejected (improvements are marginal), and late candidates mostly rejected (perfect play draws against perfect play, so no candidate can demonstrate superiority).
Empty board policy is the most visually satisfying diagnostic. In early iterations, the policy is roughly uniform — the network has no preference. By iteration 5-10, the center (cell 4) should emerge as the clear favorite. By iteration 15-20, the policy should be heavily concentrated on cell 4, with corners (cells 0, 2, 6, 8) receiving the remaining probability and edges (cells 1, 3, 5, 7) getting almost nothing. This matches optimal Tic-Tac-Toe strategy: center is the strongest opening move.
Step 4: Validate convergence to perfect play
After training completes, run three validation tests.
Test 1: Trained network vs pure MCTS. Play 100 games between the trained network (using 200 MCTS iterations) and pure MCTS (using 10,000 iterations — a massive search budget). If the network has converged to perfect play, every game should be a draw. Pure MCTS with 10,000 iterations plays near-perfectly through brute force; the trained network should match it with far less search.
#![allow(unused)]
fn main() {
// Validation: trained network vs strong pure MCTS.
println!("═══ Validation 1: Trained AI vs Pure MCTS (10k iters) ═══");
let mut draws = 0;
let mut rng = StdRng::seed_from_u64(123);
for game in 0..100 {
let mut state = GameState::new();
while !state.is_terminal() {
let mv = match state.current_player {
Player::X => {
if game % 2 == 0 {
mcts_network_move(&state, 200, &trained, 1.5)
} else {
mcts_move(&state, 10_000, &mut rng)
}
}
Player::O => {
if game % 2 == 0 {
mcts_move(&state, 10_000, &mut rng)
} else {
mcts_network_move(&state, 200, &trained, 1.5)
}
}
};
state = state.apply_move(mv);
}
if state.winner().is_none() {
draws += 1;
}
}
println!(" Result: {}/100 draws", draws);
if draws == 100 {
println!(" PASS — the network plays perfectly.");
} else {
println!(" Some decisive games — network may need more training.");
}
}
Test 2: Check the opening move. The optimal first move in Tic-Tac-Toe is the center (cell 4). Verify that the trained network always plays it:
#![allow(unused)]
fn main() {
println!("\n═══ Validation 2: Opening Move ═══");
let empty = GameState::new();
let mv = mcts_network_move(&empty, 200, &trained, 1.5);
println!(" First move: cell {} (expected: 4 = center)", mv);
if mv == 4 {
println!(" PASS — network plays center.");
} else {
println!(" UNEXPECTED — center is the optimal opening.");
}
}
Test 3: Human vs AI. Add an interactive mode where you can play against the trained network. Try to beat it — you should not be able to.
#![allow(unused)]
fn main() {
println!("\n═══ Validation 3: Human vs Trained AI ═══");
println!("You are X. Enter a cell number (0-8):");
println!(" 0 | 1 | 2");
println!(" ---------");
println!(" 3 | 4 | 5");
println!(" ---------");
println!(" 6 | 7 | 8");
println!();
let mut state = GameState::new();
while !state.is_terminal() {
state.print_board();
match state.current_player {
Player::X => {
// Human move
print!("Your move (0-8): ");
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let mut input = String::new();
std::io::stdin().read_line(&mut input).unwrap();
let mv: usize = input.trim().parse().expect("enter 0-8");
if !state.legal_moves().contains(&mv) {
println!("Illegal move. Try again.");
continue;
}
state = state.apply_move(mv);
}
Player::O => {
// AI move
let mv = mcts_network_move(&state, 200, &trained, 1.5);
println!("AI plays cell {}.", mv);
state = state.apply_move(mv);
}
}
}
state.print_board();
match state.winner() {
Some(Player::X) => println!("You win! (The AI has a bug.)"),
Some(Player::O) => println!("AI wins. Better luck next time."),
None => println!("Draw. That is the best you can do against perfect play."),
}
}
What you just built
Take a moment to appreciate what happened. You started with:
- A set of game rules (how to place X’s and O’s).
- A neural network initialized with random noise.
- An algorithm — self-play, train, evaluate, repeat.
And you ended with an AI that plays Tic-Tac-Toe perfectly. No one told it that the center is a good opening move. No one programmed in the concept of a “fork” or a “block.” No one provided any examples of expert play. The network discovered all of this on its own, purely by playing against itself and learning from the results.
The same algorithm — with no structural changes — was used to train AlphaGo Zero, which defeated the world champion at Go. The difference is only in scale: Go has 10^170 possible positions instead of 5,478; AlphaGo Zero used 4.9 million self-play games instead of 1,000; it trained on 64 GPU workers and 19 CPU parameter servers instead of your laptop. But the loop is identical: self-play, train, evaluate, repeat.
This is what makes the AlphaGo Zero approach remarkable. It is a general algorithm. It does not encode any game-specific strategy. Give it the rules of any two-player perfect-information game, and it will learn to play it at an expert level — given enough compute and enough iterations.
Where to go from here
You have a working AlphaGo Zero implementation for Tic-Tac-Toe. Here are some directions to extend it:
Scale up to Connect Four. Connect Four has a 7x6 board (42 cells), branching factor of 7, and games that last up to 42 moves. It is significantly more complex than Tic-Tac-Toe but still tractable on a single machine. You will need to adjust the network architecture (larger hidden layers), increase the MCTS iteration budget, and train for more iterations. The same algorithm applies without changes to its structure.
Add data augmentation via board symmetries. A Tic-Tac-Toe board has 8 symmetries: 4 rotations (0, 90, 180, 270 degrees) and 4 reflections. Every training example can be transformed into 8 equivalent examples for free. This multiplies your effective dataset by 8x and helps the network learn that the game is symmetric. In the AlphaGo Zero paper, this technique was crucial for efficient learning on the Go board’s 8-fold symmetry.
Try different network architectures. The simple two-layer network in this course works for Tic-Tac-Toe but would struggle on larger games. AlphaGo Zero used a deep residual network (ResNet) with 20-40 residual blocks. Try adding residual connections (output = input + layer(input)) and see if training converges faster. If you move to a 2D board game like Connect Four, try convolutional layers — they are natural for grid-structured inputs.
Read the original papers. Now that you understand the algorithm from the ground up, the papers will make much more sense:
- Mastering the Game of Go without Human Knowledge (Silver et al., 2017) — the AlphaGo Zero paper. Describes the full training pipeline you just implemented.
- A General Reinforcement Learning Algorithm that Masters Chess, Shogi, and Go Through Self-Play (Silver et al., 2018) — the AlphaZero paper. Generalizes AlphaGo Zero to chess and shogi with minimal modifications.
- Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model (Schrittwieser et al., 2020) — MuZero. Removes the requirement of knowing the game rules by learning a model of the environment.
Explore other self-play domains. Self-play is not limited to board games. It has been applied to:
- Real-time strategy games (AlphaStar for StarCraft II)
- Card games with hidden information (Pluribus for six-player poker)
- Robotic control (learning dexterous manipulation through simulated self-play)
- Multi-agent emergent behavior (OpenAI’s hide-and-seek experiments)
The pattern is always the same: agents learn by competing against themselves, bootstrapping skill from zero.
Solution: complete main.rs for the full training pipeline
use rand::prelude::*;
// ═══════════════════════════════════════════════════════════
// Game logic (§6)
// ═══════════════════════════════════════════════════════════
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Player {
X,
O,
}
impl Player {
fn opponent(self) -> Player {
match self {
Player::X => Player::O,
Player::O => Player::X,
}
}
}
#[derive(Debug, Clone)]
struct GameState {
board: [Option<Player>; 9],
current_player: Player,
}
impl GameState {
fn new() -> Self {
GameState {
board: [None; 9],
current_player: Player::X,
}
}
fn legal_moves(&self) -> Vec<usize> {
(0..9).filter(|&i| self.board[i].is_none()).collect()
}
fn apply_move(&self, cell: usize) -> GameState {
let mut new = self.clone();
new.board[cell] = Some(self.current_player);
new.current_player = self.current_player.opponent();
new
}
fn winner(&self) -> Option<Player> {
const LINES: [[usize; 3]; 8] = [
[0, 1, 2], [3, 4, 5], [6, 7, 8],
[0, 3, 6], [1, 4, 7], [2, 5, 8],
[0, 4, 8], [2, 4, 6],
];
for line in &LINES {
if let Some(p) = self.board[line[0]] {
if self.board[line[1]] == Some(p)
&& self.board[line[2]] == Some(p)
{
return Some(p);
}
}
}
None
}
fn is_terminal(&self) -> bool {
self.winner().is_some() || self.legal_moves().is_empty()
}
fn print_board(&self) {
for row in 0..3 {
for col in 0..3 {
let i = row * 3 + col;
let ch = match self.board[i] {
Some(Player::X) => "X",
Some(Player::O) => "O",
None => ".",
};
print!(" {}", ch);
}
println!();
}
}
/// Encodes the board as a flat array of f32 for the network.
/// X = +1.0, O = -1.0, empty = 0.0, from the perspective of
/// the current player.
fn encode(&self) -> [f32; 9] {
let mut encoded = [0.0f32; 9];
for i in 0..9 {
encoded[i] = match self.board[i] {
Some(p) if p == self.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
}
encoded
}
}
// ═══════════════════════════════════════════════════════════
// Neural network (§9-10)
// ═══════════════════════════════════════════════════════════
#[derive(Clone)]
struct Network {
// Input (9) → Hidden (64) → Policy (9) + Value (1)
w1: Vec<Vec<f32>>, // 64 x 9
b1: Vec<f32>, // 64
w_policy: Vec<Vec<f32>>, // 9 x 64
b_policy: Vec<f32>, // 9
w_value: Vec<Vec<f32>>, // 1 x 64
b_value: Vec<f32>, // 1
}
impl Network {
fn new(seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
let hidden = 64;
let input = 9;
let w1 = (0..hidden)
.map(|_| (0..input).map(|_| rng.gen_range(-0.5..0.5)).collect())
.collect();
let b1 = vec![0.0; hidden];
let w_policy = (0..9)
.map(|_| (0..hidden).map(|_| rng.gen_range(-0.5..0.5)).collect())
.collect();
let b_policy = vec![0.0; 9];
let w_value = vec![
(0..hidden).map(|_| rng.gen_range(-0.5..0.5)).collect()
];
let b_value = vec![0.0];
Network { w1, b1, w_policy, b_policy, w_value, b_value }
}
fn forward(&self, input: &[f32; 9]) -> (Vec<f32>, [f32; 9], f32) {
// Hidden layer with ReLU
let mut hidden = vec![0.0f32; self.w1.len()];
for (h, (weights, bias)) in hidden.iter_mut()
.zip(self.w1.iter().zip(self.b1.iter()))
{
*h = weights.iter().zip(input.iter())
.map(|(w, x)| w * x).sum::<f32>() + bias;
*h = h.max(0.0); // ReLU
}
// Policy head (softmax applied later)
let mut policy_logits = [0.0f32; 9];
for (i, (weights, bias)) in self.w_policy.iter()
.zip(self.b_policy.iter()).enumerate()
{
policy_logits[i] = weights.iter().zip(hidden.iter())
.map(|(w, h)| w * h).sum::<f32>() + bias;
}
// Value head (tanh)
let value_raw: f32 = self.w_value[0].iter().zip(hidden.iter())
.map(|(w, h)| w * h).sum::<f32>() + self.b_value[0];
let value = value_raw.tanh();
(hidden, policy_logits, value)
}
fn predict(&self, state: &GameState) -> ([f32; 9], f32) {
let input = state.encode();
let (_, logits, value) = self.forward(&input);
// Softmax
let max_logit = logits.iter().cloned()
.fold(f32::NEG_INFINITY, f32::max);
let mut policy = [0.0f32; 9];
let mut sum = 0.0f32;
for i in 0..9 {
policy[i] = (logits[i] - max_logit).exp();
sum += policy[i];
}
for p in policy.iter_mut() {
*p /= sum;
}
(policy, value)
}
fn train_step(
&mut self,
state: &GameState,
policy_target: &[f32; 9],
value_target: f32,
lr: f32,
) -> (f32, f32) {
let input = state.encode();
let (hidden, logits, value) = self.forward(&input);
// Softmax for policy
let max_logit = logits.iter().cloned()
.fold(f32::NEG_INFINITY, f32::max);
let mut policy = [0.0f32; 9];
let mut sum = 0.0f32;
for i in 0..9 {
policy[i] = (logits[i] - max_logit).exp();
sum += policy[i];
}
for p in policy.iter_mut() {
*p /= sum;
}
// Policy loss: cross-entropy
let policy_loss: f32 = -(0..9)
.map(|i| {
if policy_target[i] > 0.0 {
policy_target[i] * policy[i].max(1e-8).ln()
} else {
0.0
}
})
.sum::<f32>();
// Value loss: MSE
let value_loss = (value - value_target).powi(2);
// ── Backprop ──
// Policy gradient: d_logits[i] = policy[i] - target[i]
let mut d_logits = [0.0f32; 9];
for i in 0..9 {
d_logits[i] = policy[i] - policy_target[i];
}
// Value gradient: d_value = 2 * (value - target) * (1 - tanh^2)
let d_value = 2.0 * (value - value_target) * (1.0 - value * value);
// Update policy weights
for i in 0..9 {
for j in 0..hidden.len() {
self.w_policy[i][j] -= lr * d_logits[i] * hidden[j];
}
self.b_policy[i] -= lr * d_logits[i];
}
// Update value weights
for j in 0..hidden.len() {
self.w_value[0][j] -= lr * d_value * hidden[j];
}
self.b_value[0] -= lr * d_value;
// Gradient for hidden layer
let mut d_hidden = vec![0.0f32; hidden.len()];
for j in 0..hidden.len() {
for i in 0..9 {
d_hidden[j] += d_logits[i] * self.w_policy[i][j];
}
d_hidden[j] += d_value * self.w_value[0][j];
// ReLU gradient
if hidden[j] <= 0.0 {
d_hidden[j] = 0.0;
}
}
// Update input-to-hidden weights
for j in 0..hidden.len() {
for k in 0..9 {
self.w1[j][k] -= lr * d_hidden[j] * input[k];
}
self.b1[j] -= lr * d_hidden[j];
}
(policy_loss, value_loss)
}
}
/// Masks illegal moves by zeroing their probabilities and
/// renormalizing.
fn mask_illegal(policy: &[f32; 9], state: &GameState) -> [f32; 9] {
let mut masked = [0.0f32; 9];
let legal = state.legal_moves();
let mut sum = 0.0f32;
for &m in &legal {
masked[m] = policy[m];
sum += policy[m];
}
if sum > 0.0 {
for m in &legal {
masked[*m] /= sum;
}
} else {
// Uniform over legal moves
let uniform = 1.0 / legal.len() as f32;
for &m in &legal {
masked[m] = uniform;
}
}
masked
}
// ═══════════════════════════════════════════════════════════
// MCTS — pure and network-guided (§7-8, §12)
// ═══════════════════════════════════════════════════════════
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
prior: f64,
}
struct MctsTree {
nodes: Vec<MctsNode>,
}
impl MctsTree {
fn new(state: GameState) -> Self {
let moves = state.legal_moves();
let root = MctsNode {
state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: moves,
prior: 0.0,
};
MctsTree { nodes: vec![root] }
}
fn best_move(&self) -> usize {
let root = &self.nodes[0];
root.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.map(|&idx| self.nodes[idx].move_from_parent.unwrap())
.expect("root must have children after search")
}
/// Pure MCTS with random rollouts.
fn search(&mut self, iterations: u32, rng: &mut impl Rng) {
for _ in 0..iterations {
let leaf = self.select(true);
let leaf = self.expand_pure(leaf, rng);
let value = self.rollout(leaf, rng);
self.backpropagate(leaf, value);
}
}
/// Network-guided MCTS.
fn search_with_network(
&mut self,
iterations: u32,
network: &Network,
c_puct: f64,
) {
for _ in 0..iterations {
let leaf = self.select_puct(c_puct);
let (leaf, value) =
self.expand_with_network(leaf, network);
self.backpropagate(leaf, value);
}
}
fn select(&mut self, use_uct: bool) -> usize {
let mut node = 0;
loop {
let n = &self.nodes[node];
if n.state.is_terminal() {
return node;
}
if !n.unexpanded_moves.is_empty() {
return node;
}
// Select child with highest UCT
let parent_visits = n.visit_count as f64;
let c = 1.414;
node = *n.children.iter().max_by(|&&a, &&b| {
let sa = self.uct_score(a, parent_visits, c);
let sb = self.uct_score(b, parent_visits, c);
sa.partial_cmp(&sb).unwrap()
}).unwrap();
}
}
fn select_puct(&self, c_puct: f64) -> usize {
let mut node = 0;
loop {
let n = &self.nodes[node];
if n.state.is_terminal() {
return node;
}
if !n.unexpanded_moves.is_empty() {
return node;
}
let parent_visits = n.visit_count;
node = *n.children.iter().max_by(|&&a, &&b| {
let sa = puct_score(
self.nodes[a].total_value,
self.nodes[a].visit_count,
parent_visits,
self.nodes[a].prior,
c_puct,
);
let sb = puct_score(
self.nodes[b].total_value,
self.nodes[b].visit_count,
parent_visits,
self.nodes[b].prior,
c_puct,
);
sa.partial_cmp(&sb).unwrap()
}).unwrap();
}
}
fn uct_score(&self, idx: usize, parent_visits: f64, c: f64) -> f64 {
let node = &self.nodes[idx];
if node.visit_count == 0 {
return f64::INFINITY;
}
let exploitation =
node.total_value / node.visit_count as f64;
let exploration =
c * (parent_visits.ln() / node.visit_count as f64).sqrt();
exploitation + exploration
}
fn expand_pure(
&mut self,
node: usize,
rng: &mut impl Rng,
) -> usize {
if self.nodes[node].state.is_terminal()
|| self.nodes[node].unexpanded_moves.is_empty()
{
return node;
}
let idx = rng.gen_range(
0..self.nodes[node].unexpanded_moves.len()
);
let mv = self.nodes[node].unexpanded_moves.swap_remove(idx);
let new_state = self.nodes[node].state.apply_move(mv);
let moves = new_state.legal_moves();
let child = MctsNode {
state: new_state,
move_from_parent: Some(mv),
parent: Some(node),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: moves,
prior: 0.0,
};
let child_idx = self.nodes.len();
self.nodes.push(child);
self.nodes[node].children.push(child_idx);
child_idx
}
fn expand_with_network(
&mut self,
node: usize,
network: &Network,
) -> (usize, f64) {
if self.nodes[node].state.is_terminal() {
// Return exact value for terminal states.
let value = match self.nodes[node].state.winner() {
Some(p) if p == self.nodes[node].state.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
return (node, value);
}
if self.nodes[node].unexpanded_moves.is_empty() {
// Already fully expanded — use network value.
let (_, value) =
network.predict(&self.nodes[node].state);
return (node, value as f64);
}
// Get network predictions for this state.
let (policy, value) =
network.predict(&self.nodes[node].state);
let masked =
mask_illegal(&policy, &self.nodes[node].state);
// Expand ALL children at once with priors.
let moves: Vec<usize> =
self.nodes[node].unexpanded_moves.drain(..).collect();
for mv in &moves {
let new_state =
self.nodes[node].state.apply_move(*mv);
let child_moves = new_state.legal_moves();
let child = MctsNode {
state: new_state,
move_from_parent: Some(*mv),
parent: Some(node),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: child_moves,
prior: masked[*mv] as f64,
};
let child_idx = self.nodes.len();
self.nodes.push(child);
self.nodes[node].children.push(child_idx);
}
(node, value as f64)
}
fn rollout(&self, node: usize, rng: &mut impl Rng) -> f64 {
let mut state = self.nodes[node].state.clone();
while !state.is_terminal() {
let moves = state.legal_moves();
let mv = moves[rng.gen_range(0..moves.len())];
state = state.apply_move(mv);
}
match state.winner() {
Some(p) if p == self.nodes[node].state.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
}
}
fn backpropagate(&mut self, mut node: usize, mut value: f64) {
loop {
self.nodes[node].visit_count += 1;
self.nodes[node].total_value += value;
value = -value; // Flip perspective
if let Some(parent) = self.nodes[node].parent {
node = parent;
} else {
break;
}
}
}
}
fn puct_score(
child_value: f64,
child_visits: u32,
parent_visits: u32,
prior: f64,
c_puct: f64,
) -> f64 {
let q = if child_visits == 0 {
0.0
} else {
child_value / child_visits as f64
};
let exploration = c_puct
* prior
* (parent_visits as f64).sqrt()
/ (1.0 + child_visits as f64);
q + exploration
}
/// Returns the best move using pure MCTS with random rollouts.
fn mcts_move(
state: &GameState,
iterations: u32,
rng: &mut impl Rng,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search(iterations, rng);
tree.best_move()
}
/// Returns the best move using network-guided MCTS.
fn mcts_network_move(
state: &GameState,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(iterations, network, c_puct);
tree.best_move()
}
// ═══════════════════════════════════════════════════════════
// Training data and self-play (§11, §13)
// ═══════════════════════════════════════════════════════════
struct TrainingExample {
state: GameState,
policy_target: [f32; 9],
value_target: f32,
}
/// Extracts the visit distribution from the root of an MCTS tree.
fn visit_distribution(tree: &MctsTree) -> [f32; 9] {
let root = &tree.nodes[0];
let total: u32 = root.children.iter()
.map(|&idx| tree.nodes[idx].visit_count)
.sum();
let mut dist = [0.0f32; 9];
if total == 0 {
return dist;
}
for &idx in &root.children {
let mv = tree.nodes[idx].move_from_parent.unwrap();
dist[mv] = tree.nodes[idx].visit_count as f32 / total as f32;
}
dist
}
fn select_move_with_temperature(
tree: &MctsTree,
temperature: f64,
rng: &mut impl Rng,
) -> usize {
let root = &tree.nodes[0];
if temperature < 0.01 {
return tree.best_move();
}
let inv_temp = 1.0 / temperature;
let scaled: Vec<f64> = root.children.iter()
.map(|&idx| {
(tree.nodes[idx].visit_count as f64).powf(inv_temp)
})
.collect();
let total: f64 = scaled.iter().sum();
if total == 0.0 {
return tree.best_move();
}
let mut r: f64 = rng.gen::<f64>() * total;
for (i, &s) in scaled.iter().enumerate() {
r -= s;
if r <= 0.0 {
let child_idx = root.children[i];
return tree.nodes[child_idx]
.move_from_parent
.expect("child must have a move");
}
}
let last = *root.children.last().unwrap();
tree.nodes[last].move_from_parent
.expect("child must have a move")
}
fn self_play_game(
network: &Network,
mcts_iterations: u32,
c_puct: f64,
temp_moves: usize,
rng: &mut impl Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
let mut move_number = 0;
while !state.is_terminal() {
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(mcts_iterations, network, c_puct);
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
let temperature = if move_number < temp_moves {
1.0
} else {
0.0
};
let chosen_move =
select_move_with_temperature(&tree, temperature, rng);
state = state.apply_move(chosen_move);
move_number += 1;
}
let winner = state.winner();
history.into_iter().map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample { state: s, policy_target, value_target }
}).collect()
}
fn shuffle_examples(examples: &mut Vec<TrainingExample>, rng: &mut impl Rng) {
for i in (1..examples.len()).rev() {
let j = rng.gen_range(0..=i);
examples.swap(i, j);
}
}
struct TrainingLoss {
policy_loss: f32,
value_loss: f32,
}
fn train_on_examples(
net: &mut Network,
examples: &mut Vec<TrainingExample>,
epochs: u32,
batch_size: usize,
learning_rate: f32,
rng: &mut impl Rng,
) -> TrainingLoss {
let mut final_policy_loss = 0.0f32;
let mut final_value_loss = 0.0f32;
for _epoch in 0..epochs {
shuffle_examples(examples, rng);
let mut epoch_policy_loss = 0.0f32;
let mut epoch_value_loss = 0.0f32;
let mut count = 0;
for batch_start in (0..examples.len()).step_by(batch_size) {
let batch_end =
(batch_start + batch_size).min(examples.len());
for i in batch_start..batch_end {
let ex = &examples[i];
let (pl, vl) = net.train_step(
&ex.state,
&ex.policy_target,
ex.value_target,
learning_rate,
);
epoch_policy_loss += pl;
epoch_value_loss += vl;
count += 1;
}
}
final_policy_loss = epoch_policy_loss / count as f32;
final_value_loss = epoch_value_loss / count as f32;
}
TrainingLoss {
policy_loss: final_policy_loss,
value_loss: final_value_loss,
}
}
fn play_evaluation_game(
network_x: &Network,
network_o: &Network,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut state = GameState::new();
while !state.is_terminal() {
let network = match state.current_player {
Player::X => network_x,
Player::O => network_o,
};
let mv = mcts_network_move(
&state, mcts_iterations, network, c_puct,
);
state = state.apply_move(mv);
}
match state.winner() {
Some(Player::X) => 1.0,
Some(Player::O) => -1.0,
None => 0.0,
}
}
fn evaluate_networks(
candidate: &Network,
current_best: &Network,
num_games: u32,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut candidate_wins = 0;
let mut total_decisive = 0;
for game in 0..num_games {
let result = if game % 2 == 0 {
play_evaluation_game(
candidate, current_best,
mcts_iterations, c_puct,
)
} else {
-play_evaluation_game(
current_best, candidate,
mcts_iterations, c_puct,
)
};
if result > 0.0 {
candidate_wins += 1;
total_decisive += 1;
} else if result < 0.0 {
total_decisive += 1;
}
}
if total_decisive == 0 {
0.5
} else {
candidate_wins as f64 / total_decisive as f64
}
}
// ═══════════════════════════════════════════════════════════
// Full training loop with diagnostics (§14)
// ═══════════════════════════════════════════════════════════
struct TrainingConfig {
num_iterations: u32,
games_per_iteration: u32,
mcts_iterations: u32,
c_puct: f64,
temp_moves: usize,
training_epochs: u32,
batch_size: usize,
learning_rate: f32,
eval_games: u32,
win_threshold: f64,
}
fn alphago_zero_training(
config: &TrainingConfig,
seed: u64,
) -> Network {
let mut rng = StdRng::seed_from_u64(seed);
let mut best_network = Network::new(rng.gen());
println!("Starting AlphaGo Zero training loop");
println!(" Iterations: {}", config.num_iterations);
println!(" Games/iter: {}", config.games_per_iteration);
println!(" MCTS iters: {}", config.mcts_iterations);
println!(" Total games: {}",
config.num_iterations * config.games_per_iteration);
println!();
for iteration in 1..=config.num_iterations {
// ── Phase 1: Self-play ─────────────────────────────
let mut all_examples = Vec::new();
let mut total_moves = 0u32;
let mut x_wins = 0u32;
let mut o_wins = 0u32;
let mut draws = 0u32;
for _game in 0..config.games_per_iteration {
let examples = self_play_game(
&best_network,
config.mcts_iterations,
config.c_puct,
config.temp_moves,
&mut rng,
);
let game_length = examples.len() as u32;
total_moves += game_length;
if let Some(last) = examples.last() {
match last.value_target.partial_cmp(&0.0) {
Some(std::cmp::Ordering::Greater) => {
match last.state.current_player {
Player::X => x_wins += 1,
Player::O => o_wins += 1,
}
}
Some(std::cmp::Ordering::Less) => {
match last.state.current_player {
Player::X => o_wins += 1,
Player::O => x_wins += 1,
}
}
_ => draws += 1,
}
}
all_examples.extend(examples);
}
let avg_game_length =
total_moves as f64 / config.games_per_iteration as f64;
// ── Phase 2: Train ─────────────────────────────────
let mut candidate = best_network.clone();
let loss = train_on_examples(
&mut candidate,
&mut all_examples,
config.training_epochs,
config.batch_size,
config.learning_rate,
&mut rng,
);
// ── Phase 3: Evaluate ──────────────────────────────
let win_rate = evaluate_networks(
&candidate,
&best_network,
config.eval_games,
config.mcts_iterations,
config.c_puct,
);
let accepted = win_rate >= config.win_threshold;
if accepted {
best_network = candidate;
}
// ── Print diagnostics ──────────────────────────────
let (policy, value) =
best_network.predict(&GameState::new());
let masked =
mask_illegal(&policy, &GameState::new());
let best_cell = masked.iter().enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap().0;
let cell_name = match best_cell {
0 => "top-left", 1 => "top-center",
2 => "top-right", 3 => "mid-left",
4 => "center", 5 => "mid-right",
6 => "bot-left", 7 => "bot-center",
8 => "bot-right", _ => "unknown",
};
println!(
"═══ Iteration {}/{} ═══",
iteration, config.num_iterations
);
println!(
" Self-play: {} games, {} examples",
config.games_per_iteration, all_examples.len()
);
println!(
" Game length: {:.1} moves (avg)",
avg_game_length
);
println!(
" Outcomes: X wins {} | O wins {} | Draws {}",
x_wins, o_wins, draws
);
println!(
" Training: policy_loss = {:.4} value_loss = {:.4}",
loss.policy_loss, loss.value_loss
);
println!(
" Evaluation: candidate win rate = {:.1}% → {}",
win_rate * 100.0,
if accepted { "ACCEPTED" } else { "REJECTED" }
);
println!(
" Empty board: best move = cell {} ({})",
best_cell, cell_name
);
println!(
" policy = [{:.2} {:.2} {:.2} {:.2} \
{:.2} {:.2} {:.2} {:.2} {:.2}]",
masked[0], masked[1], masked[2], masked[3],
masked[4], masked[5], masked[6], masked[7],
masked[8]
);
println!(
" value = {:+.3}",
value
);
println!();
}
best_network
}
fn main() {
// ── Train ──────────────────────────────────────────────
let config = TrainingConfig {
num_iterations: 20,
games_per_iteration: 50,
mcts_iterations: 200,
c_puct: 1.5,
temp_moves: 3,
training_epochs: 20,
batch_size: 32,
learning_rate: 0.01,
eval_games: 40,
win_threshold: 0.55,
};
let trained = alphago_zero_training(&config, 42);
// ── Validation 1: vs strong pure MCTS ──────────────────
println!("═══ Validation 1: Trained AI vs Pure MCTS (10k) ═══");
let mut rng = StdRng::seed_from_u64(123);
let mut val_draws = 0;
for game in 0..100 {
let mut state = GameState::new();
while !state.is_terminal() {
let mv = match state.current_player {
Player::X => {
if game % 2 == 0 {
mcts_network_move(
&state, 200, &trained, 1.5,
)
} else {
mcts_move(&state, 10_000, &mut rng)
}
}
Player::O => {
if game % 2 == 0 {
mcts_move(&state, 10_000, &mut rng)
} else {
mcts_network_move(
&state, 200, &trained, 1.5,
)
}
}
};
state = state.apply_move(mv);
}
if state.winner().is_none() {
val_draws += 1;
}
}
println!(" Result: {}/100 draws", val_draws);
if val_draws == 100 {
println!(" PASS — the network plays perfectly.");
} else {
println!(
" Some decisive games — network may need \
more training."
);
}
// ── Validation 2: opening move ─────────────────────────
println!("\n═══ Validation 2: Opening Move ═══");
let empty = GameState::new();
let mv = mcts_network_move(&empty, 200, &trained, 1.5);
println!(
" First move: cell {} (expected: 4 = center)", mv
);
if mv == 4 {
println!(" PASS — network plays center.");
} else {
println!(" UNEXPECTED — center is the optimal opening.");
}
// ── Validation 3: human vs AI ──────────────────────────
println!("\n═══ Validation 3: Human vs Trained AI ═══");
println!("You are X. Enter a cell number (0-8):");
println!(" 0 | 1 | 2");
println!(" ---------");
println!(" 3 | 4 | 5");
println!(" ---------");
println!(" 6 | 7 | 8");
println!();
let mut state = GameState::new();
while !state.is_terminal() {
state.print_board();
match state.current_player {
Player::X => {
print!("Your move (0-8): ");
std::io::Write::flush(
&mut std::io::stdout(),
).unwrap();
let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.unwrap();
let mv: usize = input
.trim()
.parse()
.expect("enter 0-8");
if !state.legal_moves().contains(&mv) {
println!("Illegal move. Try again.");
continue;
}
state = state.apply_move(mv);
}
Player::O => {
let mv = mcts_network_move(
&state, 200, &trained, 1.5,
);
println!("AI plays cell {}.", mv);
state = state.apply_move(mv);
}
}
}
state.print_board();
match state.winner() {
Some(Player::X) => {
println!("You win! (The AI has a bug.)")
}
Some(Player::O) => {
println!("AI wins. Better luck next time.")
}
None => println!(
"Draw. That is the best you can do against \
perfect play."
),
}
}
Congratulations
You have built a complete AlphaGo Zero implementation from scratch. Starting with nothing but the rules of Tic-Tac-Toe and a bag of random numbers, you watched an AI discover optimal play through pure self-improvement. No human games to study. No handcrafted evaluation functions. No opening books. Just a loop: play, learn, evaluate, repeat.
The network you trained has internalized something remarkable. It does not search through game trees at test time (MCTS does that). Instead, it has compressed the intuition of thousands of games into a set of floating-point weights. When it looks at an empty board and says “play center,” it is not following a rule someone programmed — it learned that from experience, the same way a human player would.
This is the same algorithm that, scaled up with deeper networks and more compute, mastered Go, chess, and shogi at a superhuman level. The architecture changes. The training budget grows by orders of magnitude. But the core idea — self-play generating training data, a neural network learning from that data, and an evaluation gate preventing regression — remains exactly what you built in this chapter.
Well done.
Building a Simple LLM from Scratch
A hands-on course building a small GPT-1-style language model in Rust — from raw text to a trained, sampling transformer.
Part 1 — Language Modeling Basics
§1 What is a Language Model?
A language model is a system that assigns probabilities to sequences of tokens. Given some context — a sequence of words, characters, or subwords that have appeared so far — a language model answers the question:
“What token is most likely to come next?”
Formally, a language model estimates the conditional probability distribution:
\[ P(t_{n+1} \mid t_1, t_2, \ldots, t_n) \]
where \( t_1, t_2, \ldots, t_n \) is the context (the tokens seen so far) and \( t_{n+1} \) is the next token.
A concrete example
Suppose we are building a character-level language model trained on English text. Given the context:
The cat sat on the m
Our model might produce a probability distribution like:
| Next character | Probability |
|---|---|
a | 0.55 |
o | 0.15 |
e | 0.10 |
i | 0.05 |
| (other) | 0.15 |
The model thinks a is most likely (leading to “mat”, “map”, “man”, etc.), followed by o (“moon”, “mop”, etc.). It learned these patterns from the statistics of its training data — it has never been told English grammar rules.
Autoregressive generation
Language models generate text autoregressively — one token at a time, feeding each generated token back in as context for the next prediction:
Step 1: "The cat" → predict 's'
Step 2: "The cats" → predict 'a'
Step 3: "The catsa" → predict 't'
...and so on
This loop of predict-then-append is the core mechanism behind every text-generating AI, from simple bigram models to GPT-4.
Why is language modeling useful?
Language modeling sounds like a narrow statistical task, but it turns out to be remarkably powerful:
- Text generation. Chatbots, story writers, and code assistants all generate text by sampling from a language model.
- Representation learning. Training a model to predict the next token forces it to learn deep representations of syntax, semantics, and even factual knowledge.
- Foundation for downstream tasks. Models pre-trained on language modeling can be fine-tuned for translation, summarisation, question answering, and more.
- Compression perspective. A good language model is a good compressor — it can represent text efficiently by encoding only the “surprising” tokens. This connection to information theory is why language modeling is such a fundamental problem.
The spectrum of language models
Language models range from the trivially simple to the extraordinarily complex:
Simple Complex
| |
Bigram ──── N-gram ──── RNN ──── LSTM ──── Transformer ── GPT-4
counts tables neural gated attention massive
net RNN mechanism scale
In this course, we will build a small Transformer-based language model — the same architecture family that powers GPT, Claude, and other modern LLMs. Ours will be tiny (a few thousand parameters), but it contains every essential component of its larger cousins.
Key takeaway: A language model predicts the next token given context. Despite sounding simple, this task — when scaled up — gives rise to the capabilities we see in modern AI systems.
§2 Character-Level Tokenisation
Before a language model can process text, we need to convert raw text into numbers. This process is called tokenisation — splitting text into discrete units (tokens) and mapping each to a numerical ID.
Three approaches to tokenisation
There are three main strategies, each with different tradeoffs:
| Strategy | Token unit | Vocabulary size | Example: “cats” |
|---|---|---|---|
| Character-level | Single characters | ~100 | ['c', 'a', 't', 's'] |
| Subword (BPE, etc.) | Character groups | ~30,000-100,000 | ['cat', 's'] |
| Word-level | Whole words | ~100,000+ | ['cats'] |
Word-level tokenisation is simple but struggles with misspellings, rare words, and morphological variation (“run”, “running”, “ran” are all separate tokens).
Subword methods like Byte-Pair Encoding (BPE) — used by GPT models — strike a balance: common words get their own token, while rare words are split into pieces. The word “unhappiness” might become ["un", "happiness"].
Character-level tokenisation is the simplest: every character in the text is its own token. The vocabulary is tiny — just the set of unique characters in the training data.
Why character-level for this course?
We use character-level tokenisation because:
- Simplicity. No external tokeniser library is needed. We can build it from scratch in a few lines of Rust.
- Small vocabulary. A typical English text corpus has fewer than 100 unique characters, which means smaller embedding tables and faster training.
- No unknown tokens. Any character in the input can be represented — there is no “out of vocabulary” problem.
- Educational clarity. It is easy to inspect what the model is learning when each token is a single visible character.
The downsides are that sequences become long (the word “language” is 8 tokens instead of 1) and the model must learn to spell words from individual characters. For our small educational model, these tradeoffs are acceptable.
Building a vocabulary
Given a training corpus, we construct a vocabulary by:
- Collecting all unique characters in the text.
- Sorting them (for deterministic ordering).
- Assigning each character a unique integer ID.
For example, given the text "hello world":
Unique characters (sorted): [' ', 'd', 'e', 'h', 'l', 'o', 'r', 'w']
Character → ID mapping:
' ' → 0
'd' → 1
'e' → 2
'h' → 3
'l' → 4
'o' → 5
'r' → 6
'w' → 7
Encoding and decoding
Encoding converts a string into a sequence of token IDs:
"hello" → [3, 2, 4, 4, 5]
Decoding converts a sequence of token IDs back into a string:
[7, 5, 6, 4, 1] → "world"
These two operations must be perfect inverses — decode(encode(text)) == text — or we lose information.
A note on special tokens
Production tokenisers often include special tokens like <PAD> (padding), <BOS> (beginning of sequence), and <EOS> (end of sequence). For our simple character-level tokeniser, we will not use these — every token corresponds to a real character in the text.
Key takeaway: Character-level tokenisation maps each character to an integer. It is the simplest tokenisation scheme and ideal for learning, though real LLMs use subword methods for efficiency.
§3 Exercise 1: Build a Character-Level Tokeniser in Rust
In this exercise, we build a CharTokeniser struct that can encode text into token IDs and decode them back.
Project setup
Create a new Rust project:
cargo new llm-from-scratch
cd llm-from-scratch
Your Cargo.toml needs only the standard library for now. We will add candle in later exercises:
[package]
name = "llm-from-scratch"
version = "0.1.0"
edition = "2021"
[dependencies]
# We will add candle-core and candle-nn later
[profile.release]
opt-level = "z"
lto = true
strip = true
codegen-units = 1
The CharTokeniser struct
Create src/tokeniser.rs:
#![allow(unused)]
fn main() {
use std::collections::HashMap;
/// A character-level tokeniser that maps individual characters
/// to integer IDs and back.
pub struct CharTokeniser {
/// Maps each character to its token ID.
char_to_id: HashMap<char, u32>,
/// Maps each token ID back to its character.
id_to_char: Vec<char>,
}
impl CharTokeniser {
/// Build a tokeniser from a training corpus.
///
/// Collects all unique characters, sorts them, and assigns
/// sequential IDs starting from 0.
///
/// # Example
/// ```
/// let tok = CharTokeniser::from_corpus("hello world");
/// assert_eq!(tok.vocab_size(), 8); // ' ', 'd', 'e', 'h', 'l', 'o', 'r', 'w'
/// ```
pub fn from_corpus(text: &str) -> Self {
let mut chars: Vec<char> = text.chars().collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
chars.sort();
let char_to_id: HashMap<char, u32> = chars
.iter()
.enumerate()
.map(|(i, &c)| (c, i as u32))
.collect();
CharTokeniser {
char_to_id,
id_to_char: chars,
}
}
/// Returns the number of unique tokens in the vocabulary.
pub fn vocab_size(&self) -> usize {
self.id_to_char.len()
}
/// Encode a string into a sequence of token IDs.
///
/// # Panics
/// Panics if the string contains a character not in the vocabulary.
pub fn encode(&self, text: &str) -> Vec<u32> {
text.chars()
.map(|c| {
*self.char_to_id
.get(&c)
.unwrap_or_else(|| panic!("Character '{}' not in vocabulary", c))
})
.collect()
}
/// Decode a sequence of token IDs back into a string.
///
/// # Panics
/// Panics if any token ID is out of range.
pub fn decode(&self, ids: &[u32]) -> String {
ids.iter()
.map(|&id| {
*self.id_to_char
.get(id as usize)
.unwrap_or_else(|| panic!("Token ID {} out of range", id))
})
.collect()
}
/// Print the vocabulary mapping for inspection.
pub fn print_vocab(&self) {
println!("Vocabulary ({} tokens):", self.vocab_size());
for (i, c) in self.id_to_char.iter().enumerate() {
let display = match c {
'\n' => "\\n".to_string(),
'\t' => "\\t".to_string(),
' ' => "' '".to_string(),
_ => format!("'{}'", c),
};
println!(" {} → {}", display, i);
}
}
}
}
Wiring it up in main.rs
In src/main.rs:
mod tokeniser;
use tokeniser::CharTokeniser;
fn main() {
let corpus = "\
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles.";
let tok = CharTokeniser::from_corpus(corpus);
tok.print_vocab();
let sample = "to be";
let encoded = tok.encode(sample);
println!("\nEncoded \"{}\": {:?}", sample, encoded);
let decoded = tok.decode(&encoded);
println!("Decoded back: \"{}\"", decoded);
// Verify round-trip
assert_eq!(decoded, sample);
println!("\nRound-trip check passed.");
}
Expected output
When you run cargo run, you should see something like:
Vocabulary (39 tokens):
'\n' → 0
' ' → 1
''' → 2
',' → 3
'.' → 4
':' → 5
'O' → 6
'T' → 7
'W' → 8
'a' → 9
...
Encoded "to be": [30, 21, 1, 12, 15]
Decoded back: "to be"
Round-trip check passed.
(Your exact IDs will depend on the characters present in the corpus.)
Exercises to try
- Extend the corpus. Download a larger text (e.g., from Project Gutenberg) and see how the vocabulary grows.
- Handle unknown characters. Modify
encodeto return anOptionor use a special<UNK>token instead of panicking. - Measure sequence length. Encode a paragraph and compare the number of tokens to the number of words. How much longer is the character-level encoding?
Key takeaway: A character-level tokeniser is just two hash maps — one from characters to IDs and one from IDs to characters. The
from_corpusmethod automatically builds the vocabulary from whatever text you give it.
Part 2 — The Transformer Architecture
§4 Embeddings and Positional Encoding
Our tokeniser converts text into a sequence of integer IDs. But neural networks work with continuous vectors, not discrete integers. Embeddings bridge this gap by mapping each token ID to a dense vector of floating-point numbers.
What is an embedding?
An embedding is a lookup table — a matrix of shape (vocab_size, embed_dim) where each row is a learnable vector representing one token.
Embedding table (vocab_size=5, embed_dim=4):
Token ID 0 → [ 0.12, -0.34, 0.56, 0.78]
Token ID 1 → [-0.91, 0.23, 0.45, -0.67]
Token ID 2 → [ 0.33, 0.11, -0.88, 0.54]
Token ID 3 → [ 0.76, -0.55, 0.22, 0.13]
Token ID 4 → [-0.42, 0.89, -0.11, 0.66]
Given the input token IDs [2, 0, 3], the embedding layer simply looks up rows 2, 0, and 3:
Input: [2, 0, 3]
Output: [[ 0.33, 0.11, -0.88, 0.54], ← row 2
[ 0.12, -0.34, 0.56, 0.78], ← row 0
[ 0.76, -0.55, 0.22, 0.13]] ← row 3
The result is a matrix of shape (sequence_length, embed_dim). Initially, these vectors are random. During training, backpropagation adjusts them so that tokens with similar meanings end up with similar vectors.
Why do we need embeddings?
Integer IDs have no inherent structure — the fact that “a” is token 0 and “b” is token 1 does not mean they are “close” in any useful sense. Embeddings give the model a continuous space where it can represent relationships. After training, you might find that:
- Vowel characters cluster together.
- Uppercase and lowercase versions of the same letter are nearby.
- Punctuation characters form their own cluster.
The embedding dimension
The embedding dimension (embed_dim or d_model) is a hyperparameter you choose. It controls how much information each token vector can carry:
| Embed dim | Capacity | Training cost | Typical use |
|---|---|---|---|
| 16-64 | Low | Very fast | Toy models (like ours) |
| 128-512 | Medium | Moderate | Small-scale experiments |
| 768-4096 | High | Expensive | Production LLMs |
For our tiny model, we will use embed_dim = 64.
The position problem
Consider two sentences:
"The cat ate the fish"
"The fish ate the cat"
These contain the exact same tokens but have very different meanings. If we only embed the tokens, the model has no way to tell the two sentences apart — it does not know the order of the tokens.
Positional encoding
To inject position information, we add a positional encoding vector to each token embedding. The original Transformer paper (“Attention Is All You Need”) proposed using fixed sinusoidal functions:
\[ PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \] \[ PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \]
Where:
posis the position in the sequence (0, 1, 2, …).iis the dimension index.d_modelis the embedding dimension.
These functions produce a unique pattern for each position. Low-frequency sinusoids change slowly across positions (capturing coarse position), while high-frequency sinusoids change rapidly (capturing fine position).
Position 0: [sin(0), cos(0), sin(0), cos(0), ...] = [0.00, 1.00, 0.00, 1.00, ...]
Position 1: [sin(1), cos(1), sin(ε), cos(ε), ...] = [0.84, 0.54, 0.01, 1.00, ...]
Position 2: [sin(2), cos(2), sin(2ε), cos(2ε),...] = [0.91, -0.42, 0.02, 1.00, ...]
(Here ε = 1/10000^(2/d_model), a very small number for higher dimensions.)
Learned positional embeddings
GPT-1 and GPT-2 use a simpler approach: a second embedding table of shape (max_seq_len, embed_dim) where each position gets its own learnable vector, just like tokens do. This is what we will implement:
Final input = token_embedding[token_id] + position_embedding[position]
Both the token embeddings and position embeddings are learned during training.
Putting it all together
The input pipeline for our model looks like this:
"hello" → [3, 2, 4, 4, 5] (tokenisation)
→ [[0.12, ...], [-0.91, ...], (token embedding lookup)
[0.33, ...], [0.33, ...],
[-0.42, ...]]
+ [[0.05, ...], [0.11, ...], (position embedding lookup)
[0.22, ...], [0.08, ...],
[0.17, ...]]
= [[0.17, ...], [-0.80, ...], (element-wise addition)
[0.55, ...], [0.41, ...],
[-0.25, ...]]
The result is a (seq_len, embed_dim) matrix that carries both what each token is and where it appears. This matrix is what we feed into the Transformer blocks.
Key takeaway: Embeddings convert token IDs into learnable vectors. Positional encodings add position information so the model can distinguish “the cat ate the fish” from “the fish ate the cat”. Together, they form the input to the Transformer.
§5 Self-Attention: Queries, Keys, and Values
Self-attention is the mechanism that makes Transformers work. It allows every token in a sequence to look at every other token and decide which ones are relevant. This is the single most important concept in this course.
The intuition: a database lookup
Think of self-attention as a soft database lookup:
- Each token formulates a query: “What kind of information am I looking for?”
- Each token advertises a key: “Here is what kind of information I have.”
- Each token holds a value: “Here is my actual information.”
To process a token, we compare its query against all keys. Where there is a strong match, we pull in the corresponding value. The result is a weighted combination of all values, where the weights reflect how relevant each token is.
Token: "sat"
Query: "I'm a verb — who is my subject?"
Keys available:
"The" → "I'm a determiner" (low match)
"cat" → "I'm a noun/subject" (HIGH match)
"sat" → "I'm a verb" (medium match)
"on" → "I'm a preposition" (low match)
Result: mostly attend to "cat", somewhat to "sat", barely to others
Of course, the model does not literally think in words — these “queries” and “keys” are learned vector representations. But the analogy captures the mechanism.
The math: scaled dot-product attention
Given a sequence of n token embeddings (each of dimension d_model), self-attention works as follows:
Step 1: Project into Q, K, V.
We use three learned weight matrices \( W_Q, W_K, W_V \) (each of shape d_model × d_k) to produce:
\[ Q = X W_Q, \quad K = X W_K, \quad V = X W_V \]
Where \( X \) is the input matrix of shape (n, d_model) and \( Q, K, V \) are each of shape (n, d_k).
Step 2: Compute attention scores.
We compute the dot product of each query with all keys:
\[ \text{scores} = Q K^T \]
This produces an (n, n) matrix where entry (i, j) measures how much token i should attend to token j.
Step 3: Scale.
We divide by \( \sqrt{d_k} \) to prevent the dot products from becoming too large (which would push softmax into regions with tiny gradients):
\[ \text{scaled_scores} = \frac{Q K^T}{\sqrt{d_k}} \]
Step 4: Softmax.
We apply softmax row-wise so that each row sums to 1, giving us a proper probability distribution:
\[ \text{attention_weights} = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) \]
Step 5: Weighted sum of values.
We multiply the attention weights by the values:
\[ \text{output} = \text{attention_weights} \times V \]
The complete formula is:
\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V \]
A small numeric example
Let us walk through attention with a tiny example. Suppose we have 3 tokens with d_k = 2:
Q (queries): K (keys): V (values):
[1.0, 0.0] [1.0, 0.0] [1.0, 0.0]
[0.0, 1.0] [0.0, 1.0] [0.0, 1.0]
[1.0, 1.0] [1.0, 1.0] [0.5, 0.5]
Step 2 — QK^T:
K^T = [1.0 0.0 1.0]
[0.0 1.0 1.0]
QK^T = [1.0*1.0+0.0*0.0 1.0*0.0+0.0*1.0 1.0*1.0+0.0*1.0] [1.0 0.0 1.0]
[0.0*1.0+1.0*0.0 0.0*0.0+1.0*1.0 0.0*1.0+1.0*1.0] = [0.0 1.0 1.0]
[1.0*1.0+1.0*0.0 1.0*0.0+1.0*1.0 1.0*1.0+1.0*1.0] [1.0 1.0 2.0]
Step 3 — Scale by √d_k = √2 ≈ 1.414:
Scaled = [0.71 0.00 0.71]
[0.00 0.71 0.71]
[0.71 0.71 1.41]
Step 4 — Softmax (row-wise):
Row 0: softmax([0.71, 0.00, 0.71]) ≈ [0.39, 0.19, 0.42]
Row 1: softmax([0.00, 0.71, 0.71]) ≈ [0.19, 0.39, 0.42]
Row 2: softmax([0.71, 0.71, 1.41]) ≈ [0.24, 0.24, 0.52]
Step 5 — Multiply by V:
Output row 0 = 0.39*[1,0] + 0.19*[0,1] + 0.42*[0.5,0.5] = [0.60, 0.40]
Output row 1 = 0.19*[1,0] + 0.39*[0,1] + 0.42*[0.5,0.5] = [0.40, 0.60]
Output row 2 = 0.24*[1,0] + 0.24*[0,1] + 0.52*[0.5,0.5] = [0.50, 0.50]
Token 2 (whose query was [1,1] — “I want everything”) ends up with a balanced mixture [0.50, 0.50]. Token 0 (query [1,0]) ends up leaning toward the first dimension [0.60, 0.40]. The attention mechanism has routed information according to what each token asked for.
Why self-attention works
Self-attention has two properties that make it powerful:
-
Global context. Every token can attend to every other token in a single step. In an RNN, information must flow through many sequential steps to get from one end of the sequence to the other.
-
Content-based routing. Which tokens to attend to is determined by the content (via Q and K), not by fixed connectivity patterns. The model learns to route information dynamically.
RNN: information flows sequentially Attention: direct connections
t1 → t2 → t3 → t4 → t5 t1 ←→ t2 ←→ t3 ←→ t4 ←→ t5
(5 steps from t1 to t5) (1 step between any pair)
Key takeaway: Self-attention computes a weighted sum of value vectors, where the weights are determined by the similarity between query and key vectors. The formula \(\text{softmax}(QK^T / \sqrt{d_k}) V\) is the mathematical heart of the Transformer.
§6 The Transformer Block
A single self-attention layer is powerful but not sufficient. The Transformer block wraps attention in a series of components that make training stable and learning more expressive.
Components of a Transformer block
A Transformer block contains four components, applied in order:
- Multi-head self-attention
- Residual connection + Layer normalisation
- Feed-forward network (FFN)
- Residual connection + Layer normalisation
Here is the full block in ASCII art:
┌──────────────────────┐
│ Input (x) │
└──────────┬───────────┘
│
├──────────────────────┐
▼ │ (residual)
┌──────────────────────┐ │
│ Multi-Head │ │
│ Self-Attention │ │
└──────────┬───────────┘ │
│ │
▼ │
(+) ◄────────────────────┘
│
▼
┌──────────────────────┐
│ Layer Norm │
└──────────┬───────────┘
│
├──────────────────────┐
▼ │ (residual)
┌──────────────────────┐ │
│ Feed-Forward │ │
│ Network (FFN) │ │
└──────────┬───────────┘ │
│ │
▼ │
(+) ◄────────────────────┘
│
▼
┌──────────────────────┐
│ Layer Norm │
└──────────┬───────────┘
│
▼
┌──────────────────────┐
│ Output │
└──────────────────────┘
Let us look at each component.
Multi-head self-attention
Instead of running a single attention computation, multi-head attention runs multiple attention heads in parallel, each with its own \( W_Q, W_K, W_V \) matrices. Each head can learn to focus on different types of relationships:
- Head 1 might learn syntactic relationships (subject-verb agreement).
- Head 2 might learn positional proximity (nearby characters).
- Head 3 might learn semantic similarity.
If d_model = 64 and we use 4 heads, each head operates on d_k = d_model / num_heads = 16 dimensions. The outputs of all heads are concatenated and projected back to d_model:
\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W_O \]
where each \( \text{head}_i = \text{Attention}(X W_Q^i, X W_K^i, X W_V^i) \).
Residual connections
A residual connection (or skip connection) adds the input of a sublayer directly to its output:
\[ \text{output} = \text{sublayer}(x) + x \]
This seemingly simple trick is crucial for training deep networks. It ensures that gradients can flow directly through the network during backpropagation, preventing the vanishing gradient problem that plagues deep architectures. Even if a sublayer learns nothing useful, the residual connection ensures the signal passes through unchanged.
Layer normalisation
Layer normalisation normalises the values across the feature dimension for each token independently:
\[ \text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \]
Where \( \mu \) and \( \sigma^2 \) are the mean and variance computed across the feature dimension, and \( \gamma, \beta \) are learnable scale and shift parameters.
Layer norm keeps activations in a reasonable range, which stabilises training. Without it, values can explode or collapse as they pass through many layers.
Feed-forward network
The FFN is a simple two-layer neural network applied independently to each token position:
\[ \text{FFN}(x) = \text{GELU}(x W_1 + b_1) W_2 + b_2 \]
The inner dimension is typically 4x the model dimension (e.g., d_model = 64 → d_ff = 256). The GELU activation function is a smooth approximation of ReLU used in GPT models.
The FFN is where much of the model’s “knowledge” is stored. While attention determines what information to combine, the FFN transforms that information.
Pre-norm vs post-norm
The original Transformer used post-norm (normalise after the residual add). GPT-2 and most modern models use pre-norm (normalise before the sublayer). Pre-norm is more stable to train:
Post-norm (original): LayerNorm(x + Sublayer(x))
Pre-norm (GPT-2): x + Sublayer(LayerNorm(x))
We will use pre-norm in our implementation.
Key takeaway: A Transformer block combines multi-head attention (for context mixing), a feed-forward network (for per-token transformation), residual connections (for gradient flow), and layer normalisation (for training stability). These four ingredients, repeated many times, give Transformers their power.
§7 Exercise 2: Implement Self-Attention in Rust
In this exercise, we implement single-head scaled dot-product self-attention using the candle tensor library. This is the core computation from §5 translated into Rust.
Adding candle to your project
Update your Cargo.toml:
[package]
name = "llm-from-scratch"
version = "0.1.0"
edition = "2021"
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
anyhow = "1"
[profile.release]
opt-level = "z"
lto = true
strip = true
codegen-units = 1
Single-head self-attention
Create src/attention.rs:
#![allow(unused)]
fn main() {
use candle_core::{Device, Result, Tensor, D};
/// Compute single-head scaled dot-product self-attention.
///
/// # Arguments
/// * `q` - Query tensor of shape `(seq_len, d_k)`
/// * `k` - Key tensor of shape `(seq_len, d_k)`
/// * `v` - Value tensor of shape `(seq_len, d_k)`
///
/// # Returns
/// Output tensor of shape `(seq_len, d_k)` — the attention-weighted
/// combination of values.
pub fn scaled_dot_product_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
) -> Result<Tensor> {
let d_k = q.dim(D::Minus1)? as f64;
// Step 1: QK^T — compute attention scores
// q: (seq_len, d_k), k^T: (d_k, seq_len) → scores: (seq_len, seq_len)
let scores = q.matmul(&k.t()?)?;
// Step 2: Scale by sqrt(d_k)
let scaled = (scores / d_k.sqrt())?;
// Step 3: Softmax along the last dimension (row-wise)
let weights = candle_nn::ops::softmax(&scaled, D::Minus1)?;
// Step 4: Weighted sum of values
// weights: (seq_len, seq_len) × v: (seq_len, d_k) → (seq_len, d_k)
let output = weights.matmul(v)?;
Ok(output)
}
/// Project input through a linear layer (matrix multiply) to produce Q, K, or V.
///
/// # Arguments
/// * `x` - Input tensor of shape `(seq_len, d_model)`
/// * `w` - Weight matrix of shape `(d_model, d_k)`
///
/// # Returns
/// Projected tensor of shape `(seq_len, d_k)`.
pub fn project(x: &Tensor, w: &Tensor) -> Result<Tensor> {
x.matmul(w)
}
}
Testing it in main.rs
Add the module and a test function to src/main.rs:
mod attention;
mod tokeniser;
use anyhow::Result;
use candle_core::{Device, Tensor};
fn demo_attention() -> Result<()> {
let device = &Device::Cpu;
// Simulate 4 token embeddings, each of dimension 8
let seq_len = 4;
let d_model = 8;
let d_k = 8; // Same as d_model for single-head
// Random input "embeddings"
let x = Tensor::randn(0f32, 1.0, (seq_len, d_model), device)?;
// Random projection weights (in a real model, these are learned)
let w_q = Tensor::randn(0f32, 1.0, (d_model, d_k), device)?;
let w_k = Tensor::randn(0f32, 1.0, (d_model, d_k), device)?;
let w_v = Tensor::randn(0f32, 1.0, (d_model, d_k), device)?;
// Project input into Q, K, V
let q = attention::project(&x, &w_q)?;
let k = attention::project(&x, &w_k)?;
let v = attention::project(&x, &w_v)?;
println!("Input shape: {:?}", x.shape());
println!("Q shape: {:?}", q.shape());
println!("K shape: {:?}", k.shape());
println!("V shape: {:?}", v.shape());
// Compute attention
let output = attention::scaled_dot_product_attention(&q, &k, &v)?;
println!("Output shape: {:?}", output.shape());
println!("\nAttention output:\n{}", output);
Ok(())
}
fn main() {
if let Err(e) = demo_attention() {
eprintln!("Error: {}", e);
}
}
Expected output
Input shape: [4, 8]
Q shape: [4, 8]
K shape: [4, 8]
V shape: [4, 8]
Output shape: [4, 8]
Attention output:
[[ 0.1234, -0.5678, ...],
[ 0.2345, -0.4567, ...],
[ 0.3456, -0.3456, ...],
[ 0.4567, -0.2345, ...]]
(Your exact numbers will differ because of random initialisation.)
What to observe
After running the code, notice:
- Shape preservation. The output has the same shape as the input —
(seq_len, d_k). Each token position gets a new vector that is a weighted combination of all value vectors. - Row similarity. The output rows tend to be more similar to each other than the input rows. This is because attention mixes information across all positions.
- Softmax effect. If you print the attention weights (the output of softmax), you will see that each row sums to 1.0 and typically has one or two dominant values.
Exercises to try
- Print the attention weights matrix. After the softmax step, print the
(seq_len, seq_len)weight matrix. Which tokens attend most strongly to which? - Add a causal mask. Before softmax, set the upper-triangle entries of the score matrix to negative infinity. This prevents each position from attending to future positions. (Hint: use
Tensor::onesandTensor::trilto build a mask.) - Compare with and without scaling. Remove the
/ d_k.sqrt()and observe how the attention weights change — they should become much more “peaky” (concentrated on one token).
Key takeaway: Self-attention in code is just three matrix multiplications (to project Q, K, V), one more multiply (QK^T), a scale, a softmax, and a final multiply by V. The
candlecrate provides all the tensor operations we need.
Part 3 — Assembling the Model
§8 A Decoder-Only LM: Stacking Blocks and the Causal Mask
We now have all the pieces — embeddings, positional encoding, and Transformer blocks. It is time to assemble them into a complete language model. We will build a decoder-only Transformer, the architecture used by GPT-1, GPT-2, GPT-3, and many other LLMs.
Why “decoder-only”?
The original Transformer (2017) had two halves:
- An encoder that processes an input sequence (e.g., a French sentence).
- A decoder that generates an output sequence (e.g., the English translation), attending to both itself and the encoder output.
GPT (2018) showed that you only need the decoder half. By training the decoder to predict the next token in a single sequence, you get a general-purpose language model. No encoder, no cross-attention — just self-attention with a causal mask.
Original Transformer: Decoder-only (GPT):
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Encoder │→│ Decoder │ │ Decoder │
│ (bidir.) │ │ (causal) │ │ (causal) │
└──────────┘ └──────────┘ └──────────┘
Used for: translation Used for: generation
The causal mask
In a decoder-only model, each token can only attend to tokens at or before its position — never to future tokens. This is essential because during generation, future tokens do not exist yet.
We enforce this with a causal mask (also called a “look-ahead mask”) — a lower-triangular matrix that blocks attention to future positions:
Causal mask for sequence length 5:
t0 t1 t2 t3 t4
t0 [ 1 0 0 0 0 ] ← t0 can only see t0
t1 [ 1 1 0 0 0 ] ← t1 can see t0, t1
t2 [ 1 1 1 0 0 ] ← t2 can see t0, t1, t2
t3 [ 1 1 1 1 0 ] ← t3 can see t0, t1, t2, t3
t4 [ 1 1 1 1 1 ] ← t4 can see everything
In practice, we set the masked positions (zeros above) to \( -\infty \) before the softmax step. Since \( \text{softmax}(-\infty) = 0 \), those positions get zero attention weight.
\[ \text{MaskedAttention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}} + M\right) V \]
where \( M \) has 0 for allowed positions and \( -\infty \) for blocked positions.
The full model architecture
Our GPT-1-style model stacks all the components in this order:
┌─────────────────────────────────────┐
│ Input Token IDs │
│ [4, 2, 7, 1] │
└──────────────┬──────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ Token Embedding (lookup) │
│ + Position Embedding │
│ → (seq_len, d_model) │
└──────────────┬──────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ Transformer Block 1 │
│ ┌─ Masked Multi-Head Attention ─┐ │
│ │ + Residual + LayerNorm │ │
│ │ FFN + Residual + LayerNorm │ │
│ └───────────────────────────────┘ │
└──────────────┬──────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ Transformer Block 2 │
│ (same structure as Block 1) │
└──────────────┬──────────────────────┘
│
▼
... (N blocks total)
│
▼
┌─────────────────────────────────────┐
│ Final Layer Norm │
└──────────────┬──────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ Linear Projection (no bias) │
│ (d_model → vocab_size) │
│ Output: logits per token │
│ → (seq_len, vocab_size) │
└─────────────────────────────────────┘
The output logits are raw (unnormalised) scores for each token in the vocabulary, at each position in the sequence. To get probabilities, we apply softmax. To get a loss, we compare these logits against the actual next tokens using cross-entropy.
Hyperparameters for our model
We keep things small enough to train on a CPU in seconds:
| Hyperparameter | Value | Description |
|---|---|---|
vocab_size | ~65 | Number of unique characters (depends on corpus) |
d_model | 64 | Embedding dimension |
n_heads | 4 | Number of attention heads |
n_layers | 2 | Number of Transformer blocks |
d_ff | 256 | FFN inner dimension (4 × d_model) |
max_seq_len | 128 | Maximum sequence length |
This gives roughly 100K parameters — tiny by modern standards, but sufficient to learn character-level patterns from a small corpus.
How generation works
Once trained, we generate text autoregressively:
- Start with a prompt (e.g., “The “).
- Encode it to token IDs.
- Run the model to get logits for the next position.
- Sample a token from the probability distribution (or take the argmax).
- Append the sampled token to the sequence.
- Repeat from step 3.
Step 1: "The " → model → next token probabilities → sample 'c'
Step 2: "The c" → model → next token probabilities → sample 'a'
Step 3: "The ca" → model → next token probabilities → sample 't'
Step 4: "The cat" → model → next token probabilities → sample ' '
...
We only need the logits at the last position to generate the next token, but the model processes the entire sequence at once (which is efficient during training).
Key takeaway: A decoder-only language model is a stack of Transformer blocks with causal masking, sandwiched between an embedding layer and a linear output projection. The causal mask ensures each position can only attend to past tokens, enabling autoregressive generation.
§9 Exercise 3: Define the GPT-1-Style Model in candle
In this exercise, we define the full model architecture in Rust using candle. We will build the model struct by struct, from the bottom up.
The overall structure
We need these components:
CausalSelfAttention— multi-head attention with causal maskingFeedForward— the two-layer FFNTransformerBlock— attention + FFN with residual connections and layer normGpt1Model— the full model: embeddings, N blocks, final projection
Configuration
First, define a config struct in src/model.rs:
#![allow(unused)]
fn main() {
use candle_core::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear, linear_no_bias, Embedding, LayerNorm,
Linear, Module, VarBuilder,
};
/// Configuration for our GPT-1-style model.
#[derive(Clone)]
pub struct GptConfig {
pub vocab_size: usize,
pub d_model: usize,
pub n_heads: usize,
pub n_layers: usize,
pub d_ff: usize,
pub max_seq_len: usize,
}
impl GptConfig {
/// A tiny configuration suitable for CPU training.
pub fn tiny(vocab_size: usize) -> Self {
GptConfig {
vocab_size,
d_model: 64,
n_heads: 4,
n_layers: 2,
d_ff: 256,
max_seq_len: 128,
}
}
}
}
Causal self-attention
#![allow(unused)]
fn main() {
/// Multi-head causal self-attention.
pub struct CausalSelfAttention {
qkv_proj: Linear,
out_proj: Linear,
n_heads: usize,
d_k: usize,
}
impl CausalSelfAttention {
pub fn new(cfg: &GptConfig, vb: VarBuilder) -> Result<Self> {
let d_k = cfg.d_model / cfg.n_heads;
// Project Q, K, V in a single linear layer for efficiency.
let qkv_proj = linear(cfg.d_model, 3 * cfg.d_model, vb.pp("qkv_proj"))?;
let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
Ok(Self { qkv_proj, out_proj, n_heads: cfg.n_heads, d_k })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (seq_len, d_model) = (x.dim(0)?, x.dim(1)?);
// Project to Q, K, V in one operation, then split
let qkv = self.qkv_proj.forward(x)?; // (seq_len, 3 * d_model)
let q = qkv.narrow(1, 0, d_model)?;
let k = qkv.narrow(1, d_model, d_model)?;
let v = qkv.narrow(1, 2 * d_model, d_model)?;
// Reshape for multi-head: (seq_len, n_heads, d_k) then transpose
// to (n_heads, seq_len, d_k) for batched attention
let q = q.reshape((seq_len, self.n_heads, self.d_k))?
.transpose(0, 1)?;
let k = k.reshape((seq_len, self.n_heads, self.d_k))?
.transpose(0, 1)?;
let v = v.reshape((seq_len, self.n_heads, self.d_k))?
.transpose(0, 1)?;
// Scaled dot-product attention: (n_heads, seq_len, seq_len)
let scale = (self.d_k as f64).sqrt();
let scores = q.matmul(&k.transpose(1, 2)?)?.affine(1.0 / scale, 0.0)?;
// Causal mask: set future positions to -inf
let mask = Tensor::ones((seq_len, seq_len), DType::F32, x.device())?
.tril(0)?;
let neg_inf = Tensor::ones_like(&mask)?
.affine(-1e9, 0.0)?
.affine(1.0, 0.0)?;
let mask = mask.where_cond(
&scores.broadcast_left(self.n_heads)?.squeeze(0)?,
&neg_inf,
);
// Simpler approach: build additive mask
let additive_mask = Tensor::zeros((seq_len, seq_len), DType::F32, x.device())?;
// We need upper-triangular part to be -inf
let ones = Tensor::ones((seq_len, seq_len), DType::F32, x.device())?;
let causal = ones.tril(0)?; // lower triangle = 1, upper = 0
// Convert: where causal==0, set to -1e9; where causal==1, set to 0
let additive_mask = ((causal.affine(-1.0, 1.0))? // 0→1, 1→0
.affine(1e9, 0.0))? // 0→0, 1→1e9
.affine(-1.0, 0.0)?; // 0→0, 1→-1e9
let masked_scores = scores.broadcast_add(&additive_mask)?;
let weights = candle_nn::ops::softmax(&masked_scores, D::Minus1)?;
// Weighted sum of values
let attn_out = weights.matmul(&v)?; // (n_heads, seq_len, d_k)
// Reshape back: transpose → (seq_len, n_heads, d_k) → (seq_len, d_model)
let attn_out = attn_out.transpose(0, 1)?
.reshape((seq_len, d_model))?;
// Output projection
self.out_proj.forward(&attn_out)
}
}
}
Feed-forward network
#![allow(unused)]
fn main() {
/// Position-wise feed-forward network with GELU activation.
pub struct FeedForward {
up: Linear,
down: Linear,
}
impl FeedForward {
pub fn new(cfg: &GptConfig, vb: VarBuilder) -> Result<Self> {
let up = linear(cfg.d_model, cfg.d_ff, vb.pp("up"))?;
let down = linear(cfg.d_ff, cfg.d_model, vb.pp("down"))?;
Ok(Self { up, down })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let h = self.up.forward(x)?.gelu()?;
self.down.forward(&h)
}
}
}
Transformer block (pre-norm)
#![allow(unused)]
fn main() {
/// A single Transformer block with pre-norm residual connections.
pub struct TransformerBlock {
attn: CausalSelfAttention,
ffn: FeedForward,
ln1: LayerNorm,
ln2: LayerNorm,
}
impl TransformerBlock {
pub fn new(cfg: &GptConfig, vb: VarBuilder) -> Result<Self> {
let attn = CausalSelfAttention::new(cfg, vb.pp("attn"))?;
let ffn = FeedForward::new(cfg, vb.pp("ffn"))?;
let ln1 = layer_norm(cfg.d_model, Default::default(), vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.d_model, Default::default(), vb.pp("ln2"))?;
Ok(Self { attn, ffn, ln1, ln2 })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
// Pre-norm: x + Attn(LayerNorm(x))
let residual = x;
let h = self.ln1.forward(x)?;
let h = self.attn.forward(&h)?;
let x = (residual + h)?;
// Pre-norm: x + FFN(LayerNorm(x))
let residual = &x;
let h = self.ln2.forward(&x)?;
let h = self.ffn.forward(&h)?;
(residual + h)
}
}
}
The full GPT model
#![allow(unused)]
fn main() {
/// A small GPT-1-style language model.
pub struct Gpt1Model {
token_emb: Embedding,
pos_emb: Embedding,
blocks: Vec<TransformerBlock>,
ln_f: LayerNorm,
lm_head: Linear,
}
impl Gpt1Model {
pub fn new(cfg: &GptConfig, vb: VarBuilder) -> Result<Self> {
let token_emb = embedding(cfg.vocab_size, cfg.d_model, vb.pp("token_emb"))?;
let pos_emb = embedding(cfg.max_seq_len, cfg.d_model, vb.pp("pos_emb"))?;
let mut blocks = Vec::with_capacity(cfg.n_layers);
for i in 0..cfg.n_layers {
blocks.push(TransformerBlock::new(cfg, vb.pp(format!("block_{}", i)))?);
}
let ln_f = layer_norm(cfg.d_model, Default::default(), vb.pp("ln_f"))?;
let lm_head = linear_no_bias(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self { token_emb, pos_emb, blocks, ln_f, lm_head })
}
/// Forward pass: token IDs → logits.
///
/// # Arguments
/// * `token_ids` - 1D tensor of shape `(seq_len,)` with token IDs.
///
/// # Returns
/// Logits tensor of shape `(seq_len, vocab_size)`.
pub fn forward(&self, token_ids: &Tensor) -> Result<Tensor> {
let seq_len = token_ids.dim(0)?;
// Create position indices [0, 1, 2, ..., seq_len-1]
let positions = Tensor::arange(0u32, seq_len as u32, token_ids.device())?;
// Embed tokens and positions, then add
let tok_emb = self.token_emb.forward(token_ids)?;
let pos_emb = self.pos_emb.forward(&positions)?;
let mut x = (tok_emb + pos_emb)?;
// Pass through all Transformer blocks
for block in &self.blocks {
x = block.forward(&x)?;
}
// Final layer norm + projection to vocabulary
let x = self.ln_f.forward(&x)?;
self.lm_head.forward(&x)
}
}
}
Testing the model
In main.rs:
mod model;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarMap;
fn main() -> anyhow::Result<()> {
let device = &Device::Cpu;
let varmap = VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, device);
let cfg = model::GptConfig::tiny(65); // 65 characters in typical Shakespeare
let model = model::Gpt1Model::new(&cfg, vb)?;
// Create a dummy input: 10 token IDs
let input = Tensor::new(&[0u32, 1, 2, 3, 4, 5, 6, 7, 8, 9], device)?;
let logits = model.forward(&input)?;
println!("Input shape: {:?}", input.shape());
println!("Output logits shape: {:?}", logits.shape());
// Should be (10, 65) — 10 positions, 65 vocabulary scores each
Ok(())
}
What to observe
- The output shape should be
(seq_len, vocab_size)— one set of logits per input position. - With random weights, the logits will be meaningless noise. Training (next section) will make them meaningful.
- The model processes the entire sequence in parallel — this is the advantage of Transformers over RNNs.
Key takeaway: Our GPT model is built from composable structs:
CausalSelfAttention,FeedForward,TransformerBlock, andGpt1Model. Each handles one concern, and thecandlecrate provides the tensor operations and automatic differentiation we need for training.
Part 4 — Training
§10 Cross-Entropy Loss and the Training Loop
We have a model that takes token IDs and outputs logits. Now we need to teach it to output the right logits — the ones that predict the next token accurately. This is where training comes in.
The training objective
Recall that our model outputs logits of shape (seq_len, vocab_size) — a score for every token in the vocabulary, at every position. The training target is simple: at each position i, the correct next token is position i + 1 in the input.
Input: [T, h, e, , c, a, t]
Target: [h, e, , c, a, t, .]
Position: 0 1 2 3 4 5 6
At position 0, where the input is “T”, the model should predict “h”. At position 1, it should predict “e”, and so on. We shift the input by one to create the targets.
Cross-entropy loss
Cross-entropy loss measures how far the model’s predicted probability distribution is from the true distribution (where all probability mass is on the correct token).
For a single position where the correct token has index \( y \):
\[ \mathcal{L} = -\log P(y) = -\log \frac{e^{z_y}}{\sum_j e^{z_j}} \]
Where \( z_j \) are the logits. Intuitively:
- If the model assigns high probability to the correct token, \( -\log P(y) \) is small (close to 0). Good.
- If the model assigns low probability, \( -\log P(y) \) is large. Bad.
We average this loss over all positions in the sequence and all sequences in the batch.
A concrete example
Suppose our vocabulary is ['a', 'b', 'c'] and the correct next token is 'b' (index 1). The model outputs logits:
Logits: [2.0, 5.0, 1.0]
After softmax: [0.05, 0.93, 0.02] (e^2 / sum, e^5 / sum, e^1 / sum)
Loss: -log(0.93) = 0.07 (low loss — model is confident and correct)
If the model were wrong:
Logits: [5.0, 1.0, 2.0]
After softmax: [0.93, 0.02, 0.05]
Loss: -log(0.02) = 3.91 (high loss — model is confident but wrong)
Gradient descent
To minimise the loss, we use gradient descent. The idea is:
- Compute the loss for a batch of data.
- Compute the gradient of the loss with respect to every model parameter (backpropagation).
- Update each parameter by subtracting a small multiple of its gradient: \[ \theta \leftarrow \theta - \eta \nabla_\theta \mathcal{L} \] where \( \eta \) is the learning rate.
The learning rate is a critical hyperparameter:
- Too high: training is unstable, loss oscillates or diverges.
- Too low: training is painfully slow.
- A typical starting value for small models:
1e-3to3e-4.
The training loop
The training loop repeats the following steps for many epochs (passes through the entire dataset):
for epoch in 1..=num_epochs:
for batch in dataset:
1. Forward pass: logits = model(input_tokens)
2. Compute loss: loss = cross_entropy(logits, target_tokens)
3. Backward pass: compute gradients via backpropagation
4. Update weights: optimizer.step()
5. Zero gradients: optimizer.zero_grad()
print epoch loss
Batching
For efficiency, we process multiple sequences at once in a batch. Instead of feeding one sequence at a time, we stack batch_size sequences into a matrix:
- Input shape:
(batch_size, seq_len) - Output logits:
(batch_size, seq_len, vocab_size)
For our small model training on CPU, a batch size of 32-64 works well.
The AdamW optimiser
We will use AdamW — a variant of the Adam optimiser with decoupled weight decay. Adam adapts the learning rate for each parameter based on the history of its gradients, which generally works much better than plain gradient descent. candle provides AdamW out of the box.
AdamW hyperparameters:
learning_rate: 3e-4
beta1: 0.9 (momentum)
beta2: 0.999 (RMS of gradients)
weight_decay: 0.1
Key takeaway: Cross-entropy loss measures how well the model’s predictions match the true next tokens. The training loop repeatedly computes this loss, computes gradients via backpropagation, and updates the model’s parameters to reduce the loss.
§11 Exercise 4: Train on a Small Text Corpus
In this exercise, we put everything together: load a text corpus, create training data, and train our model.
Preparing the data
For training data, we will use a small text corpus — a few kilobytes of Shakespeare works well. Create a file data/input.txt with some text, or use this approach to embed the data directly:
#![allow(unused)]
fn main() {
/// Load and prepare training data.
/// Returns (tokeniser, input_ids) where input_ids is the entire
/// corpus encoded as a vector of token IDs.
fn load_data() -> (CharTokeniser, Vec<u32>) {
let text = "\
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First, you know Caius Marcius is chief enemy to the people.
All:
We know't, we know't.
First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?
All:
No more talking on't; let it be done: away, away!
";
let tok = CharTokeniser::from_corpus(text);
let ids = tok.encode(text);
(tok, ids)
}
}
You can replace this with a longer text for better results. The more data, the more patterns the model can learn.
Creating batches
We need to extract fixed-length chunks from the corpus for training:
#![allow(unused)]
fn main() {
use candle_core::{DType, Device, Tensor};
/// Create a batch of (input, target) pairs from the corpus.
///
/// Each input is a sequence of `seq_len` tokens.
/// Each target is the same sequence shifted by one position.
fn create_batch(
data: &[u32],
batch_size: usize,
seq_len: usize,
device: &Device,
) -> anyhow::Result<(Tensor, Tensor)> {
use rand::Rng;
let mut rng = rand::thread_rng();
let max_start = data.len() - seq_len - 1;
let mut inputs = Vec::with_capacity(batch_size * seq_len);
let mut targets = Vec::with_capacity(batch_size * seq_len);
for _ in 0..batch_size {
let start = rng.gen_range(0..max_start);
for j in 0..seq_len {
inputs.push(data[start + j]);
targets.push(data[start + j + 1]);
}
}
let inputs = Tensor::new(inputs.as_slice(), device)?
.reshape((batch_size, seq_len))?;
let targets = Tensor::new(targets.as_slice(), device)?
.reshape((batch_size, seq_len))?;
Ok((inputs, targets))
}
}
The training loop
Here is the complete training loop. Note that our model’s forward pass needs to be adjusted to handle batched input (a 2D tensor instead of 1D). For simplicity, we can process each sequence in the batch separately and stack the results:
#![allow(unused)]
fn main() {
use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarMap, VarBuilder};
fn train() -> anyhow::Result<()> {
let device = &Device::Cpu;
let (tok, data) = load_data();
println!("Corpus size: {} characters", data.len());
println!("Vocabulary size: {}", tok.vocab_size());
// Model setup
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
let cfg = GptConfig::tiny(tok.vocab_size());
let model = Gpt1Model::new(&cfg, vb)?;
// Optimiser
let params = ParamsAdamW {
lr: 3e-4,
weight_decay: 0.1,
..Default::default()
};
let mut opt = AdamW::new(varmap.all_vars(), params)?;
// Training hyperparameters
let batch_size = 16;
let seq_len = 64;
let num_steps = 1000;
println!("\nTraining for {} steps...\n", num_steps);
for step in 1..=num_steps {
let (inputs, targets) = create_batch(&data, batch_size, seq_len, device)?;
// Forward pass: process each sequence in the batch
let mut all_logits = Vec::new();
for b in 0..batch_size {
let input_b = inputs.get(b)?; // (seq_len,)
let logits_b = model.forward(&input_b)?; // (seq_len, vocab_size)
all_logits.push(logits_b);
}
let logits = Tensor::stack(&all_logits, 0)?; // (batch, seq_len, vocab_size)
// Reshape for cross-entropy: flatten batch and sequence dimensions
let vocab_size = tok.vocab_size();
let logits_flat = logits.reshape((batch_size * seq_len, vocab_size))?;
let targets_flat = targets.reshape(batch_size * seq_len)?;
// Cross-entropy loss
let log_probs = candle_nn::ops::log_softmax(&logits_flat, D::Minus1)?;
let targets_one_hot = targets_flat.to_dtype(DType::I64)?;
let loss = candle_nn::loss::cross_entropy(&logits_flat, &targets_one_hot)?;
// Backward pass + optimiser step
opt.backward_step(&loss)?;
if step % 100 == 0 || step == 1 {
let loss_val: f32 = loss.to_scalar()?;
println!("Step {:>4} | Loss: {:.4}", step, loss_val);
}
}
println!("\nTraining complete!");
Ok(())
}
}
Add rand to your dependencies
Update Cargo.toml:
[dependencies]
candle-core = "0.8"
candle-nn = "0.8"
anyhow = "1"
rand = "0.8"
Expected output
Corpus size: 482 characters
Vocabulary size: 42
Training for 1000 steps...
Step 1 | Loss: 3.7376
Step 100 | Loss: 2.8412
Step 200 | Loss: 2.3567
Step 300 | Loss: 2.0134
Step 400 | Loss: 1.8223
Step 500 | Loss: 1.6891
Step 600 | Loss: 1.5744
Step 700 | Loss: 1.4832
Step 800 | Loss: 1.4102
Step 900 | Loss: 1.3523
Step 1000 | Loss: 1.2987
Training complete!
The loss should decrease steadily. A random model starts with loss \( \approx \ln(\text{vocab_size}) \) (for 42 tokens, that is \( \ln(42) \approx 3.74 \)). As the model trains, it learns character patterns and the loss drops.
Tips for better results
- Use more data. Even a few pages of Shakespeare (50KB+) will dramatically improve generation quality.
- Train longer. 1000 steps is a minimum — try 5000 or 10000 for better results.
- Adjust the learning rate. If loss plateaus, try reducing the learning rate.
- Increase model size. With more data, you can increase
d_modelto 128 andn_layersto 4.
Key takeaway: The training loop repeatedly samples batches, computes the forward pass and cross-entropy loss, and updates weights via backpropagation. Watching the loss decrease is satisfying confirmation that the model is learning.
§12 Exercise 5: Sample from the Model
The payoff for all our work — generating text from the trained model. In this exercise, we implement temperature-based sampling and generate text character by character.
Temperature sampling
After the model produces logits for the next token, we convert them to probabilities using softmax. The temperature parameter controls the randomness of sampling:
\[ P(t_i) = \frac{e^{z_i / T}}{\sum_j e^{z_j / T}} \]
Where \( T \) is the temperature:
| Temperature | Effect |
|---|---|
| T < 1.0 | Sharper distribution — model picks high-probability tokens more often. More deterministic, less creative. |
| T = 1.0 | Unmodified distribution — sample directly from learned probabilities. |
| T > 1.0 | Flatter distribution — lower-probability tokens get a bigger share. More random, more creative. |
| T → 0 | Equivalent to argmax — always pick the most likely token. |
Logits: [2.0, 5.0, 1.0]
T = 1.0 → P: [0.05, 0.93, 0.02] (normal)
T = 0.5 → P: [0.00, 1.00, 0.00] (very peaked)
T = 2.0 → P: [0.18, 0.63, 0.19] (flattened)
Top-k sampling
Top-k sampling restricts the choice to the k most probable tokens, setting all other probabilities to zero. This prevents the model from choosing extremely unlikely tokens (which can produce gibberish):
Logits (sorted): [5.0, 3.0, 2.0, 0.5, -1.0, -3.0]
Top-k (k=3): [5.0, 3.0, 2.0, -inf, -inf, -inf]
After softmax: [0.67, 0.24, 0.09, 0.0, 0.0, 0.0]
Combining temperature and top-k is the standard approach in practice.
Implementation
Add a sampling function to your project:
#![allow(unused)]
fn main() {
use rand::distributions::Distribution;
/// Sample a token ID from logits with temperature and optional top-k.
///
/// # Arguments
/// * `logits` - 1D tensor of shape `(vocab_size,)` — raw model output
/// * `temperature` - Controls randomness (lower = more deterministic)
/// * `top_k` - If Some(k), only consider the top k most likely tokens
fn sample_token(
logits: &Tensor,
temperature: f64,
top_k: Option<usize>,
) -> anyhow::Result<u32> {
let device = logits.device();
let vocab_size = logits.dim(0)?;
// Apply temperature
let scaled = if temperature != 1.0 {
(logits / temperature)?
} else {
logits.clone()
};
// Convert to Vec for manipulation
let mut logit_vec: Vec<f32> = scaled.to_vec1()?;
// Apply top-k: set everything outside top-k to -inf
if let Some(k) = top_k {
let k = k.min(vocab_size);
let mut indexed: Vec<(usize, f32)> = logit_vec.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let threshold = indexed[k - 1].1;
for val in logit_vec.iter_mut() {
if *val < threshold {
*val = f32::NEG_INFINITY;
}
}
}
// Softmax to get probabilities
let max_val = logit_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logit_vec.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
let probs: Vec<f64> = exps.iter().map(|&x| (x / sum) as f64).collect();
// Sample from the distribution
let dist = rand::distributions::WeightedIndex::new(&probs)?;
let mut rng = rand::thread_rng();
Ok(dist.sample(&mut rng) as u32)
}
}
The generation loop
#![allow(unused)]
fn main() {
/// Generate text from the model autoregressively.
///
/// # Arguments
/// * `model` - The trained GPT model
/// * `tok` - The tokeniser
/// * `prompt` - Starting text
/// * `max_tokens` - Number of tokens to generate
/// * `temperature` - Sampling temperature
/// * `top_k` - Optional top-k filtering
fn generate(
model: &Gpt1Model,
tok: &CharTokeniser,
prompt: &str,
max_tokens: usize,
temperature: f64,
top_k: Option<usize>,
) -> anyhow::Result<String> {
let device = &Device::Cpu;
// Encode the prompt
let mut token_ids = tok.encode(prompt);
let max_seq_len = 128; // Must match model config
for _ in 0..max_tokens {
// Truncate to max_seq_len if needed (keep most recent tokens)
let start = if token_ids.len() > max_seq_len {
token_ids.len() - max_seq_len
} else {
0
};
let context = &token_ids[start..];
// Forward pass
let input = Tensor::new(context, device)?;
let logits = model.forward(&input)?;
// Get logits for the last position
let last_logits = logits.get(context.len() - 1)?; // (vocab_size,)
// Sample next token
let next_token = sample_token(&last_logits, temperature, top_k)?;
token_ids.push(next_token);
}
Ok(tok.decode(&token_ids))
}
}
Putting it together
After training, add generation:
fn main() -> anyhow::Result<()> {
// ... training code from Exercise 4 ...
println!("\n--- Generating text ---\n");
// Try different temperatures
for temp in [0.5, 0.8, 1.0, 1.5] {
let text = generate(&model, &tok, "First", 200, temp, Some(10))?;
println!("Temperature {:.1}:", temp);
println!("{}\n", text);
}
Ok(())
}
Expected output
With a well-trained model on Shakespeare text:
Temperature 0.5:
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
Temperature 1.0:
First Citizen:
Let us know the people are resolved to
the corn at our own kill him, and we
speak.
Temperature 1.5:
First Civkzl:
aNo moye, ws arl't; he proceiw
Le usdn ferktie corn at mork
At low temperature, the model reproduces memorised text. At high temperature, it becomes creative but error-prone. Temperature 0.8-1.0 is usually the sweet spot.
Exercises to try
- Experiment with top-k. Try
k = 1(greedy),k = 5,k = 10, andNone(no filtering). How does it affect output quality? - Implement top-p (nucleus) sampling. Instead of a fixed k, include tokens until their cumulative probability exceeds a threshold p (e.g., 0.9).
- Try different prompts. How does the model respond to prompts it has seen vs. novel prompts?
- Measure perplexity. Compute \( e^{\text{average loss}} \) on a held-out test set to quantify model quality.
Key takeaway: Text generation works by repeatedly running the model, sampling a token from the output distribution, and appending it. Temperature and top-k control the tradeoff between coherence and creativity.
Part 5 — Reflection
§13 What Limits This Model?
We have built a working language model — it can learn patterns in text and generate new text. But if you compare its output to ChatGPT or Claude, the gap is enormous. Let us understand why.
Context length
Our model has a maximum context window of 128 tokens (characters). It literally cannot “see” anything beyond the last 128 characters. Modern LLMs have context windows of 8K to 200K+ tokens (and those are subword tokens, each representing several characters). This means our model:
- Cannot maintain coherence over long passages.
- Cannot reason about information that appeared more than a few words ago.
- Has no ability to follow instructions that exceed its window.
Model size
Our model has roughly 100,000 parameters. For comparison:
| Model | Parameters | Ratio to ours |
|---|---|---|
| Our model | ~100K | 1x |
| GPT-1 (2018) | 117M | 1,170x |
| GPT-2 (2019) | 1.5B | 15,000x |
| GPT-3 (2020) | 175B | 1,750,000x |
| GPT-4 (2023) | ~1.8T (rumoured) | ~18,000,000x |
With only 100K parameters, our model can memorise short character patterns but cannot learn grammar, semantics, or world knowledge. Larger models have more capacity to store and compose information.
Training data
We trained on a few hundred characters. Real LLMs train on trillions of tokens — essentially the entire public internet, books, code repositories, and more. The sheer volume and diversity of data is what allows large models to:
- Learn the structure of many languages.
- Absorb factual knowledge.
- Understand code, math, and reasoning patterns.
Tokenisation
Character-level tokenisation means our model sees one character per step. A 10-word sentence is ~50 tokens for us, but only ~10-15 tokens for a BPE tokeniser. This means:
- Our model needs a longer context window for the same effective range.
- Longer sequences are more expensive to process (attention is \( O(n^2) \) in sequence length).
- The model must learn to spell — it cannot take words as atomic units.
Real LLMs use BPE (GPT) or SentencePiece (Llama) tokenisers with vocabularies of 32K-100K tokens.
Training techniques we skipped
Production LLMs use many techniques we did not cover:
- Learning rate scheduling. A warm-up phase followed by cosine decay.
- Gradient clipping. Preventing exploding gradients by capping their magnitude.
- Mixed precision training. Using float16/bfloat16 for speed and memory efficiency.
- Data parallelism and model parallelism. Distributing training across hundreds of GPUs.
- RLHF (Reinforcement Learning from Human Feedback). Fine-tuning the model to follow instructions and be helpful, using human preference data. This is what makes ChatGPT and Claude conversational, rather than just completing text.
- Supervised fine-tuning (SFT). Training on curated instruction-response pairs before RLHF.
What our model CAN do
Despite its limitations, our model demonstrates every fundamental component of a modern LLM:
- Tokenisation — converting text to numbers and back.
- Embeddings — learned vector representations of tokens and positions.
- Self-attention with causal masking — the core Transformer mechanism.
- Stacked Transformer blocks — depth through repeated application.
- Cross-entropy training — learning from next-token prediction.
- Autoregressive generation — producing text one token at a time.
The jump from our model to GPT-4 is primarily one of scale — more parameters, more data, more compute — plus careful engineering and alignment techniques. The architecture is fundamentally the same.
Key takeaway: Our model is limited by context length, model size, training data, and tokenisation. But it contains every core component of production LLMs. The path from here to GPT-4 is primarily scaling and engineering, not architectural revolution.
§14 Further Reading
This chapter covered the foundations. Here are resources to go deeper, organised by topic.
Foundational papers
- “Attention Is All You Need” (Vaswani et al., 2017) — arxiv.org/abs/1706.03762 — The paper that introduced the Transformer architecture. Essential reading.
- “Improving Language Understanding by Generative Pre-Training” (Radford et al., 2018) — CDN link — The GPT-1 paper. Showed that a decoder-only Transformer pre-trained on language modeling can be fine-tuned for many tasks.
- “Language Models are Unsupervised Multitask Learners” (Radford et al., 2019) — CDN link — The GPT-2 paper. Showed that scaling up GPT-1 leads to emergent few-shot abilities.
- “Language Models are Few-Shot Learners” (Brown et al., 2020) — arxiv.org/abs/2005.14165 — The GPT-3 paper. Demonstrated that massive scale enables in-context learning.
Tutorials and courses
- Andrej Karpathy’s “Let’s build GPT” — youtube.com/watch?v=kCc8FmEb1nY — A two-hour video building a GPT from scratch in Python/PyTorch. Excellent companion to this chapter, covering the same ideas in a different language.
- Andrej Karpathy’s “makemore” series — Builds character-level language models of increasing complexity, from bigrams to Transformers. Available on YouTube.
- 3Blue1Brown “But what is a neural network?” — youtube.com/watch?v=aircAruvnKk — Beautiful visual explanations of the basics of neural networks, backpropagation, and gradient descent.
- “The Illustrated Transformer” by Jay Alammar — jalammar.github.io/illustrated-transformer/ — The best visual guide to the Transformer architecture.
Rust ML ecosystem
- Candle — github.com/huggingface/candle — The tensor framework we used in this chapter. Supports CPU and GPU, with a PyTorch-like API.
- Candle documentation — docs.rs/candle-core — API reference for tensor operations.
- Burn — burn.dev — Another Rust deep learning framework, with a different design philosophy (backend-agnostic).
- tch-rs — github.com/LaurentMazare/tch-rs — Rust bindings for PyTorch’s C++ library (libtorch). More mature but requires a C++ dependency.
Books
- “Deep Learning” by Goodfellow, Bengio, and Courville — deeplearningbook.org — The comprehensive textbook on deep learning fundamentals.
- “Dive into Deep Learning” — d2l.ai — Interactive, code-first textbook with implementations in multiple frameworks.
- “Speech and Language Processing” by Jurafsky and Martin — web.stanford.edu/~jurafsky/slp3/ — Covers NLP foundations including language modeling, with chapters on neural approaches.
Topics to explore next
Now that you understand the basics, here are natural next steps:
- Subword tokenisation. Implement BPE (Byte-Pair Encoding) to handle larger vocabularies efficiently. See the
tokenizerscrate by Hugging Face. - GPU training. Switch from
Device::CputoDevice::Cudain candle to train on a GPU. This enables much larger models and datasets. - Positional encodings. Experiment with RoPE (Rotary Position Embeddings), which is used in Llama and most modern models.
- KV caching. During generation, cache the key and value tensors from previous tokens to avoid redundant computation. This is essential for fast inference.
- Fine-tuning a pre-trained model. Load a pre-trained model (e.g., a small Llama) in candle and fine-tune it on your own data.
- RLHF. Study how reinforcement learning from human feedback transforms a language model into an assistant.
Key takeaway: The field of language modeling is vast and evolving rapidly. The fundamentals you learned in this chapter — tokenisation, embeddings, attention, training — are the foundation everything else builds on. Pick a direction that interests you and keep building.