Machine Learning: Training a Game AI Through Self-Play
This is a self-guided course on reinforcement learning through the lens of a concrete goal: training a Rust program to play Tic-Tac-Toe by playing against itself, in the style of AlphaGo Zero. No prior ML experience is assumed. You will build everything from scratch β a game engine, a search algorithm, and eventually a neural network that guides that search. Sections marked π§ are stubs whose full content is tracked in a beans ticket.
Table of Contents
Part 1 β Foundations
- What is reinforcement learning?
- Monte Carlo Tree Search β algorithm explained
- Why self-play? The AlphaGo Zero insight
Part 2 β The Game
- Choosing a simple game: Tic-Tac-Toe
- Representing game state in Rust
- Exercise 1: implement the game logic
Part 3 β MCTS
Part 4 β Neural Network Policy/Value Head
- Neural network architecture overview
- Integrating a neural network crate
- Exercise 3: train the network on MCTS data
- Exercise 4: replace rollout with the value network
Part 5 β Self-Play Loop
Part 1 β Foundations
1. What is reinforcement learning?
Imagine teaching a dog to sit. You donβt hand it a textbook on canine biomechanics. You donβt show it a thousand videos of other dogs sitting. Instead, you wait for it to do something close to sitting, and then you give it a treat. Over time, the dog figures out: that particular action leads to something good. It learns from experience.
That, in essence, is reinforcement learning (RL). It is one of the three major branches of machine learning, and it is the one that most closely mirrors how humans and animals learn to interact with the world. Before we dive into RL itself, letβs briefly look at the other two branches so you can see what makes RL different.
The three branches of machine learning
Supervised learning is like studying with an answer key. You are given a dataset of questions paired with correct answers β pictures labeled βcatβ or βdog,β emails labeled βspamβ or βnot spam,β house features paired with sale prices. The algorithmβs job is to learn the pattern that maps inputs to outputs so it can make predictions on new, unseen data.
The key ingredient is labeled data. Someone (often a human) has already done the hard work of providing the right answer for every example. The algorithm learns by comparing its predictions to those answers and adjusting itself to get closer.
Unsupervised learning is like being handed a box of puzzle pieces with no picture on the box. There are no labels, no right answers. The algorithmβs job is to find structure on its own β clusters of similar data points, hidden patterns, compressed representations. Think of a music streaming service grouping songs into genres based on audio features alone, without anyone telling it what βjazzβ or βrockβ means.
Reinforcement learning is different from both. There are no labeled examples to study. There is no dataset to mine for hidden structure. Instead, there is an agent that takes actions in an environment and receives rewards (or penalties) based on what happens. The agentβs goal is to figure out, through trial and error, which actions lead to the most reward over time.
Here is a table that summarizes the differences:
| Supervised | Unsupervised | Reinforcement | |
|---|---|---|---|
| Learning signal | Correct answer for each example | None β find structure | Reward signal after actions |
| Data | Labeled dataset, provided upfront | Unlabeled dataset, provided upfront | Generated through interaction |
| Goal | Predict the right output | Discover hidden patterns | Maximize cumulative reward |
| Analogy | Studying with an answer key | Sorting a box of unlabeled photos | Training a dog with treats |
The critical difference for our purposes: in RL, the agent generates its own data by interacting with the environment. It doesnβt need a human to label millions of examples. It learns by doing. This is exactly why RL is so powerful for games β the agent can play millions of games against itself and learn from every single one.
The core vocabulary of RL
Letβs define the key concepts. We will use a concrete example throughout: imagine you are playing a simple board game.
Agent
The agent is the learner and decision-maker. It is the entity that observes the world, chooses actions, and tries to maximize its reward. In our course, the agent will be a program that plays Tic-Tac-Toe. In the real world, an agent could be a robot navigating a warehouse, a program trading stocks, or a thermostat controlling room temperature.
The agent is not the environment. It does not control the rules. It can only observe what is happening and choose from the actions available to it.
Environment
The environment is everything outside the agent. It is the world the agent interacts with. The environment receives the agentβs actions, updates its internal state, and sends back two things: the new state and a reward.
For a board game, the environment is the game itself β the board, the rules, and the opponent. The agent places a piece, and the environment responds with the updated board position and (eventually) a signal about whether the agent won or lost.
The environment does not need to be physical. It can be a simulation, a video game, a mathematical function β anything that accepts actions and produces states and rewards.
State
The state is a snapshot of the environment at a particular moment. It contains all the information the agent needs to make a decision.
In Tic-Tac-Toe, the state is the current board configuration: which squares have X, which have O, and which are empty. In chess, it is the position of every piece on the board. For a self-driving car, the state might include speed, position, nearby obstacles, traffic light status, and more.
A good state representation captures everything relevant. If the agent cannot see something that matters, it will struggle to make good decisions. We will spend a full section later on choosing a state representation for Tic-Tac-Toe.
Action
An action is something the agent can do. At each step, the agent looks at the current state and selects one action from the set of available actions.
In Tic-Tac-Toe, an action is placing your mark on one of the empty squares. If there are four empty squares, the agent has four possible actions. In a video game, actions might be βmove left,β βjump,β or βfire.β For a robot arm, actions could be βrotate joint 30 degreesβ or βclose gripper.β
The set of available actions can change from state to state. Early in a Tic-Tac-Toe game, you have nine possible moves. Late in the game, you might have only one or two.
Reward
The reward is a numerical signal from the environment that tells the agent how well it is doing. It is the only feedback the agent gets. The agentβs entire goal is to accumulate as much reward as possible over time.
Rewards can be:
- Positive β βThat was good, do more of that.β Winning a game might give +1.
- Negative β βThat was bad, do less of that.β Losing might give -1.
- Zero β βNothing interesting happened.β Most moves in a game might give 0 reward.
One of the challenges of RL is that rewards can be sparse. In Tic-Tac-Toe, you might only get a reward at the very end of the game β win, lose, or draw. All the moves leading up to that final outcome get zero reward. The agent has to figure out which of its earlier moves were responsible for the eventual outcome. This is called the credit assignment problem: which actions deserve credit (or blame) for a delayed reward?
Think about learning to cook. You follow a recipe, making dozens of small decisions along the way β how much salt, how long to sear, when to add the garlic. You only get your βrewardβ when you taste the final dish. If it is terrible, which decision was the mistake? Was it the salt? The garlic timing? The heat? RL faces exactly this challenge.
Policy
A policy is the agentβs strategy β its rule for choosing actions. Given a state, the policy tells the agent what to do.
A policy can be simple: βalways pick the first available squareβ (a terrible Tic-Tac-Toe strategy, but a valid policy). Or it can be complex: a neural network that takes the board state as input and outputs a probability distribution over all possible moves.
We write the policy as pi (the Greek letter). Formally, a policy maps states to actions:
Given state s, the policy pi(s) tells the agent which action a to take.
A deterministic policy always picks the same action for a given state. A stochastic policy outputs probabilities β βplay center with 60% chance, play corner with 30% chance, play edge with 10% chance.β Stochastic policies are powerful because they allow exploration. If the agent always does the same thing, it can never discover that something else might be better.
The goal of RL is to find a good policy β one that leads to high cumulative reward. The optimal policy is the one that maximizes expected reward. Finding it (or approximating it) is what RL algorithms do.
Value function
The value function answers a deceptively important question: how good is it to be in a particular state?
Not how good is the immediate reward β but how much total future reward can the agent expect to accumulate from this state onward, assuming it follows its current policy?
Consider two Tic-Tac-Toe positions. In one, you have two in a row with the third square open β you are about to win. In the other, your opponent has two in a row with the third square open β you are about to lose. The immediate reward for both states is zero (the game is not over yet), but their values are very different. One state is worth a lot; the other is nearly worthless.
The value function is what lets the agent think ahead. Instead of being greedy (just maximizing immediate reward), the agent can evaluate states by their long-term potential. βThis move gives me zero reward right now, but it puts me in a state with high value.β
There are two flavors:
- State value function V(s): How much total future reward can I expect from state s?
- Action value function Q(s, a): How much total future reward can I expect if I take action a in state s?
The Q function is especially useful because it directly tells the agent which action is best: just pick the action with the highest Q value.
The RL loop
Now letβs put all these concepts together. Reinforcement learning follows a cycle β a loop that repeats over and over:
ββββββββββββββββββββββββββββββββββββββββββββββββββββ
β β
β βββββββββββ action βββββββββββββββ β
β β ββββββββββββββββΆβ β β
β β Agent β β Environment β β
β β βββββββββββββββββ β β
β βββββββββββ state, βββββββββββββββ β
β reward β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββ
Here is the loop, step by step:
-
Observe: The agent observes the current state of the environment. In Tic-Tac-Toe, it looks at the board.
-
Decide: The agent uses its policy to choose an action. Maybe it picks the center square. Maybe it picks a corner. The policy determines how this choice is made.
-
Act: The agent takes the action. It places its mark on the board.
-
Receive feedback: The environment responds with two things β the new state (the updated board, possibly after the opponent also moves) and a reward (did the game end? did we win?).
-
Learn: The agent uses this experience (state, action, reward, new state) to update its policy and/or value function. βLast time I was in that state and took that action, things went well. I should do that more often.β
-
Repeat: Go back to step 1.
This loop runs thousands, millions, even billions of times. Early on, the agent has no idea what it is doing. It explores randomly, stumbling into wins and losses. But over time, it accumulates experience. It starts to learn which states are good and which are bad, which actions lead to rewards and which lead to disaster. The policy improves. The value estimates get more accurate. The agent gets better.
This is fundamentally different from supervised learning, where all the data is collected upfront and the model trains on it in a fixed dataset. In RL, the agentβs behavior changes the data it collects. A better policy leads to different states, which leads to different experiences, which leads to further improvements. The agent is simultaneously exploring the environment and exploiting what it has learned.
Exploration vs. exploitation
This brings us to one of the deepest tensions in RL: the exploration-exploitation tradeoff.
Suppose you have found a restaurant you like. Do you keep going there (exploit what you know) or try a new restaurant that might be even better (explore the unknown)? If you always exploit, you might miss something great. If you always explore, you never enjoy the best thing you have already found.
RL agents face this same dilemma at every step. The agent has a policy that it thinks is pretty good. Should it follow that policy (exploit) or try something random to see what happens (explore)?
-
Too much exploitation: The agent gets stuck in a rut. It finds a decent strategy and never discovers that a much better one exists. In Tic-Tac-Toe, it might learn to always play in the center and never discover that corner openings can be just as strong.
-
Too much exploration: The agent never commits to anything. It keeps trying random things and never builds on what it has learned. It plays like a beginner forever.
Good RL algorithms balance exploration and exploitation carefully. A common approach is to start with lots of exploration (try everything!) and gradually shift toward exploitation as the agent becomes more confident in its knowledge. We will see this principle in action when we implement Monte Carlo Tree Search in a later section.
Why RL is a natural fit for games
You might be wondering: why are we using reinforcement learning for a game? Why not just use supervised learning β collect a bunch of expert games, label the moves as βgoodβ or βbad,β and train on that?
There are several reasons why RL is a particularly natural fit for games:
Games have clear rules and rewards. The environment is well-defined. The state is fully observable (in most board games, you can see the whole board). The reward is unambiguous: win, lose, or draw. There is no messy, noisy real-world data to deal with. This makes games an ideal testbed for RL algorithms.
You can generate unlimited training data. Unlike supervised learning, where you need humans to label examples, an RL agent can play games against itself forever. It generates its own training data through self-play. No labeling required. No expensive human expertise needed. The agent just plays and learns.
The search space is structured. Games have turns, legal moves, and terminal states. This structure makes it possible to reason about future consequences systematically. Algorithms like Monte Carlo Tree Search (which we will cover in the next section) exploit this structure brilliantly.
There is no βexpertβ bottleneck. Supervised learning is limited by the quality of its training data. If you train on expert games, the best your agent can do is imitate those experts. But what if the experts are not actually optimal? RL agents can surpass human performance because they are not trying to copy anyone β they are trying to maximize reward. AlphaGo Zero famously discovered Go strategies that no human had ever played, because it was not constrained by human conventions.
Games have been the proving ground for AI breakthroughs. From IBMβs Deep Blue defeating Kasparov at chess in 1997, to DeepMindβs AlphaGo defeating Lee Sedol at Go in 2016, to OpenAI Five competing at Dota 2, games have consistently pushed the boundaries of what AI can do. The techniques developed for games β including the ones you will learn in this course β have gone on to be applied in robotics, drug discovery, chip design, and more.
A preview of where we are headed
In this course, we will build a complete RL system from scratch. Here is a high-level preview:
-
Part 2 builds the game β a clean, efficient Tic-Tac-Toe engine in Rust that can represent states, enumerate legal moves, and detect wins.
-
Part 3 builds the search β Monte Carlo Tree Search, an algorithm that explores possible future game states by simulating random playouts. MCTS does not need any machine learning at all; it works through pure simulation.
-
Part 4 adds the brain β a neural network that learns to evaluate board positions (value) and suggest promising moves (policy). This replaces the random simulations in MCTS with learned intuition.
-
Part 5 closes the loop β the full self-play training cycle where the agent plays itself, generates training data, improves its neural network, and repeats. This is the AlphaGo Zero recipe.
By the end, you will have a program that starts knowing nothing about Tic-Tac-Toe β not even the basic strategy β and teaches itself to play at a strong level through pure self-play. The same architecture, scaled up, is what defeated the world champion at Go.
But first, we need to understand the search algorithm that makes this all possible. In the next section, we will explore Monte Carlo Tree Search.
2. Monte Carlo Tree Search β algorithm explained
In the last section, we learned that an RL agent needs a policy β a strategy for choosing actions. But how do you come up with a good policy for a game like Tic-Tac-Toe, or chess, or Go? One powerful answer is to search: look ahead at possible future moves, evaluate what might happen, and pick the move that leads to the best outcome.
The problem is that looking ahead in a game is staggeringly expensive.
The game tree problem
Every game can be represented as a game tree. The root is the current position. Each branch is a possible move. Each node is a resulting position. The leaves are terminal states β wins, losses, or draws.
For Tic-Tac-Toe, the game tree is manageable. From the opening position, there are 9 possible first moves, then 8 possible responses, then 7, and so on. The total number of possible games is at most 9! = 362,880 (and fewer in practice because many games end before all squares are filled). A modern computer can explore the entire Tic-Tac-Toe tree in milliseconds.
Now consider chess. From the opening position, White has 20 possible first moves. Black has 20 possible responses. That is already 400 positions after one move each. The average chess game is about 40 moves per side, and at each turn there are roughly 30 legal moves. The total number of possible chess games is estimated at around 10^120 β a number so large it dwarfs the number of atoms in the observable universe (roughly 10^80).
Go is even worse. On a 19x19 board, the first move has 361 options. The game tree has an estimated 10^360 possible positions.
Exhaustive search is impossible. You cannot look at every possible future. You need a smarter approach β one that focuses its computational effort on the parts of the tree that matter most.
What βMonte Carloβ means
The term βMonte Carloβ comes from the famous casino in Monaco and refers to a broad class of algorithms that use random sampling to estimate quantities that are too expensive to compute exactly.
The core idea is simple: if you cannot calculate something precisely, you can estimate it by running many random experiments and averaging the results.
Suppose someone asks you: βWhat fraction of the area of a square is covered by a circle inscribed inside it?β You could calculate this analytically (it is pi/4). Or you could take the Monte Carlo approach: throw 10,000 darts randomly at the square, count how many land inside the circle, and divide. With enough darts, your estimate converges to the true answer.
Monte Carlo Tree Search applies this same idea to game trees. Instead of exhaustively analyzing every possible sequence of moves, we play out random games (called rollouts or simulations) and use the results to estimate how good each move is. A move that leads to lots of random wins is probably a good move. A move that leads to lots of random losses is probably a bad one.
The remarkable thing is that this works. Even completely random play β just picking legal moves at random until the game ends β gives you useful signal about which positions are strong and which are weak. And as we will see, MCTS is much smarter than pure random sampling. It focuses its simulations on the most promising parts of the tree.
The four phases of MCTS
MCTS builds a search tree incrementally, one simulation at a time. Each simulation consists of four phases: Selection, Expansion, Simulation, and Backpropagation. Letβs walk through each one.
Phase 1: Selection
Starting from the root of the tree (the current game position), we walk down through nodes that we have already explored, choosing which child to visit at each step.
The key question is: which child do we pick? We want to visit moves that look promising (to refine our estimate of their value), but we also want to try moves we have not explored much (because they might turn out to be even better). This is the exploration-exploitation tradeoff from the previous section, and MCTS handles it with a formula called UCT, which we will explain shortly.
Selection continues until we reach a node that has at least one child we have never visited β a frontier of the known tree.
Phase 2: Expansion
Once selection reaches a node with unexplored children, we pick one of those unexplored moves and add it to the tree as a new node. This is how the tree grows: one new node per simulation.
The new node starts with no statistics β we have never simulated a game from this position before. We are about to fix that.
Phase 3: Simulation (rollout)
From the newly expanded node, we play a random game to completion. Both sides make random legal moves until someone wins, someone loses, or the game is drawn. This is called a rollout or playout.
The rollout does not add any nodes to the tree. It is purely a quick-and-dirty way to estimate how good the new position is. If random play from this position tends to result in wins, the position is probably decent. If it tends to result in losses, the position is probably bad.
The result of the rollout is a simple outcome: typically +1 for a win, -1 for a loss, and 0 for a draw.
Phase 4: Backpropagation
After the rollout finishes, we take the result and propagate it back up the tree, updating every node we visited during the selection phase.
Each node in the tree tracks two numbers:
- N β the number of times this node has been visited (i.e., how many simulations have passed through it)
- W β the total accumulated reward from all simulations that passed through it
We increment N by 1 and add the rollout result to W for every node on the path from the new node back to the root. Note that the reward is relative to the player whose turn it is at each node β a win for player X is a loss for player O, so the sign of the reward flips as we move up through alternating players.
After backpropagation, the tree has slightly better statistics. The next simulation can make a slightly more informed selection. Over hundreds or thousands of simulations, the treeβs estimates converge toward the true values of each move.
The MCTS loop in pseudocode
Here is the complete algorithm:
function mcts(root_state, num_simulations):
root = create_node(root_state)
repeat num_simulations times:
// Phase 1: Selection
node = root
state = copy(root_state)
while node is fully expanded and not terminal:
node = select_child(node) // using UCT formula
state = apply_move(state, node.move)
// Phase 2: Expansion
if node is not terminal:
move = pick an unexplored move from node
state = apply_move(state, move)
child = create_node(state, move)
add child to node's children
node = child
// Phase 3: Simulation (rollout)
while state is not terminal:
move = random legal move
state = apply_move(state, move)
result = reward from terminal state
// Phase 4: Backpropagation
while node is not null:
node.N += 1
node.W += result (from node's perspective)
node = node.parent
// After all simulations, pick the most-visited child of root
return child of root with highest N
Notice the final step: after all simulations are done, we pick the move whose corresponding child has been visited the most (highest N), not the one with the highest average reward. Visit count is a more robust measure than average reward because it reflects how much computational effort the algorithm devoted to that move. A move that was visited many times was consistently attractive enough to survive the selection process.
UCB1 and the UCT formula
The heart of MCTS is the selection phase, and the heart of selection is the formula that decides which child to visit. This is where exploration and exploitation are balanced.
The standard formula is called UCT (Upper Confidence Bounds for Trees), adapted from a statistical formula called UCB1 (Upper Confidence Bound 1).
For each child node, UCT computes a score:
UCT(child) = (W / N) + C * sqrt(ln(N_parent) / N)
Where:
- W / N is the childβs average reward β its win rate so far. This is the exploitation term. Children with high win rates get high scores.
- sqrt(ln(N_parent) / N) is the exploration term. N_parent is the visit count of the parent node, and N is the visit count of the child. This term is large when the child has been visited relatively few times compared to its siblings. It shrinks as the child gets more visits.
- C is a constant that controls the balance between exploration and exploitation. A common choice is C = sqrt(2), which comes from the theoretical analysis of UCB1. In practice, this value is often tuned.
The intuition behind UCT is straightforward:
If a child has a high win rate, visit it more (exploitation). You want to refine your estimate of moves that look strong.
If a child has been visited much less than its siblings, visit it more (exploration). You might be missing something good. A move with 2 visits and a 50% win rate might actually be the best move β you just havenβt given it enough chances yet.
As N grows, the exploration term shrinks. The more you visit a child, the less βexploration bonusβ it gets. Eventually, exploitation dominates, and the algorithm focuses on the moves with genuinely high win rates.
As N_parent grows but N stays small, the exploration term grows. If the parent has been visited 1000 times but one child has only been visited 3 times, that child gets a large exploration bonus. The algorithm is saying: βYouβve visited this position 1000 times but barely looked at this move. Go check it out.β
UCT is elegant because it automatically and gradually shifts from exploration to exploitation. In the early simulations, when visit counts are low, the exploration term dominates and the algorithm tries many different moves. As simulations accumulate, the exploitation term takes over and the algorithm focuses on the best moves.
A worked example
Letβs trace through a small example to see MCTS in action. Imagine a very simple game β not Tic-Tac-Toe, but something even smaller so we can fit the tree in ASCII art. Suppose it is our turn and we have three possible moves: A, B, and C.
After 0 simulations β We have only the root node. Nothing has been explored yet.
Root
(0/0)
/ | \
? ? ?
A B C
The notation (W/N) means W total reward out of N visits. The root has no visits yet.
Simulations 1-3 β We expand and simulate each child once (since all are unexplored, selection goes to each in turn). Suppose the random rollouts give us: A wins, B loses, C wins.
Root
(1/3)
/ | \
(1/1) (0/1) (1/1)
A B C
A has 1 win out of 1 visit (average 1.0). B has 0 wins out of 1 visit (average 0.0). C has 1 win out of 1 visit (average 1.0).
Simulation 4 β Now all children have been visited at least once, so selection uses UCT to choose. Letβs compute UCT scores with C = sqrt(2) ~ 1.41:
UCT(A) = 1/1 + 1.41 * sqrt(ln(3)/1) = 1.0 + 1.41 * 1.048 = 1.0 + 1.48 = 2.48
UCT(B) = 0/1 + 1.41 * sqrt(ln(3)/1) = 0.0 + 1.41 * 1.048 = 0.0 + 1.48 = 1.48
UCT(C) = 1/1 + 1.41 * sqrt(ln(3)/1) = 1.0 + 1.41 * 1.048 = 1.0 + 1.48 = 2.48
A and C are tied. Say we break ties randomly and pick A. Suppose the rollout from A results in a loss this time.
Root
(1/4)
/ | \
(1/2) (0/1) (1/1)
A B C
Simulation 5 β Recompute UCT scores:
UCT(A) = 1/2 + 1.41 * sqrt(ln(4)/2) = 0.50 + 1.41 * 0.833 = 0.50 + 1.17 = 1.67
UCT(B) = 0/1 + 1.41 * sqrt(ln(4)/1) = 0.00 + 1.41 * 1.177 = 0.00 + 1.66 = 1.66
UCT(C) = 1/1 + 1.41 * sqrt(ln(4)/1) = 1.00 + 1.41 * 1.177 = 1.00 + 1.66 = 2.66
C has the highest UCT score (2.66) β it has a perfect win rate and has only been visited once, so both exploitation and exploration favor it. MCTS selects C. Suppose the rollout from C is a win again.
Root
(2/5)
/ | \
(1/2) (0/1) (2/2)
A B C
Notice what is happening. C is emerging as the most promising move (2 wins out of 2 visits). But MCTS has not abandoned A (which might just have been unlucky) or even B (which gets an exploration bonus for being rarely visited). Over many more simulations, if C is truly the best move, it will accumulate the most visits.
After 100 simulations, the tree might look something like this:
Root
(62/100)
/ | \
(25/42) (7/16) (30/42)
A B C
C has the highest visit count (42) and win rate (30/42 = 71%). A is close behind. B has been explored less because its early results were poor, but MCTS still checked it 16 times to make sure. The algorithm would recommend move C.
How MCTS builds an asymmetric tree
A critical property of MCTS is that it does not explore the game tree uniformly. Unlike minimax with a fixed depth limit (which examines every move to the same depth), MCTS builds an asymmetric tree that is deep in promising branches and shallow in unpromising ones.
Consider this scenario in a more complex game:
Root
/ \
A B
/ \ |
A1 A2 B1
/
A1a
/ \
A1a1 A1a2
Move A looks promising, so MCTS explores it deeply. Within A, move A1 looks better than A2, so it gets explored even deeper. Move B looks weak, so it only has one child β MCTS checked it out, confirmed it was not great, and moved on.
This is exactly the right strategy. There is no point spending computational effort analyzing a bad move 15 moves deep. It is much better to spend those simulations refining your understanding of the good moves. MCTS achieves this automatically through the UCT formula β no hand-tuning required.
This property becomes especially important in games with large branching factors. In Go, where each position might have 200+ legal moves, you simply cannot afford to explore every move to the same depth. MCTS naturally focuses on the 10-20 most promising moves and explores them deeply while giving the others only a cursory glance.
Why MCTS works without domain knowledge
One of the most remarkable properties of MCTS is that it works with zero domain knowledge. The algorithm needs only three things:
- A way to enumerate legal moves from any position
- A way to apply a move and get the resulting position
- A way to detect when the game is over and who won
That is it. MCTS does not need to know that controlling the center matters in Tic-Tac-Toe. It does not need to know that connected stones are strong in Go. It does not need an evaluation function crafted by a human expert. It discovers all of this on its own through random rollouts.
This is a profound advantage. For many games and decision problems, writing a good evaluation function is extremely difficult. What makes a chess position βgoodβ? Generations of chess programmers have painstakingly encoded hundreds of heuristic rules β piece values, king safety, pawn structure, mobility, and more. MCTS sidesteps this entirely. You just need the rules of the game.
Of course, random rollouts are not as accurate as expert evaluation. A random game of chess is a terrible guide to who is winning β random play is so bad that the result is nearly uncorrelated with the actual quality of the position. This is where the AlphaGo Zero innovation comes in (which we will cover in the next section): replace random rollouts with a neural network that has learned to evaluate positions. But pure MCTS with random rollouts is still surprisingly effective for many games, especially those with lower branching factors.
Strengths of vanilla MCTS
Anytime algorithm. You can stop MCTS after 100 simulations or 100,000 simulations. More simulations give better results, but even a small number of simulations gives a reasonable answer. This is useful when you have a time limit β run as many simulations as you can in the available time, then pick the best move.
No evaluation function needed. As discussed above, MCTS works with just the game rules. This makes it applicable to any game (or sequential decision problem) immediately, without domain-specific engineering.
Handles high branching factors. Unlike alpha-beta search (the traditional game-tree search algorithm), which slows down dramatically as the branching factor increases, MCTS handles large branching factors gracefully because it focuses on promising branches.
Naturally balances exploration and exploitation. The UCT formula provides an elegant, principled solution to this fundamental tradeoff.
Asymptotically optimal. Given enough simulations, MCTS converges to the optimal move. The more you simulate, the closer you get to perfect play.
Limitations of vanilla MCTS
Random rollouts can be uninformative. In complex games, random play is essentially noise. A random chess game tells you almost nothing about who is actually winning. This limits how well pure MCTS can play strategic games. (The fix: replace rollouts with a learned evaluation function β exactly what AlphaGo Zero does.)
Many simulations needed for complex games. While MCTS is an anytime algorithm, the number of simulations needed for good play scales with the complexity of the game. For Go on a 19x19 board, millions of simulations per move are needed for strong play with random rollouts alone.
Memory usage. MCTS stores the search tree in memory. In long games or games with high branching factors, this tree can grow large. Various techniques exist to manage this (e.g., recycling parts of the tree between moves).
Not well-suited to all problems. MCTS works best in domains with discrete actions, clear terminal states, and well-defined rewards. It is less natural for continuous control problems or domains where the outcome of a random policy is completely uninformative.
Looking ahead
We now have the two foundational concepts for this course: reinforcement learning (from section 1) and Monte Carlo Tree Search (from this section). In the next section, we will see how DeepMind combined these ideas β along with deep neural networks and self-play β to create AlphaGo Zero, an agent that mastered Go with no human knowledge at all. That combination of MCTS + neural networks + self-play is exactly what we will build in Rust over the rest of this course.
3. Why self-play? The AlphaGo Zero insight
You now understand reinforcement learning and Monte Carlo Tree Search as separate ideas. RL gives us the framework: an agent, an environment, a reward signal, and a policy to optimize. MCTS gives us a powerful search algorithm that explores game trees intelligently using random simulations. Both are impressive on their own. But neither one, by itself, is enough to produce superhuman game play in a complex game like Go.
The missing ingredient is self-play β the idea that an agent can generate its own training data by playing against copies of itself, creating a feedback loop where improvement begets further improvement, with no human knowledge required. This section tells the story of how that insight emerged and why it changed the field.
From handcrafted to learned: a brief history
To appreciate what self-play made possible, it helps to understand what came before. The history of game-playing AI is a story of progressively removing the need for human expertise.
Deep Blue: the triumph of engineering (1997)
In 1997, IBMβs Deep Blue defeated world chess champion Garry Kasparov in a six-game match. It was a landmark moment for AI β the first time a computer had beaten a reigning world champion at chess under standard tournament conditions.
But Deep Blue was not a learning system. It was a masterpiece of engineering. At its core was an alpha-beta search algorithm (a refined version of minimax) running on custom hardware that could evaluate 200 million chess positions per second. The critical component was its evaluation function: a hand-tuned formula that scored any chess position based on hundreds of features β material count, king safety, pawn structure, piece mobility, control of the center, and dozens of other heuristics. These features were designed and tuned by a team that included grandmaster Joel Benjamin.
Deep Blue played brilliantly, but everything it knew about chess had been put there by humans. The evaluation function was the product of decades of chess programming knowledge. If you wanted to apply the same approach to a different game β say, Go β you would need a team of Go experts to design a completely new evaluation function from scratch. And for Go, no one could figure out how to do that well. The game is too subtle, too positional, too intuitive. Handcrafted evaluation functions for Go never came close to expert human play.
Deep Blueβs approach was powerful but fundamentally limited: you need human experts to encode their knowledge, and you can never exceed what those experts know.
AlphaGo: learning from humans, then surpassing them (2016)
Nearly two decades later, DeepMind tackled Go β a game widely considered to be far beyond the reach of brute-force search. The original AlphaGo system, which defeated European champion Fan Hui in 2015 and world champion Lee Sedol in 2016, took a hybrid approach.
AlphaGoβs training had two stages:
Stage 1: Supervised learning from human expert games. DeepMind collected a database of roughly 30 million positions from games played by strong human Go players on online servers. They trained a deep neural network (the policy network) to predict the move a human expert would play in any given position. This network learned to imitate human play β given a board position, it output a probability distribution over possible moves, reflecting how likely a strong human would be to play each one.
This supervised policy network was already impressive. It could predict the expert move about 57% of the time, and it played at a level roughly equivalent to a strong amateur. But imitation has limits. The network was constrained by the data it was trained on β it could only learn patterns that appeared in human games.
Stage 2: Reinforcement learning through self-play. Starting from the supervised policy network, DeepMind then used RL to improve it further. The network played millions of games against earlier versions of itself. After each game, the policy was updated to make winning moves more likely and losing moves less likely. This self-play RL stage pushed AlphaGo significantly beyond the level of the human games it had originally learned from.
AlphaGo also trained a value network β a separate neural network that, given a board position, estimated the probability of winning from that position. This value network replaced the random rollouts in MCTS. Instead of playing a random game to the end to estimate a positionβs value, AlphaGo could ask its value network: βHow likely am I to win from here?β
The final AlphaGo system combined MCTS with these two neural networks: the policy network guided which branches to explore (replacing the need to look at every possible move), and the value network evaluated positions at the leaves of the search tree (replacing random rollouts). This was far more efficient and far more accurate than either component alone.
AlphaGoβs victory over Lee Sedol in March 2016 was a watershed moment. Go had been considered a βgrand challengeβ for AI for decades. But there was a nagging asterisk: AlphaGo still started from human data. It needed those 30 million expert positions to bootstrap its learning. What if the human experts were wrong about something? What if there were strategies that no human had ever discovered? The supervised learning stage was both a crutch and a ceiling.
AlphaGo Zero: tabula rasa (2017)
In October 2017, DeepMind published a paper that answered the question: What happens if you remove the human data entirely?
AlphaGo Zero started with no human games, no handcrafted features, and no human knowledge beyond the rules of Go. The networkβs weights were initialized randomly. It knew nothing about Go strategy β not that the center is important, not that connected stones are strong, not that you should surround territory. Nothing.
It learned everything from self-play alone. And within 40 days, it surpassed every previous version of AlphaGo, including the version that defeated Lee Sedol. It was not even close β AlphaGo Zero defeated the Lee Sedol version 100 games to 0.
Let that sink in. A system that started knowing literally nothing about Go, that learned entirely by playing against itself, crushed a system that had been bootstrapped from millions of human expert games. Human expert knowledge was not just unnecessary β it was actively holding the system back.
Why human data can be a ceiling
This result is counterintuitive. Surely starting with human knowledge should help? You get a head start. You donβt have to waste time rediscovering basic strategy. How could it be worse to start with expert data?
The answer lies in the nature of imitation learning. When you train on human expert games, you are teaching the network to predict what a human would do. But βwhat a human would doβ is not the same as βthe objectively best move.β Human play reflects human biases, conventions, fashions, and limitations.
Consider a few ways human data can limit an AI:
Blind spots. If no human has ever played a particular style of move, that move will be absent from the training data. The network will learn to assign it near-zero probability, even if it is actually strong. Human Go players avoided certain moves because they violated conventional wisdom β but some of those moves turn out to be excellent when analyzed deeply enough. AlphaGo Zero discovered several such moves that professional players had dismissed for centuries.
Convergent play styles. Strong human players tend to play similarly to each other because they study the same games, read the same books, and follow the same meta-game. Training on this data produces a network that imitates this convergent style. It cannot easily break out of the mold to find something genuinely novel.
Suboptimal consensus. Sometimes the entire human community is wrong about a position or a strategy. If every expert agrees that a particular opening is weak, the training data will reflect that consensus β even if the consensus is wrong. A self-play system has no such preconceptions.
Distributional mismatch. The states a strong human encounters during a game are not the same states the AI will encounter once it starts improving beyond human level. The AI ends up in positions that no human has ever seen, and its policy β trained only on human-encountered positions β has no useful guidance for these alien board states.
A self-play system avoids all of these problems. It is not trying to imitate anyone. It is trying to win. It explores whatever strategies lead to victories, no matter how unconventional. Its training data comes from its own games, so the distribution of positions it trains on always matches the positions it actually encounters during play.
The self-play insight
The core idea of self-play is deceptively simple: the agent generates its own training data by playing against itself.
Here is how it works in broad strokes:
- Start with an agent that plays randomly (or very poorly).
- Have the agent play a game against a copy of itself.
- Use the gameβs outcome (who won, who lost) and the move-by-move decisions as training data.
- Update the agentβs neural network using this data, making it slightly better.
- Go back to step 2 with the improved agent.
At first, both sides play terribly. The games are random chaos. But even in random chaos, one side wins and the other loses, and that outcome provides a training signal. The network learns, very slightly, which positions tend to lead to wins and which tend to lead to losses. It learns, very slightly, which moves are better than others.
Now something remarkable happens. The slightly-improved agent plays against itself again. Because it is slightly better, it generates slightly better training data. The games are slightly less random, slightly more strategic. The agent learns from this better data and improves further. Which produces even better data. Which produces even more improvement.
The virtuous cycle
This is the key dynamic that makes self-play so powerful β a virtuous cycle where improvement in the agentβs policy leads to higher-quality training data, which leads to further improvement:
βββββββββββββββββββββββββββββββββββββββββββββββ
β β
β Better policy β
β β β
β βΌ β
β Higher-quality self-play games β
β β β
β βΌ β
β More informative training data β
β β β
β βΌ β
β Better neural network β
β β β
β βΌ β
β Better policy ββββ (cycle repeats) βββ β
β β β
βββββββββββββββββββββββββββββββββββββββββββ β
β² β
βββββββββββββββββββββββββββββββββββββ
This is in sharp contrast to supervised learning, where the training data is fixed. No matter how good the model gets, it is still learning from the same dataset. The quality of the data is a hard ceiling on the quality of the model. In self-play, the ceiling rises with the agent. The data gets better because the agent gets better.
There is an important subtlety here: why doesnβt the agent just learn to exploit weaknesses in its own play, going around in circles? The answer is that MCTS provides a stabilizing force. Even when the neural networkβs policy is weak, MCTS improves upon it by running hundreds or thousands of simulations. The search acts as a βpolicy improvement operatorβ β the MCTS-enhanced policy is always stronger than the raw neural network policy. When the network is then trained to match the MCTS-improved policy, it genuinely gets better. The combination of search and learning prevents the kind of circular, degenerate strategies that might arise from naive self-play.
The AlphaGo Zero architecture
Letβs look at how AlphaGo Zero puts all of this together. The architecture is elegant in its simplicity β it unifies several components that were separate in earlier systems.
A single neural network with two heads
The original AlphaGo used two separate networks: a policy network (which moves to play) and a value network (who is winning). AlphaGo Zero combined these into a single neural network with two output heads:
-
Policy head: Given the current board state, output a probability distribution over all legal moves. βPlay here with 35% probability, here with 20%, here with 15%β¦β This guides the MCTS search toward promising moves.
-
Value head: Given the current board state, output a single number between -1 and +1 estimating who is winning. +1 means βI am definitely winning,β -1 means βI am definitely losing,β and 0 means βthe position is roughly equal.β
Both heads share the same underlying neural network (a deep residual network, or ResNet), which processes the board state into a rich internal representation. The two heads are just different βlensesβ on that same representation β one for suggesting moves, one for evaluating positions.
MCTS guided by the neural network
In vanilla MCTS (as we described in section 2), the simulation phase uses random rollouts β random moves played to the end of the game. AlphaGo Zero replaces this entirely. There are no random rollouts. Instead:
-
During selection, MCTS uses a modified UCT formula that incorporates the neural networkβs policy output. Moves that the network thinks are promising get a higher prior probability, so MCTS explores them first. This is far more efficient than treating all moves equally.
-
During evaluation (replacing simulation), when MCTS reaches a leaf node, it does not play a random game to completion. Instead, it queries the neural networkβs value head to get an instant estimate of the positionβs value. This is faster and far more accurate than a random rollout, especially in complex games where random play is essentially noise.
The result is an MCTS that is guided by learned intuition. The network says βthese moves look promising and this position looks good for me,β and MCTS uses that guidance to search much more efficiently. But MCTS also corrects the network β if the network thinks a move is good but search reveals it leads to a loss, MCTS will discover that through deeper exploration.
The training loop
Here is the AlphaGo Zero training loop in its entirety:
Step 1: Self-play. The current neural network plays games against itself, using MCTS to select moves. During each game, at every move, the system records three things: the board state, the MCTS visit counts for each move (which represent the search-improved policy), and the eventual game outcome (win or loss).
Step 2: Train. The neural network is updated to better match the self-play data. Specifically, it is trained to:
- Make the policy headβs output match the MCTS visit count distribution (because MCTS, with its search, has a better sense of which moves are good than the raw network does)
- Make the value headβs output match the actual game outcome (because the final result tells us who really was winning)
Step 3: Evaluate. Periodically, the newly trained network plays a match against the previous best network. If it wins convincingly (by a specified margin), it becomes the new best network and replaces the old one for future self-play games.
Step 4: Repeat. Go back to step 1 with the improved network.
This loop ran for roughly 4.9 million self-play games over 40 days on 4 TPUs (Googleβs custom machine learning accelerators). The network steadily improved throughout.
The breakthrough results
The results from the 2017 paper (βMastering the Game of Go without Human Knowledge,β published in Nature) were striking:
Surpassed all previous versions. After just 36 hours of training, AlphaGo Zero surpassed the version of AlphaGo that defeated Lee Sedol. After 72 hours, it surpassed AlphaGo Master (an improved version that had defeated the worldβs top professionals 60-0 in online blitz games). By the end of training, it defeated AlphaGo Master 89-11 in a 100-game match.
Learned from scratch. The system started with completely random play. In the early hours of training, it played nonsensically. Within a few hours, it had discovered basic Go concepts β influence, territory, life and death. Within days, it had developed sophisticated strategic understanding. It rediscovered many patterns of play that took humans thousands of years to develop, and it also discovered novel strategies that professional players had never considered.
Simpler architecture. Counterintuitively, AlphaGo Zero used a simpler architecture than the original AlphaGo. One network instead of two. No handcrafted features β just raw stone positions on the board. No random rollouts. And yet it was dramatically stronger. Simplicity and generality turned out to be advantages, not handicaps.
More efficient. Despite learning from scratch, AlphaGo Zero required less computation overall than the original AlphaGo, which had been trained on human data plus RL. Starting from a blank slate actually made learning easier, because the system was not constrained by the biases and limitations of human play.
A year later, DeepMind published AlphaZero (2018), which applied the same approach to chess, shogi (Japanese chess), and Go simultaneously β the same algorithm, the same neural network architecture, the same hyperparameters. Within hours of training, AlphaZero achieved superhuman performance in all three games, defeating the strongest existing programs in each: Stockfish in chess, Elmo in shogi, and AlphaGo Zero itself in Go. This demonstrated that the self-play approach was not Go-specific β it was a general recipe for mastering two-player perfect-information games.
Why this matters beyond Go
The AlphaGo Zero result matters not just because it plays Go well, but because of what it demonstrates as a principle:
You do not need human knowledge to achieve superhuman performance. Given the right architecture (neural network + MCTS), the right training signal (self-play outcomes), and enough computation, an agent can discover strategies that exceed the best human understanding. The human experts are not the ceiling β they are a waypoint.
Self-generated data can be superior to human-curated data. This challenges the assumption, common in machine learning, that more data (especially human-labeled data) is always better. For certain problems, an agentβs own experience β tailored to its own level of play, covering the exact situations it encounters β is more valuable than any static dataset.
Search and learning are complementary. MCTS alone (with random rollouts) plays Go at an amateur level. A neural network alone (without search) plays at a strong amateur level. But MCTS guided by a neural network, with the network trained on MCTS-improved data, produces superhuman play. The whole is vastly greater than the sum of its parts.
How we will adapt this for Tic-Tac-Toe
In the remaining sections of this course, we will build a miniature version of the AlphaGo Zero system. The game will be Tic-Tac-Toe instead of Go, the neural network will be small and simple instead of a massive ResNet, and training will take seconds instead of days. But the core architecture and training loop will be the same:
-
A game engine (Part 2) that represents Tic-Tac-Toe state, enumerates legal moves, and detects terminal positions β the βenvironmentβ in RL terms.
-
An MCTS implementation (Part 3) that searches the game tree, initially using random rollouts, just as vanilla MCTS does.
-
A neural network with policy and value heads (Part 4) that takes a board state as input and outputs move probabilities and a position evaluation. We will integrate this with MCTS, replacing random rollouts with the value head and using the policy head to guide search.
-
A self-play training loop (Part 5) where the system plays Tic-Tac-Toe against itself, generates training data from those games, trains the neural network on that data, and repeats. We will watch the agent improve from completely random play to strong (ideally perfect) Tic-Tac-Toe.
Tic-Tac-Toe is a solved game β perfect play from both sides always results in a draw. This actually makes it an ideal testbed. We know what the βright answerβ looks like, so we can measure exactly how close our agent gets to optimal play. And because the game is small (only 9 squares, games last at most 9 moves), training is fast enough to iterate and experiment on a single laptop.
The AlphaGo Zero recipe is remarkably general. The same ideas that produce superhuman Go play will produce strong Tic-Tac-Toe play. And understanding the recipe at this small scale β where you can inspect every node in the search tree and every weight in the neural network β will give you genuine insight into how and why it works at large scale.
In the next section, we will start building. First up: choosing our game and understanding what properties it needs to have for this approach to work.
Part 2 β The Game
4. Choosing a simple game: Tic-Tac-Toe
We ended Part 1 with a grand vision: a system that masters a game purely through self-play, combining MCTS with a neural network in the style of AlphaGo Zero. Before we start writing code, we need to pick a game. The choice matters more than it might seem β the wrong game can bury the core ideas under irrelevant complexity, while the right game keeps the focus exactly where it belongs.
Why start with a toy problem
Complex algorithms are hard to debug. When your self-play agent produces nonsensical moves after training, you want to know whether the bug is in your game logic, your MCTS implementation, your neural network architecture, or your training loop. If your game is Go (a 19x19 board with roughly 10^170 legal game positions), you cannot feasibly enumerate the game tree to check correctness. If your game is Tic-Tac-Toe, you can.
Toy problems are not a detour β they are a microscope. At small scale, you can inspect every node that MCTS visits, print the entire game tree, verify that the neural networkβs policy output assigns probability to every legal move, and confirm that the value headβs evaluation matches the true game-theoretic value of each position. Once you understand the system at this resolution, scaling up to larger games becomes an engineering problem rather than a conceptual one.
AlphaZeroβs authors themselves understood this principle. The same algorithm that defeated the world champion in Go was also tested on Tic-Tac-Toe-like games during development. Start simple, get it working, then scale.
What properties does a game need?
Not every game is suitable for the AlphaGo Zero approach. The framework we described in Part 1 β MCTS guided by a neural network, trained through self-play β relies on the game having several specific properties:
Two-player. Self-play requires exactly two players taking alternating roles. The agent plays both sides, generating training data from both perspectives. Single-player optimization problems and multiplayer (3+) games introduce complications that would distract from the core loop.
Zero-sum. One playerβs gain is the otherβs loss. When the agent wins as Player X, that same game is a loss from Player Oβs perspective. This gives us a clean training signal: every game produces data for both βwinningβ and βlosingβ play, and the value of any position to one player is the exact negative of its value to the other.
Deterministic. No dice, no card draws, no random events. Given a state and a move, the next state is completely determined. This means MCTS can search the game tree without needing to account for chance nodes, and the neural networkβs value estimate is a function of the position alone, not of hidden randomness.
Perfect information. Both players can see the entire game state at all times. There are no hidden cards, no fog of war. This is essential because our neural network takes the full board state as input β if information were hidden, we would need a fundamentally different architecture to handle uncertainty about the opponentβs private state.
Finite. The game always terminates after a bounded number of moves. Every branch of the game tree ends in a win, loss, or draw. This guarantees that MCTS rollouts terminate and that the game-theoretic value of every position is well-defined. (Contrast this with a game like Peg Solitaire variants where cycles might arise without special rules.)
These five properties β two-player, zero-sum, deterministic, perfect information, finite β define the class of games that the AlphaGo Zero approach handles directly. Chess, Go, Othello, Connect Four, and Tic-Tac-Toe all qualify. Poker does not (imperfect information). Backgammon does not (dice introduce randomness). Risk does not (multiplayer, random).
Tic-Tac-Toe: the rules
You almost certainly know this game already, but let us state the rules precisely, since we will need to implement them exactly:
- The board is a 3x3 grid, initially empty.
- Two players alternate turns. Player X moves first, Player O moves second.
- On your turn, you place your mark (X or O) in any empty cell.
- The first player to get three of their marks in a row β horizontally, vertically, or diagonally β wins.
- If all nine cells are filled and neither player has three in a row, the game is a draw.
That is the entire game. No special moves, no optional rules, no edge cases. This simplicity is a feature, not a limitation.
The game tree: small enough to verify, large enough to matter
How big is Tic-Tac-Toe, computationally?
If we count naively β the number of complete sequences of moves from the start to a full board, ignoring early termination β there are 9! = 362,880 possible games. But many games end before the board is full (someone wins), and many positions can be reached by different move orders. The actual numbers:
- ~255,168 possible games (unique sequences of moves from start to a terminal state, counting games that end early due to a win).
- 5,478 unique board states (distinct positions reachable during play, after accounting for the identity of which cells are filled but before considering symmetry).
- ~765 unique states after symmetry reduction (treating rotations and reflections as equivalent).
For comparison, chess has roughly 10^44 legal positions. Go has roughly 10^170. Tic-Tac-Toeβs entire game tree fits comfortably in memory on any modern machine. You can enumerate every state, compute the minimax value of every position, and verify that your MCTS agent converges to the correct answer. This is extraordinarily valuable when debugging.
At the same time, 5,478 states (or even 765 after symmetry) is not trivially small. It is large enough that a neural network has something meaningful to learn, MCTS has a tree worth searching, and the training loop has real work to do. You cannot solve it by just memorizing a lookup table during training β the network has to generalize.
The known solution: a built-in answer key
Tic-Tac-Toe is a solved game. Game theorists have exhaustively analyzed every possible position and determined the optimal move in each. The result: with perfect play from both sides, the game always ends in a draw. Player X (who moves first) cannot force a win against a perfect opponent, and Player O cannot force a win either.
This is enormously useful for us. We have a ground truth to validate our AI against:
- After training, our agent should never lose when playing against a perfect opponent (or against itself).
- Every game between two copies of a fully trained agent should end in a draw.
- The value head of the neural network should evaluate the starting position as approximately 0 (a draw), not as favoring either player.
- The policy head should never assign high probability to moves that are known blunders in solved positions.
Without this ground truth, we would have no way to know whether our trained agent was merely good or actually optimal. With Tic-Tac-Toe, we can measure exactly.
Why not a more interesting game?
You might wonder: why not use Connect Four, Othello, or even a small-board variant of Go? These are all valid targets for the AlphaGo Zero approach, and we will mention them as stretch goals later. But each adds complexity that makes it harder to learn the algorithm:
Connect Four (7 columns, 6 rows, four in a row to win) has roughly 4.5 trillion possible positions. It is also solved β Player 1 can force a win with perfect play β but the game tree is too large to exhaustively verify your implementation against. Training takes minutes to hours rather than seconds. Bugs become much harder to diagnose.
Othello (8x8 board, pieces flip) introduces a more complex move-legality rule (you must flip at least one opponent piece) and a board evaluation that shifts dramatically throughout the game. The neural network needs to be meaningfully larger to capture the patterns. This is interesting but distracting when you are trying to understand MCTS and self-play.
Small-board Go (e.g., 5x5 or 9x9) is closer to the original AlphaGo Zero domain, but Goβs rules β liberties, captures, ko, scoring β are substantially more complex to implement correctly than Tic-Tac-Toeβs. Getting the game engine right becomes a project in itself.
The principle is simple: learn the algorithm on the easiest possible game, then scale up. Once your self-play loop is producing a perfect Tic-Tac-Toe agent, adapting it to Connect Four requires only changing the game engine and enlarging the neural network. The MCTS code, the training loop, and the self-play pipeline remain the same. Tic-Tac-Toe is where we learn; other games are where we apply.
Mapping Tic-Tac-Toe to RL concepts
In Β§1, we introduced the core RL vocabulary: agent, environment, state, action, reward, and policy. Now that we have a concrete game, we can pin each concept to something specific:
| RL concept | Tic-Tac-Toe equivalent |
|---|---|
| Environment | The game rules β the 3x3 grid and the logic for placing marks, detecting wins, and detecting draws. |
| State | A specific board configuration: which cells contain X, which contain O, which are empty, and whose turn it is. |
| Action | Placing a mark in an empty cell. There are at most 9 possible actions, decreasing as the board fills. |
| Legal actions | The subset of empty cells in the current state. Our game engine must enumerate these. |
| Terminal state | A state where someone has three in a row (win/loss) or the board is full (draw). |
| Reward | +1 for a win, -1 for a loss, 0 for a draw. Only assigned at terminal states. |
| Policy | A function that takes a board state and returns a probability distribution over the legal moves. This is what the neural networkβs policy head will learn. |
| Value | A function that takes a board state and returns an estimate of the expected outcome (between -1 and +1). This is what the neural networkβs value head will learn. |
In the next section, we will translate these concepts into Rust data structures: an enum for cell contents, a struct for the board state, methods for making moves and checking for wins, and an interface that MCTS can plug into. The mapping above is the bridge between the theory of Part 1 and the code of Part 2.
5. Representing game state in Rust
In Β§4, we mapped reinforcement learning concepts onto Tic-Tac-Toe: state, action, terminal condition, reward. Now we translate that mapping into Rust data structures. By the end of this section, you will have a complete, compilable set of types and methods that represent a Tic-Tac-Toe game β the foundation that MCTS will plug into later.
We will build the code incrementally. Each piece is small, self-contained, and compiles on its own.
Choosing a board representation
There are two common ways to represent a Tic-Tac-Toe board in code:
Bitboard. Use two u16 values β one bitmask for Xβs marks and one for Oβs. Each of the 9 cells maps to a bit position. Checking for a winner becomes a handful of bitwise AND operations against precomputed winning masks. This approach is compact (two integers per board) and fast (winner detection is branchless), but it is harder to read and debug. Bitboards shine in games like chess where the board is large and performance-critical search must evaluate millions of positions per second.
Array of options. Use a fixed-size array of 9 elements, where each element is either Some(Player::X), Some(Player::O), or None (empty). This representation maps directly to how a human thinks about the board. Indexing, display, move generation, and winner checking are all straightforward loops and matches. The cost is a bit more memory and a few more branches β completely negligible for a 3x3 board.
We will use the array approach. Clarity matters more than performance at this scale, and the code will be much easier to follow as we layer MCTS and neural network integration on top. If you later scale up to a larger game, switching to a bitboard is a localized change β the interface stays the same.
The Player enum
A cell on the board is either empty, occupied by X, or occupied by O. We model the occupant as a Player enum:
#![allow(unused)]
fn main() {
/// One of the two players in a Tic-Tac-Toe game.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player {
X,
O,
}
impl Player {
/// Returns the other player.
///
/// This is used after each move to switch turns, and when
/// computing rewards from the opponent's perspective.
pub fn opponent(self) -> Player {
match self {
Player::X => Player::O,
Player::O => Player::X,
}
}
}
impl std::fmt::Display for Player {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Player::X => write!(f, "X"),
Player::O => write!(f, "O"),
}
}
}
}
A few things to note:
- We derive
CopybecausePlayeris a small enum with no heap data. Passing it around by value is cheaper than borrowing. - We derive
Hashbecause later we will want to use game states as keys in hash maps (for MCTS node lookup). - The
opponentmethod will appear everywhere: switching turns, computing rewards from the other playerβs perspective, and building the self-play loop.
The GameState struct
A game state is a board plus whose turn it is:
#![allow(unused)]
fn main() {
/// A complete snapshot of a Tic-Tac-Toe game at a moment in time.
///
/// The board is a flat array of 9 cells. `current_player` indicates
/// whose turn it is to move. A fresh game starts with an empty board
/// and `Player::X` to move.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
/// The 3x3 board stored as a flat array.
/// Index 0 is the top-left cell; index 8 is the bottom-right.
pub board: [Option<Player>; 9],
/// The player who will make the next move.
pub current_player: Player,
}
impl GameState {
/// Creates a new game with an empty board. X moves first.
pub fn new() -> Self {
GameState {
board: [None; 9],
current_player: Player::X,
}
}
}
}
The board is [Option<Player>; 9] β exactly the type we discussed. None means empty, Some(Player::X) means X occupies that cell. The struct owns all its data (no references, no lifetimes), so it is easy to clone, compare, and hash.
Board indexing
We store the board as a flat array, but we think about it as a grid with rows and columns:
0 | 1 | 2
-----------
3 | 4 | 5
-----------
6 | 7 | 8
The conversion is straightforward:
- From (row, col) to index:
index = row * 3 + col - From index to (row, col):
row = index / 3,col = index % 3
We do not need separate functions for this β the arithmetic is simple enough to inline β but it helps to be explicit about the layout. Index 0 is the top-left corner, index 4 is the center, and index 8 is the bottom-right corner. Row 0 is the top row, column 0 is the left column.
This flat layout is also the format we will use when encoding the board as input to the neural network: 9 values in a fixed order.
Displaying the board
A Display implementation makes debugging and interactive play much easier. We want output that looks like a human-readable grid:
#![allow(unused)]
fn main() {
impl std::fmt::Display for GameState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
match self.board[index] {
Some(Player::X) => write!(f, " X ")?,
Some(Player::O) => write!(f, " O ")?,
None => write!(f, " . ")?,
}
if col < 2 {
write!(f, "|")?;
}
}
writeln!(f)?;
if row < 2 {
writeln!(f, "-----------")?;
}
}
Ok(())
}
}
}
A board mid-game might display as:
X | . | O
-----------
. | X | .
-----------
. | . | .
Using . for empty cells avoids ambiguity with whitespace and makes it easy to see the board structure at a glance.
Representing moves
A move in Tic-Tac-Toe is simply placing a mark in an empty cell. Since the board is a flat array of 9 cells, a move is just an index from 0 to 8. There is no need for a dedicated struct β a usize carries all the information we need.
We could wrap it in a newtype (struct Move(usize)) for type safety, but the overhead is not worth it for a game this simple. Every function that takes a move will also take or have access to the game state, which provides all the context needed to interpret the index.
Generating legal moves
A move is legal if and only if the target cell is empty. Generating all legal moves is a matter of iterating the board and collecting the indices of None cells:
#![allow(unused)]
fn main() {
impl GameState {
/// Returns the indices of all empty cells β the legal moves
/// in this position.
///
/// In a terminal state (someone has won or the board is full),
/// this returns an empty vector.
pub fn legal_moves(&self) -> Vec<usize> {
// If the game is already over, no moves are legal.
if self.winner().is_some() {
return Vec::new();
}
self.board
.iter()
.enumerate()
.filter(|(_, cell)| cell.is_none())
.map(|(index, _)| index)
.collect()
}
}
}
This method is called frequently by MCTS β once per node expansion. For a 3x3 board the cost is trivial, but note that we check for a winner first. Once someone has won, no further moves should be legal, even if empty cells remain.
Checking for a winner
There are exactly 8 ways to win a Tic-Tac-Toe game: 3 rows, 3 columns, and 2 diagonals. We enumerate them explicitly:
#![allow(unused)]
fn main() {
/// The eight lines that constitute a win: three rows, three columns,
/// and two diagonals. Each line is a triple of board indices.
const WIN_LINES: [[usize; 3]; 8] = [
// Rows
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
// Columns
[0, 3, 6],
[1, 4, 7],
[2, 5, 8],
// Diagonals
[0, 4, 8],
[2, 4, 6],
];
impl GameState {
/// Returns the winner if one exists.
///
/// A player wins by occupying all three cells in any row,
/// column, or diagonal. Returns `None` if no player has
/// three in a row.
pub fn winner(&self) -> Option<Player> {
for line in &WIN_LINES {
let [a, b, c] = *line;
if let (Some(p1), Some(p2), Some(p3)) =
(self.board[a], self.board[b], self.board[c])
{
if p1 == p2 && p2 == p3 {
return Some(p1);
}
}
}
None
}
}
}
The WIN_LINES constant is defined outside the impl block β it is pure data with no connection to any particular game state. The winner method checks each line: if all three cells are occupied by the same player, that player has won. If no line is fully claimed, the game has no winner (yet, or ever, in the case of a draw).
Why a const array rather than computing the lines from row/column arithmetic? Explicitness. There are only 8 lines. Writing them out makes the logic immediately verifiable by inspection β you can count the rows, columns, and diagonals yourself. No off-by-one errors to worry about.
Terminal state detection
A state is terminal when either someone has won or the board is full (a draw). We also define a method to check whether the game is a draw specifically:
#![allow(unused)]
fn main() {
impl GameState {
/// Returns `true` if the game is over β either a player has won
/// or the board is completely filled (a draw).
pub fn is_terminal(&self) -> bool {
self.winner().is_some() || self.board.iter().all(|cell| cell.is_some())
}
/// Returns `true` if the game ended in a draw: the board is full
/// and no player has three in a row.
pub fn is_draw(&self) -> bool {
self.winner().is_none() && self.board.iter().all(|cell| cell.is_some())
}
}
}
These two methods, combined with winner, give us everything we need to evaluate terminal states:
winner()returnsSome(player)if someone has won.is_draw()returnstrueif the board is full with no winner.is_terminal()returnstruein either case.
MCTS will call is_terminal to know when to stop a rollout, and winner to assign the reward (+1 for a win, -1 for a loss, 0 for a draw).
Applying moves: the immutability design
When MCTS explores the game tree, it needs to βtryβ a move and see what state results. There are two ways to design this:
Mutable approach: modify the board in place, then undo the move when backtracking. This is memory-efficient (one board allocation for the entire search) but error-prone. Forgetting to undo a move β or undoing it in the wrong order β produces subtle, hard-to-diagnose bugs. The bookkeeping complexity grows with game complexity.
Immutable approach: apply_move takes a state and a move, returns a new state with the move applied. The original state is unchanged. This is conceptually cleaner β each node in the search tree corresponds to an independent GameState value. Cloning a 9-element array is cheap. For Tic-Tac-Toe, the memory cost is negligible.
We choose the immutable approach:
#![allow(unused)]
fn main() {
impl GameState {
/// Returns a new `GameState` with the given move applied.
///
/// The move is an index (0-8) into the board. The current player's
/// mark is placed in that cell, and the turn switches to the
/// opponent.
///
/// # Panics
///
/// Panics if the cell is already occupied or the index is out of
/// bounds. In a correct program, `apply_move` is only called with
/// indices returned by `legal_moves`.
pub fn apply_move(&self, index: usize) -> GameState {
assert!(
self.board[index].is_none(),
"Cell {} is already occupied",
index
);
let mut new_board = self.board;
new_board[index] = Some(self.current_player);
GameState {
board: new_board,
current_player: self.current_player.opponent(),
}
}
}
}
Notice that let mut new_board = self.board copies the array (it implements Copy), so we are not mutating the original. The new state has the mark placed and the turn switched to the opponent.
The assert! is a deliberate choice: calling apply_move on an occupied cell is always a bug in the caller. Panicking immediately makes such bugs loud and easy to find. In production game engines, you might return a Result instead, but for a learning project, panicking on programmer error is the right call.
Putting it all together
Here is the complete module, assembled from the pieces above. This is a single, compilable Rust file:
#![allow(unused)]
fn main() {
use std::fmt;
/// One of the two players in a Tic-Tac-Toe game.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player {
X,
O,
}
impl Player {
/// Returns the other player.
pub fn opponent(self) -> Player {
match self {
Player::X => Player::O,
Player::O => Player::X,
}
}
}
impl fmt::Display for Player {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Player::X => write!(f, "X"),
Player::O => write!(f, "O"),
}
}
}
/// The eight winning lines: three rows, three columns, two diagonals.
const WIN_LINES: [[usize; 3]; 8] = [
[0, 1, 2], [3, 4, 5], [6, 7, 8], // rows
[0, 3, 6], [1, 4, 7], [2, 5, 8], // columns
[0, 4, 8], [2, 4, 6], // diagonals
];
/// A complete snapshot of a Tic-Tac-Toe game.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
pub board: [Option<Player>; 9],
pub current_player: Player,
}
impl GameState {
/// Creates a new game with an empty board. X moves first.
pub fn new() -> Self {
GameState {
board: [None; 9],
current_player: Player::X,
}
}
/// Returns the winner, if any.
pub fn winner(&self) -> Option<Player> {
for line in &WIN_LINES {
let [a, b, c] = *line;
if let (Some(p1), Some(p2), Some(p3)) =
(self.board[a], self.board[b], self.board[c])
{
if p1 == p2 && p2 == p3 {
return Some(p1);
}
}
}
None
}
/// Returns `true` if the game is over.
pub fn is_terminal(&self) -> bool {
self.winner().is_some()
|| self.board.iter().all(|cell| cell.is_some())
}
/// Returns `true` if the game is a draw.
pub fn is_draw(&self) -> bool {
self.winner().is_none()
&& self.board.iter().all(|cell| cell.is_some())
}
/// Returns the legal moves (indices of empty cells).
pub fn legal_moves(&self) -> Vec<usize> {
if self.winner().is_some() {
return Vec::new();
}
self.board
.iter()
.enumerate()
.filter(|(_, cell)| cell.is_none())
.map(|(index, _)| index)
.collect()
}
/// Returns a new state with the move applied.
///
/// # Panics
///
/// Panics if the cell is already occupied.
pub fn apply_move(&self, index: usize) -> GameState {
assert!(self.board[index].is_none(), "Cell {} is occupied", index);
let mut new_board = self.board;
new_board[index] = Some(self.current_player);
GameState {
board: new_board,
current_player: self.current_player.opponent(),
}
}
}
impl fmt::Display for GameState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
match self.board[index] {
Some(Player::X) => write!(f, " X ")?,
Some(Player::O) => write!(f, " O ")?,
None => write!(f, " . ")?,
}
if col < 2 {
write!(f, "|")?;
}
}
writeln!(f)?;
if row < 2 {
writeln!(f, "-----------")?;
}
}
Ok(())
}
}
}
What we have built
Let us take stock. We now have a game representation that provides everything MCTS needs:
| RL concept | Rust implementation |
|---|---|
| State | GameState β board contents + current player |
| Action | usize β index into the board (0-8) |
| Legal actions | GameState::legal_moves() β indices of empty cells |
| Transition | GameState::apply_move(index) β returns a new state |
| Terminal test | GameState::is_terminal() β win or draw |
| Reward | Derived from GameState::winner() β +1, -1, or 0 |
This interface is deliberately minimal. There is no game loop, no input handling, no AI β just the pure game logic. In Β§6, you will use these types to implement a complete playable game. In Β§7, MCTS will call legal_moves, apply_move, is_terminal, and winner to search the game tree. The neural network in Part 4 will take the board array as its input.
The immutable design means every GameState is a self-contained snapshot. You can store them in a hash map, compare them for equality, collect them into training datasets, and pass them between functions without worrying about aliasing or mutation bugs. This simplicity will pay dividends as the system grows more complex.
6. Exercise 1: implement the game logic
Time to get your hands dirty. In this exercise you will create a Rust project, copy in the game module from Β§5, write unit tests to prove it works, and then build a small main function that plays a random game so you can watch the board evolve turn by turn.
Step 1: create the project
cargo new ttt-selfplay
cd ttt-selfplay
This gives you a standard Cargo project with src/main.rs. You will add one more file.
Step 2: create the game module
Create a new file src/game.rs and paste in the complete module from the end of Β§5 β everything from use std::fmt; through the Display impl for GameState. The code is self-contained and has no external dependencies.
Then tell Rust about the module by adding this line at the top of src/main.rs:
#![allow(unused)]
fn main() {
mod game;
}
Verify that it compiles:
cargo check
If you see errors, make sure every item in game.rs that needs to be used from outside the module is marked pub (all types and methods in the Β§5 listing already are).
Step 3: write unit tests
Add a test module at the bottom of src/game.rs:
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
use super::*;
// Your tests go here.
}
}
Write the following eight tests. The names and descriptions tell you what to assert β try writing the bodies yourself before expanding the solution.
test_new_game_is_empty β A freshly created GameState::new() should have all nine cells set to None, and the current player should be Player::X.
test_legal_moves_initial β On a new board, legal_moves() should return all nine indices (0 through 8).
test_apply_move β After calling apply_move(4) on a new game, cell 4 should contain Some(Player::X) and the current player should have switched to Player::O.
test_winner_row β Set up a board where X occupies the top row (cells 0, 1, 2) and verify that winner() returns Some(Player::X).
test_winner_col β Set up a board where O occupies the left column (cells 0, 3, 6) and verify that winner() returns Some(Player::O).
test_winner_diagonal β Set up a board where X occupies the main diagonal (cells 0, 4, 8) and verify that winner() returns Some(Player::X).
test_draw β Play out a full game that ends in a draw. Verify that is_terminal() is true, winner() is None, and is_draw() is true.
test_no_moves_after_win β From a state where someone has won, legal_moves() should return an empty vector β even though empty cells remain on the board.
Run the tests with:
cargo test
All eight should pass. If any fail, re-read the Β§5 code β the logic is already correct, so the bug is in your test setup.
Solution β all eight tests
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_game_is_empty() {
let state = GameState::new();
for cell in &state.board {
assert_eq!(*cell, None);
}
assert_eq!(state.current_player, Player::X);
}
#[test]
fn test_legal_moves_initial() {
let state = GameState::new();
let moves = state.legal_moves();
assert_eq!(moves.len(), 9);
for i in 0..9 {
assert!(moves.contains(&i));
}
}
#[test]
fn test_apply_move() {
let state = GameState::new();
let next = state.apply_move(4);
assert_eq!(next.board[4], Some(Player::X));
assert_eq!(next.current_player, Player::O);
// Original state is unchanged (immutable design).
assert_eq!(state.board[4], None);
}
#[test]
fn test_winner_row() {
// X plays 0, 1, 2 (top row). O plays 3, 4.
let state = GameState::new()
.apply_move(0) // X
.apply_move(3) // O
.apply_move(1) // X
.apply_move(4) // O
.apply_move(2); // X β wins
assert_eq!(state.winner(), Some(Player::X));
}
#[test]
fn test_winner_col() {
// O plays 0, 3, 6 (left column). X plays 1, 4, 8.
let state = GameState::new()
.apply_move(1) // X
.apply_move(0) // O
.apply_move(4) // X
.apply_move(3) // O
.apply_move(8) // X
.apply_move(6); // O β wins
assert_eq!(state.winner(), Some(Player::O));
}
#[test]
fn test_winner_diagonal() {
// X plays 0, 4, 8 (main diagonal). O plays 1, 2.
let state = GameState::new()
.apply_move(0) // X
.apply_move(1) // O
.apply_move(4) // X
.apply_move(2) // O
.apply_move(8); // X β wins
assert_eq!(state.winner(), Some(Player::X));
}
#[test]
fn test_draw() {
// A classic draw: X O X / X X O / O X O
let state = GameState::new()
.apply_move(0) // X
.apply_move(1) // O
.apply_move(2) // X
.apply_move(4) // O
.apply_move(3) // X
.apply_move(5) // O
.apply_move(7) // X
.apply_move(6) // O
.apply_move(8); // X
assert!(state.is_terminal());
assert!(state.is_draw());
assert_eq!(state.winner(), None);
}
#[test]
fn test_no_moves_after_win() {
// X wins on the top row; cells 3-8 still have empties.
let state = GameState::new()
.apply_move(0) // X
.apply_move(3) // O
.apply_move(1) // X
.apply_move(4) // O
.apply_move(2); // X β wins
assert_eq!(state.winner(), Some(Player::X));
assert!(state.legal_moves().is_empty());
}
}
}
Step 4: play a random game
Replace the contents of src/main.rs with the following. It picks a random legal move each turn, prints the board, and announces the result.
mod game;
use game::{GameState, Player};
use rand::Rng;
fn main() {
let mut rng = rand::rng();
let mut state = GameState::new();
let mut turn = 0;
while !state.is_terminal() {
let moves = state.legal_moves();
let chosen = moves[rng.random_range(..moves.len())];
println!("Turn {}: {} plays cell {}", turn, state.current_player, chosen);
state = state.apply_move(chosen);
println!("{}", state);
turn += 1;
}
match state.winner() {
Some(Player::X) => println!("X wins!"),
Some(Player::O) => println!("O wins!"),
None => println!("Draw!"),
}
}
You will need the rand crate. Add it to Cargo.toml:
cargo add rand
Then run the program:
cargo run
You should see a full game play out with the board printed after every move, ending in a win or a draw. Run it a few times β you will see different outcomes each time.
Readiness checklist
Before moving on to Β§7, confirm:
-
cargo checkcompiles without errors -
cargo testpasses all eight unit tests -
cargo runplays a random game to completion and prints the result - You understand how
apply_movereturns a new state rather than mutating the old one - You can trace how
legal_movesreturns an empty vector after a win, even if empty cells remain
If everything checks out, your game engine is ready. In the next section, we will build an MCTS player that uses this exact interface to search for strong moves.
Part 3 β MCTS
7. Implementing MCTS in Rust
In Β§2, we walked through MCTS as an algorithm: selection, expansion, simulation, backpropagation, and the UCT formula that balances exploration against exploitation. In Β§5 and Β§6, we built a complete game engine with GameState, Player, legal_moves, apply_move, is_terminal, and winner. Now we bring those two threads together and implement MCTS in Rust.
This is the most substantial implementation section in the course. Take your time with it. By the end, you will have a working MCTS player that can choose strong moves in Tic-Tac-Toe β no neural network required.
The problem with tree pointers in Rust
Before we write any code, we need to talk about a design decision that trips up nearly everyone who implements tree structures in Rust for the first time.
MCTS operates on a tree. Each node has a parent and zero or more children. In a language like Python or Java, you would represent this with object references: each node holds a pointer to its parent and a list of pointers to its children. Trees are natural in garbage-collected languages because the garbage collector handles ownership and cleanup.
Rust does not have a garbage collector. It has ownership. And ownership means every value has exactly one owner. A tree node that is βownedβ by its parent cannot also be βownedβ by its children (who need a back-pointer to the parent). You immediately run into a conflict.
There are several ways to handle this in Rust:
Rc<RefCell<MctsNode>> β reference-counted, interior-mutable nodes. Each node is wrapped in Rc (shared ownership) and RefCell (runtime borrow checking). Children hold Rc pointers to the parent; the parent holds Rc pointers to its children. This works, but it is verbose, allocates each node separately on the heap, and introduces runtime borrow-checking overhead. Every access requires .borrow() or .borrow_mut() calls. Debugging borrow violations becomes a runtime problem instead of a compile-time one.
unsafe raw pointers. You can use *mut MctsNode pointers for parent references. This is how you might do it in C. It works, but you lose Rustβs safety guarantees and have to manually reason about pointer validity.
Arena allocation with indices. Store all nodes in a single Vec<MctsNode>. Instead of pointers, use usize indices into the vector. A nodeβs βparentβ is an index. Its βchildrenβ are a list of indices. The Vec owns all the nodes, and indices are just numbers β no ownership conflicts, no lifetimes, no Rc, no unsafe.
We will use the arena approach. It is the simplest, fastest, and most idiomatic way to build trees in Rust. Here is why:
- No lifetime annotations. The
Vecowns everything. Indices are plain integers that can be copied, stored, and compared freely. - Cache-friendly. All nodes live in a contiguous block of memory. Walking the tree means accessing elements in a vector, which is fast because modern CPUs love sequential memory access.
- Easy to debug. You can print the entire tree by iterating the vector. Each nodeβs index is its identity.
- No runtime overhead. Indexing a
Vecis a bounds-checked array access. No reference counting, no borrow checking, no pointer chasing.
The tradeoff is that you cannot remove nodes from the middle of the vector without invalidating indices (you would need a more sophisticated arena for that). For MCTS, this is not a problem β we only add nodes, never remove them.
The MctsNode struct
Each node in the MCTS tree represents a game state reached by a particular sequence of moves. It stores the statistics needed by UCT (visit count and total value), the connections to its parent and children, and the move that led to this position:
#![allow(unused)]
fn main() {
/// A single node in the MCTS search tree.
///
/// Nodes are stored in a flat `Vec` (arena). Parent and child
/// relationships are represented as indices into that `Vec`,
/// not as pointers or references.
#[derive(Debug, Clone)]
struct MctsNode {
/// The game state at this node.
state: GameState,
/// The move (board index) that was played to reach this state
/// from the parent. `None` for the root node.
move_from_parent: Option<usize>,
/// Index of the parent node in the arena. `None` for the root.
parent: Option<usize>,
/// Indices of child nodes in the arena.
children: Vec<usize>,
/// Number of times this node has been visited during search.
visit_count: u32,
/// Total accumulated value from all simulations through this node.
/// Positive means good for the player who *just moved* to reach
/// this state (i.e., the opponent of `state.current_player`).
total_value: f64,
/// The legal moves from this state that have not yet been expanded
/// into child nodes. When this list is empty, the node is "fully
/// expanded."
unexpanded_moves: Vec<usize>,
}
}
A few things deserve explanation:
move_from_parent stores the action that led here. When MCTS finishes and we want to know which move to play, we look at the rootβs children and read their move_from_parent values. The root itself has no parent move, so it is None.
unexpanded_moves is a key bookkeeping field. When we create a node, we populate this with all legal moves from the state. Each time we expand a child, we remove a move from this list. When the list is empty, the node is fully expanded and selection will use UCT to pick among existing children instead of expanding a new one.
total_value perspective. This is the trickiest part of MCTS to get right. The value stored at a node is from the perspective of the player who moved to reach this node β that is, the opponent of state.current_player. Why? Because when the parent node is choosing among its children using UCT, it needs to know how good each child is for the parentβs player. The parentβs player is the one who made the move to reach the child. So the childβs value should reflect how well things went for that player.
This will become clearer when we implement backpropagation.
The MCTS tree (arena)
The tree itself is just a Vec of nodes plus a reference to the root:
#![allow(unused)]
fn main() {
/// The MCTS search tree, stored as an arena of nodes.
struct MctsTree {
/// All nodes in the tree. The root is always at index 0.
nodes: Vec<MctsNode>,
}
}
The root is always at index 0 β we create it first, and the Vec never reorders elements. Letβs add a constructor:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Creates a new MCTS tree rooted at the given game state.
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
MctsTree {
nodes: vec![root],
}
}
}
}
When the root is created, we immediately compute its legal moves and store them in unexpanded_moves. The first few iterations of MCTS will expand these into child nodes.
The UCT formula
Before implementing the four phases, letβs write the UCT calculation as a standalone function. Recall from Β§2:
UCT(child) = (W / N) + C * sqrt(ln(N_parent) / N)
In Rust:
#![allow(unused)]
fn main() {
use std::f64::consts::SQRT_2;
/// Computes the UCT score for a child node.
///
/// - `child_value`: total value accumulated at the child
/// - `child_visits`: number of times the child has been visited
/// - `parent_visits`: number of times the parent has been visited
///
/// Returns `f64::INFINITY` if the child has never been visited
/// (ensuring unvisited children are selected first).
fn uct_score(child_value: f64, child_visits: u32, parent_visits: u32) -> f64 {
if child_visits == 0 {
return f64::INFINITY;
}
let n = child_visits as f64;
let exploitation = child_value / n;
let exploration = SQRT_2 * ((parent_visits as f64).ln() / n).sqrt();
exploitation + exploration
}
}
The exploration constant C = sqrt(2) comes from the theoretical UCB1 analysis and is a good default. Some implementations tune this value β a larger C explores more aggressively, a smaller C exploits more greedily. For Tic-Tac-Toe, sqrt(2) works well.
Note the special case: if a child has never been visited, we return infinity. This guarantees that unvisited children are always selected before any visited child, which means MCTS tries every move at least once before starting to use statistics to discriminate.
Phase 1: Selection
Selection starts at the root and walks down the tree, picking the child with the highest UCT score at each step. It stops when it reaches a node that has unexpanded moves (meaning we should expand rather than descend further) or a terminal state (meaning the game is over and there is nothing to expand).
#![allow(unused)]
fn main() {
impl MctsTree {
/// Selects a leaf node for expansion by walking from the root,
/// always choosing the child with the highest UCT score.
///
/// Returns the index of the selected node.
fn select(&self) -> usize {
let mut current = 0; // start at root
loop {
let node = &self.nodes[current];
// If this node has unexpanded moves, stop here β we will
// expand one of them.
if !node.unexpanded_moves.is_empty() {
return current;
}
// If this node is terminal (no children and no unexpanded
// moves), stop here β there is nothing to expand.
if node.children.is_empty() {
return current;
}
// All moves have been expanded. Pick the child with the
// highest UCT score.
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = uct_score(
child.total_value,
child.visit_count,
parent_visits,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
}
}
This is a straightforward loop, not recursion. We walk from the root toward the leaves, always following the highest-UCT child. The loop terminates because:
- If the current node has unexpanded moves, we return it immediately.
- If the current node is terminal (no children, no unexpanded moves), we return it.
- Otherwise, we descend to a child, which is deeper in the tree.
Since the tree is finite (Tic-Tac-Toe games end in at most 9 moves), we are guaranteed to reach case 1 or 2.
Phase 2: Expansion
Once selection hands us a node, we expand it by picking one of its unexpanded moves, creating a new child node, and returning the childβs index:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Expands the given node by picking one unexpanded move, creating
/// a child node for it, and returning the child's index.
///
/// If the node has no unexpanded moves (it is fully expanded or
/// terminal), returns `None`.
fn expand(&mut self, node_idx: usize) -> Option<usize> {
// Pop an unexpanded move. If there are none, we cannot expand.
let mv = self.nodes[node_idx].unexpanded_moves.pop()?;
// Create the child state by applying the move.
let child_state = self.nodes[node_idx].state.apply_move(mv);
let unexpanded = child_state.legal_moves();
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
Some(child_idx)
}
}
}
We use pop() to grab an unexpanded move. The order does not matter β MCTS will eventually try all of them. Using pop is convenient because it returns None when the list is empty, which signals that expansion is not possible.
Notice the two-step process for adding the child: first we push it onto self.nodes (which gives it an index), then we record that index in the parentβs children list. We also record node_idx as the childβs parent. These bidirectional links let us walk both up (for backpropagation) and down (for selection) the tree.
Because we are using indices rather than references, there is no borrow-checker conflict. We can freely read self.nodes[node_idx] and then push to self.nodes without Rust complaining about aliased mutable borrows β we just need to be careful not to hold a reference across the push. The code above handles this correctly by extracting the move and state into local variables before mutating the vector.
Phase 3: Simulation (rollout)
From the newly expanded node (or a terminal node if expansion was not possible), we play a random game to completion and return the result:
#![allow(unused)]
fn main() {
use rand::seq::SliceRandom;
impl MctsTree {
/// Plays random moves from the given state until the game ends.
/// Returns the reward from the perspective of the player who
/// just moved to reach `state`:
///
/// - `+1.0` if that player eventually wins
/// - `-1.0` if that player eventually loses
/// - `0.0` for a draw
///
/// The "player who just moved" is the opponent of
/// `state.current_player`.
fn simulate(&self, state: &GameState, rng: &mut impl rand::Rng) -> f64 {
let mut current = state.clone();
// Play random moves until the game is over.
while !current.is_terminal() {
let moves = current.legal_moves();
let &mv = moves.choose(rng).expect("non-terminal state has moves");
current = current.apply_move(mv);
}
// Determine the reward from the perspective of the player
// who moved to reach the *starting* state of this rollout.
// That player is state.current_player.opponent() β the one
// whose turn it is NOT.
let rollout_player = state.current_player.opponent();
match current.winner() {
Some(winner) if winner == rollout_player => 1.0,
Some(_) => -1.0,
None => 0.0, // draw
}
}
}
}
The rollout is intentionally simple: pick a random legal move, apply it, repeat. No strategy, no heuristics β pure random play. Despite this simplicity, the statistical signal is strong enough for MCTS to work. The key insight from Β§2 is that MCTS does not need accurate rollouts, just informative ones. Random play provides enough information to distinguish good positions from bad ones.
The reward perspective needs care. The value we return will be stored at the node from which the rollout started. That node was reached by a move from its parent. The parentβs player is state.current_player.opponent() β that is the player βwho just moved.β So we compute the reward from that playerβs perspective: +1 if they won, -1 if they lost, 0 for a draw.
We use the rand crate for random number generation. The choose method on slices picks a random element. We pass in a mutable reference to an Rng (random number generator) so the caller controls the source of randomness β this makes testing easier and avoids creating a new RNG on every call.
Phase 4: Backpropagation
After the rollout produces a reward, we propagate it back up the tree from the expanded node to the root, updating visit counts and value sums:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Propagates the simulation result back up from `node_idx` to
/// the root, updating visit counts and value sums.
///
/// `value` is the reward from the perspective of the player who
/// moved to reach `node_idx`.
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
// Move to the parent. Flip the sign of the value because
// the parent represents the opponent's perspective.
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
}
}
This is the step where the sign-flipping happens, and it is the most important detail to get right.
Consider a concrete scenario. Suppose X plays move 4 (center), reaching a child node. From that child, the rollout plays out and X wins. The reward is +1.0 (good for X, who moved to reach this node).
Now we backpropagate:
- At the child node (the one X moved to): add +1.0 to
total_value, incrementvisit_count. - At the parent node (where it was Xβs turn): flip the sign, add -1.0 to
total_value. This makes sense β from the parentβs perspective, the parentβs player is X, but the value stored at the parent is from the perspective of whoever moved to reach the parent. Wait, this is confusing. Letβs think about it differently.
The sign flip works because of a simple invariant: a nodeβs total_value records how good this node is for the player who chose to come here. The parent node was reached by a move from the grandparentβs player. That player is the opponent of the player who made the move from parent to child. So the childβs result, from the grandparent-playerβs perspective, has the opposite sign.
If this still feels abstract, trace through a small example on paper. Place three nodes: grandparent (Oβs turn), parent (Xβs turn, reached by Oβs move), child (Oβs turn, reached by Xβs move). If X wins the rollout from the child:
- Child: +1.0 (X moved here, X won β good for X)
- Parent: -1.0 (O moved here, X won β bad for O)
- Grandparent: +1.0 (X moved here, X won β good for X)
The alternating sign correctly reflects each playerβs perspective.
The main MCTS loop
Now we assemble the four phases into the main search loop:
#![allow(unused)]
fn main() {
impl MctsTree {
/// Runs MCTS for the given number of iterations and returns the
/// index of the best move from the root position.
fn search(
&mut self,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
for _ in 0..iterations {
// Phase 1: Selection
let selected = self.select();
// Phase 2: Expansion
// If the selected node can be expanded, expand it.
// Otherwise (terminal state), use the selected node itself.
let node_to_simulate = match self.expand(selected) {
Some(child_idx) => child_idx,
None => selected,
};
// Phase 3: Simulation
let state = self.nodes[node_to_simulate].state.clone();
let value = self.simulate(&state, rng);
// Phase 4: Backpropagation
self.backpropagate(node_to_simulate, value);
}
// After all iterations, pick the child of root with the
// highest visit count.
self.best_move()
}
/// Returns the move (board index) of the root's most-visited child.
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have at least one child after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root's children always have a move")
}
}
}
Each iteration is straightforward: select, expand, simulate, backpropagate. After all iterations are complete, we look at the rootβs children and return the move of the one with the most visits.
Why highest visit count rather than highest average value? As we discussed in Β§2, visit count is a more robust signal. A child with 500 visits and a 60% win rate is a more reliable choice than a child with 3 visits and a 100% win rate. The 100% win rate might just be lucky β three rollouts is not enough data. The UCT formula naturally funnels more visits toward genuinely strong moves, so the most-visited child is almost always the best.
A public API function
Letβs wrap everything in a clean public function that takes a game state, a number of iterations, and returns the best move:
#![allow(unused)]
fn main() {
/// Runs MCTS from the given position for the specified number of
/// iterations and returns the best move (a board index).
pub fn mcts_best_move(
state: &GameState,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search(iterations, rng)
}
}
This is the only function that code outside the MCTS module needs to call. Pass in the current game state, the computational budget (number of iterations), and a random number generator. Get back a move.
Complete implementation
Here is the entire MCTS module assembled into a single, compilable block. This goes in src/mcts.rs (or alongside your game code in main.rs β your choice):
#![allow(unused)]
fn main() {
use rand::seq::SliceRandom;
use std::f64::consts::SQRT_2;
// -- Import your game types (adjust the path to match your module layout) --
// use crate::{GameState, Player};
/// A single node in the MCTS search tree.
///
/// Nodes are stored in a flat `Vec` (arena allocation). Parent and
/// child links are indices, not pointers.
#[derive(Debug, Clone)]
struct MctsNode {
/// The game state at this node.
state: GameState,
/// The move that led to this state from the parent. `None` for root.
move_from_parent: Option<usize>,
/// Index of the parent node in the arena. `None` for the root.
parent: Option<usize>,
/// Indices of child nodes in the arena.
children: Vec<usize>,
/// Number of times this node has been visited.
visit_count: u32,
/// Total accumulated value (from the perspective of the player
/// who moved to reach this node).
total_value: f64,
/// Legal moves not yet expanded into children.
unexpanded_moves: Vec<usize>,
}
/// The MCTS search tree, backed by arena allocation.
struct MctsTree {
/// All nodes. The root is always at index 0.
nodes: Vec<MctsNode>,
}
/// Computes the UCT score for a child node.
///
/// Returns `f64::INFINITY` for unvisited children, ensuring they
/// are selected before any visited child.
fn uct_score(child_value: f64, child_visits: u32, parent_visits: u32) -> f64 {
if child_visits == 0 {
return f64::INFINITY;
}
let n = child_visits as f64;
let exploitation = child_value / n;
let exploration = SQRT_2 * ((parent_visits as f64).ln() / n).sqrt();
exploitation + exploration
}
impl MctsTree {
/// Creates a new tree rooted at the given game state.
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
MctsTree {
nodes: vec![root],
}
}
/// Phase 1: Selection.
///
/// Walks from the root to a node that either has unexpanded moves
/// or is terminal, always picking the child with the highest UCT
/// score.
fn select(&self) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
if !node.unexpanded_moves.is_empty() {
return current;
}
if node.children.is_empty() {
return current;
}
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = uct_score(
child.total_value,
child.visit_count,
parent_visits,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
/// Phase 2: Expansion.
///
/// Picks one unexpanded move from the given node, creates a child
/// node for it, and returns the child's index. Returns `None` if
/// the node is fully expanded or terminal.
fn expand(&mut self, node_idx: usize) -> Option<usize> {
let mv = self.nodes[node_idx].unexpanded_moves.pop()?;
let child_state = self.nodes[node_idx].state.apply_move(mv);
let unexpanded = child_state.legal_moves();
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
Some(child_idx)
}
/// Phase 3: Simulation (rollout).
///
/// Plays random moves from the given state until the game ends.
/// Returns the reward from the perspective of the player who
/// moved to reach this state.
fn simulate(
&self,
state: &GameState,
rng: &mut impl rand::Rng,
) -> f64 {
let mut current = state.clone();
while !current.is_terminal() {
let moves = current.legal_moves();
let &mv = moves
.choose(rng)
.expect("non-terminal state has moves");
current = current.apply_move(mv);
}
let rollout_player = state.current_player.opponent();
match current.winner() {
Some(winner) if winner == rollout_player => 1.0,
Some(_) => -1.0,
None => 0.0,
}
}
/// Phase 4: Backpropagation.
///
/// Walks from the given node back to the root, updating visit
/// counts and value sums. Flips the sign at each step to account
/// for alternating players.
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
/// Runs the full MCTS loop for the given number of iterations.
/// Returns the best move (board index) from the root position.
fn search(
&mut self,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
for _ in 0..iterations {
let selected = self.select();
let node_to_simulate = match self.expand(selected) {
Some(child_idx) => child_idx,
None => selected,
};
let state = self.nodes[node_to_simulate].state.clone();
let value = self.simulate(&state, rng);
self.backpropagate(node_to_simulate, value);
}
self.best_move()
}
/// Returns the move of the root's most-visited child.
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have children after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root children have moves")
}
}
/// Runs MCTS from the given position and returns the best move.
///
/// `iterations` controls the computational budget β more iterations
/// means stronger play but longer computation.
pub fn mcts_best_move(
state: &GameState,
iterations: u32,
rng: &mut impl rand::Rng,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search(iterations, rng)
}
}
Make sure you have the rand crate in your Cargo.toml:
[dependencies]
rand = "0.8"
Walking through key decisions
Letβs revisit the important design choices in this implementation:
Arena allocation. The Vec<MctsNode> arena avoids all the complexity of reference-counted trees. Every node has a stable index for the lifetime of the search. The only caveat is that expand must extract data from self.nodes[node_idx] into local variables before calling self.nodes.push(child), because the push might reallocate the vector and invalidate any outstanding references. Our code does this correctly β mv and child_state are computed before the push.
unexpanded_moves as a Vec. We store the unexpanded moves directly in each node and use pop() to grab one during expansion. This is simpler than computing the set difference between legal moves and existing children on every expansion. The cost is a small amount of extra memory per node, but for Tic-Tac-Toe (at most 9 moves) this is negligible.
Reward perspective. Each nodeβs total_value is from the perspective of the player who moved to reach that node. This means:
- During selection, UCT compares children from the perspective of the parentβs current player, which is exactly the player who will choose among those children.
- During backpropagation, we flip the sign at each level because adjacent levels represent alternating players.
Cloning GameState during simulation. The rollout clones the starting state and mutates the clone as it plays random moves. Since GameState is small (a 9-element array plus a Player enum), cloning is cheap. In a game with a larger state, you might use a more efficient approach β but for Tic-Tac-Toe, this is perfectly fine.
No parallelism. This is a single-threaded implementation. MCTS can be parallelized (run multiple rollouts in parallel), but that adds complexity we do not need yet. A single-threaded MCTS can run tens of thousands of iterations per second on Tic-Tac-Toe β far more than needed for strong play.
Testing it out
Letβs add a quick test to see MCTS in action. Add this to your main.rs:
use rand::SeedableRng;
use rand::rngs::StdRng;
fn main() {
// Use a seeded RNG for reproducible results.
let mut rng = StdRng::seed_from_u64(42);
let state = GameState::new();
println!("Running MCTS from the opening position...");
println!("{}", state);
// Run 10,000 iterations β far more than needed for Tic-Tac-Toe,
// but it runs in milliseconds.
let mut tree = MctsTree::new(state.clone());
let best = tree.search(10_000, &mut rng);
println!("Best move: cell {}", best);
println!();
// Print visit counts for all children of root.
let root = &tree.nodes[0];
println!("Move Visits Avg Value");
println!("---- ------ ---------");
for &child_idx in &root.children {
let child = &tree.nodes[child_idx];
let mv = child.move_from_parent.unwrap();
let avg = if child.visit_count > 0 {
child.total_value / child.visit_count as f64
} else {
0.0
};
println!(
" {} {:>5} {:>+.3}",
mv, child.visit_count, avg
);
}
}
Run it with cargo run. You should see output like this (exact numbers depend on the RNG seed):
Running MCTS from the opening position...
. | . | .
-----------
. | . | .
-----------
. | . | .
Best move: cell 4
Move Visits Avg Value
---- ------ ---------
8 797 -0.048
7 629 -0.091
6 806 -0.039
5 655 -0.078
4 3493 +0.108
3 660 -0.072
2 810 -0.042
1 567 -0.097
0 582 -0.085
Several things to notice:
-
Cell 4 (center) gets the most visits by far. This is correct β the center is the strongest opening move in Tic-Tac-Toe. MCTS discovers this entirely from random rollouts, with no built-in knowledge of Tic-Tac-Toe strategy.
-
Corner cells (0, 2, 6, 8) get more visits than edge cells (1, 3, 5, 7). Also correct β corners are the second-best opening moves. Again, MCTS figured this out on its own.
-
The center has a positive average value; all other moves have negative average values. This indicates that the center leads to positions where the first player has a statistical advantage, while other opening moves are slightly disadvantageous (though Tic-Tac-Toe is theoretically a draw with perfect play from both sides).
-
10,000 iterations run almost instantly. For a 3x3 board, even 100,000 iterations would finish in well under a second. Tic-Tac-Toeβs small game tree means MCTS converges quickly.
What we have built
You now have a complete, working MCTS implementation in Rust. Letβs review what each component does:
| Component | Responsibility |
|---|---|
MctsNode | Stores state, statistics (visits, value), parent/child links, unexpanded moves |
MctsTree | Arena of nodes; owns the entire search tree |
uct_score | Balances exploitation (average value) with exploration (visit ratio) |
select | Walks from root to a leaf using UCT |
expand | Adds a new child node for one unexpanded move |
simulate | Plays a random game to completion, returns the reward |
backpropagate | Updates statistics from leaf to root, flipping signs |
search | Runs the four-phase loop for N iterations, returns the best move |
mcts_best_move | Public API β creates a tree, searches, returns the best move |
This is a pure MCTS implementation β the rollouts are random, with no learned guidance. In Part 4, we will replace the random rollouts with a neural network that evaluates positions directly. That is the AlphaGo Zero innovation: instead of playing a random game to estimate a positionβs value, ask a neural network that has learned to predict the value from training data generated by self-play.
But even without a neural network, this MCTS player is surprisingly strong at Tic-Tac-Toe. In the next section, you will build a playable game loop and see for yourself.
8. Exercise 2: play Tic-Tac-Toe with pure MCTS
You have a working MCTS implementation and a complete Tic-Tac-Toe game engine. Time to wire them together and watch the AI play β first against itself, then against you.
This exercise has two parts:
- Part A: MCTS vs MCTS β run a tournament of 100 games and verify that strong play produces mostly draws.
- Part B: Human vs MCTS β build a CLI where you can challenge the AI yourself.
Part A: MCTS vs MCTS tournament
The game loop is straightforward: start from an empty board, alternate MCTS moves until the game ends, record the result. Repeat 100 times.
Think about what you need:
- A loop that plays a single game to completion.
- An outer loop that plays many games and tallies results.
- Statistics printed at the end: X wins, O wins, draws.
Try writing this yourself before looking at the solution. The key function you need is mcts_best_move(state, iterations, rng) from Β§7.
Solution: MCTS vs MCTS tournament
use rand::thread_rng;
fn play_one_game(iterations: u32) -> Option<Player> {
let mut state = GameState::new();
let mut rng = thread_rng();
loop {
// Check if the game is over.
if state.legal_moves().is_empty() {
return state.winner(); // None means draw
}
// MCTS picks the best move for whoever's turn it is.
let best_move = mcts_best_move(&state, iterations, &mut rng);
state = state.apply_move(best_move);
}
}
fn run_tournament(num_games: u32, iterations: u32) {
let mut x_wins = 0;
let mut o_wins = 0;
let mut draws = 0;
for game_num in 1..=num_games {
match play_one_game(iterations) {
Some(Player::X) => x_wins += 1,
Some(Player::O) => o_wins += 1,
None => draws += 1,
}
// Print progress every 10 games.
if game_num % 10 == 0 {
println!(
"After {} games: X={}, O={}, Draw={}",
game_num, x_wins, o_wins, draws
);
}
}
println!("\n=== Final Results ({} games, {} iterations) ===", num_games, iterations);
println!("X wins: {} ({:.1}%)", x_wins, x_wins as f64 / num_games as f64 * 100.0);
println!("O wins: {} ({:.1}%)", o_wins, o_wins as f64 / num_games as f64 * 100.0);
println!("Draws: {} ({:.1}%)", draws, draws as f64 / num_games as f64 * 100.0);
}
fn main() {
run_tournament(100, 1000);
}
What to expect at different iteration counts
Run the tournament with different MCTS budgets and observe how play quality changes:
| Iterations | Typical results (100 games) | Why |
|---|---|---|
| 100 | Mixed results β X wins often, some draws, occasional O wins | Too few iterations to explore the tree deeply. MCTS makes blunders, especially in the mid-game where tactical accuracy matters. |
| 1,000 | Mostly draws, a few X or O wins | Enough iterations to find reasonable moves. Occasional mistakes on tricky positions, but overall solid play. |
| 10,000 | Nearly 100% draws | The search tree is thoroughly explored. Both sides play near-perfectly. This is what we expect β Tic-Tac-Toe is a theoretical draw with optimal play from both sides. |
The trend is clear: more iterations = stronger play = more draws. Tic-Tac-Toeβs game tree is small enough that 10,000 iterations essentially solves it. For larger games like Go, this brute-force approach would not be enough β which is exactly why AlphaGo Zero added a neural network.
Note on randomness: Your exact numbers will differ between runs because MCTS uses random rollouts. The trend will be consistent though. If you want reproducible results, seed your RNG with a fixed value:
let mut rng = rand::rngs::StdRng::seed_from_u64(42);(you will needuse rand::SeedableRng;).
Part B: Human vs MCTS
Now build an interactive CLI. The human picks a cell by number (0β8), and MCTS responds. The board should be printed after every move so the human can see the state.
For a good user experience, you will want to:
- Show the board with numbered empty cells so the player knows which moves are available.
- Read a line from stdin and parse it as a
usize. - Validate that the chosen cell is actually empty.
- Handle invalid input gracefully (prompt again instead of crashing).
Here is a helper function that prints the board with position numbers for empty cells β this is much friendlier than showing dots:
#![allow(unused)]
fn main() {
fn print_board_with_indices(state: &GameState) {
for row in 0..3 {
for col in 0..3 {
let index = row * 3 + col;
match state.board[index] {
Some(Player::X) => print!(" X "),
Some(Player::O) => print!(" O "),
None => print!(" {} ", index),
}
if col < 2 {
print!("|");
}
}
println!();
if row < 2 {
println!("-----------");
}
}
}
}
At the start of a game, this displays:
0 | 1 | 2
-----------
3 | 4 | 5
-----------
6 | 7 | 8
After X plays center, it shows:
0 | 1 | 2
-----------
3 | X | 5
-----------
6 | 7 | 8
Try building the full interactive loop yourself. You will need std::io::stdin() and its read_line method.
Solution: Human vs MCTS
use rand::thread_rng;
use std::io::{self, Write};
fn get_human_move(state: &GameState) -> usize {
let legal = state.legal_moves();
loop {
print!("Your move (enter 0-8): ");
io::stdout().flush().unwrap();
let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
match input.trim().parse::<usize>() {
Ok(index) if legal.contains(&index) => return index,
Ok(index) => println!("Cell {} is not available. Try again.", index),
Err(_) => println!("Please enter a number between 0 and 8."),
}
}
}
fn play_human_vs_mcts(human_player: Player, iterations: u32) {
let mut state = GameState::new();
let mut rng = thread_rng();
println!("\n=== Tic-Tac-Toe: You are {} ===\n", human_player);
print_board_with_indices(&state);
loop {
if state.legal_moves().is_empty() {
match state.winner() {
Some(p) if p == human_player => println!("\nYou win!"),
Some(_) => println!("\nMCTS wins!"),
None => println!("\nIt's a draw!"),
}
break;
}
let chosen_move = if state.current_player == human_player {
println!();
get_human_move(&state)
} else {
println!("\nMCTS is thinking...");
let m = mcts_best_move(&state, iterations, &mut rng);
println!("MCTS plays position {}.", m);
m
};
state = state.apply_move(chosen_move);
println!();
print_board_with_indices(&state);
}
}
fn main() {
// Human plays X (moves first), MCTS uses 1000 iterations.
play_human_vs_mcts(Player::X, 1000);
}
A typical session looks like this:
=== Tic-Tac-Toe: You are X ===
0 | 1 | 2
-----------
3 | 4 | 5
-----------
6 | 7 | 8
Your move (enter 0-8): 4
0 | 1 | 2
-----------
3 | X | 5
-----------
6 | 7 | 8
MCTS is thinking...
MCTS plays position 0.
O | 1 | 2
-----------
3 | X | 5
-----------
6 | 7 | 8
Your move (enter 0-8): 2
...
Experimentation prompts
Now that you have both programs, try these experiments:
-
Can you beat MCTS at 10 iterations? With only 10 iterations per move, the search barely explores the tree. You should be able to win consistently. Try it β the AI will make obvious blunders.
-
At what iteration count does MCTS become unbeatable? Start low and work your way up. Try 50, 100, 200, 500. Find the threshold where you can no longer win. For most people, somewhere around 200β500 iterations is enough for MCTS to play flawlessly on Tic-Tac-Toe.
-
Does it matter who goes first? Run the tournament with MCTS playing as X (first) and O (second). In theory, the first player has a slight advantage in Tic-Tac-Toe β does that show up in your results at lower iteration counts?
-
What about asymmetric strength? Modify
play_one_gameso X uses 100 iterations and O uses 10,000. The stronger player should win more often. How lopsided do the results get?
Solution: Asymmetric tournament
#![allow(unused)]
fn main() {
fn play_asymmetric_game(x_iterations: u32, o_iterations: u32) -> Option<Player> {
let mut state = GameState::new();
let mut rng = thread_rng();
loop {
if state.legal_moves().is_empty() {
return state.winner();
}
let iterations = match state.current_player {
Player::X => x_iterations,
Player::O => o_iterations,
};
let best_move = mcts_best_move(&state, iterations, &mut rng);
state = state.apply_move(best_move);
}
}
}
Readiness checklist
Before moving on to Part 4, make sure you can answer βyesβ to all of these:
- You ran a 100-game MCTS vs MCTS tournament and saw mostly draws at high iteration counts.
- You understand why more iterations leads to stronger play (deeper tree exploration, more accurate value estimates).
- You played against MCTS interactively and experienced its strength at high iteration counts.
- You found the approximate iteration threshold where MCTS becomes unbeatable at Tic-Tac-Toe.
- You understand that this approach (brute-force random rollouts) does not scale to complex games β and that replacing rollouts with a neural network is the key insight of AlphaGo Zero.
If so, you are ready for Part 4, where we introduce neural networks to guide MCTS. Instead of random rollouts, the network will learn to predict which moves are good (policy) and who is winning (value) β trained entirely from self-play data.
Part 4 β Neural Network Policy/Value Head
9. Neural network architecture overview
In Part 3, we built an MCTS agent that plays Tic-Tac-Toe by simulating thousands of random games from each position. It works β but it is brute force. Every evaluation requires playing a game to completion with random moves, and the quality of the evaluation depends on how many of those random rollouts you can afford. For a 3x3 board, that is fine. For a 19x19 Go board, it is hopelessly slow.
AlphaGo Zeroβs breakthrough was replacing those random rollouts with a neural network β a function that looks at a board position and instantly outputs two things: which moves are likely to be good (policy) and who is probably winning (value). No random simulation needed. The network learns these judgments from millions of self-play games.
This section teaches you what a neural network is, how it works, and why the specific architecture used in AlphaGo Zero has two βheads.β We will build up from a single artificial neuron to the complete dual-headed network we will implement in Rust in the next section. No prior neural network knowledge is assumed.
What is a neural network?
A neural network is a function approximator. That is a fancy way of saying: it is a machine that takes in some numbers, does math to them, and produces some numbers as output. What makes it special is that the math it does is controlled by thousands (or millions) of adjustable parameters, and those parameters can be tuned so the function produces whatever output you want for any given input.
Think of it like a programmable calculator with thousands of tiny knobs. At first, the knobs are set randomly, and the calculator outputs nonsense. But if you had a way to measure how wrong the output is and nudge each knob in the right direction, eventually the calculator would learn to compute the function you care about.
For our game AI, the function we want is:
- Input: a Tic-Tac-Toe board position
- Output: (1) which move to play, and (2) who is winning
The neural network will learn this function from data generated by self-play. We never hand it a strategy book or a database of expert games. It figures out good play entirely on its own.
The artificial neuron
The fundamental building block of a neural network is the neuron (also called a node or unit). A single neuron does three things:
- Takes in one or more numbers as inputs.
- Multiplies each input by a weight and sums them up, then adds a bias.
- Passes the result through an activation function.
Here is a concrete example. Suppose we have a neuron with two inputs:
Inputs: xβ = 3.0, xβ = -1.0
Weights: wβ = 0.5, wβ = 2.0
Bias: b = 0.1
Step 1 β Weighted sum:
z = (xβ Γ wβ) + (xβ Γ wβ) + b
z = (3.0 Γ 0.5) + (-1.0 Γ 2.0) + 0.1
z = 1.5 + (-2.0) + 0.1
z = -0.4
Step 2 β Activation function (ReLU):
output = max(0, z) = max(0, -0.4) = 0.0
Letβs break down each component:
Weights are the knobs we mentioned earlier. Each input gets its own weight, which determines how much that input matters to this neuron. A large positive weight means βthis input is very important β when it goes up, push my output up.β A large negative weight means βthis input is important in the opposite direction.β A weight near zero means βI donβt care about this input.β
Bias is an offset. It shifts the neuronβs output up or down regardless of the inputs. Think of it as the neuronβs βdefault inclinationβ β even when all inputs are zero, the bias determines whether the neuron tends to activate or stay quiet.
Activation function introduces nonlinearity. Without it, a neuron is just computing a weighted sum, which is a straight line. No matter how many straight-line computations you stack together, the result is still a straight line. Activation functions add curves, which let networks learn complex, nonlinear patterns.
The most common activation function in modern networks is ReLU (Rectified Linear Unit):
ReLU(z) = max(0, z)
If the weighted sum is positive, pass it through unchanged. If negative, output zero. It is dead simple, yet it works remarkably well. The key insight is that it creates a βhingeβ β the neuron either fires (positive sum) or stays silent (negative sum), which is enough nonlinearity for stacking layers to produce complex behavior.
Other activation functions you may encounter:
- Sigmoid: squashes output to the range (0, 1). Useful for probabilities.
- Tanh: squashes output to the range (-1, 1). Useful when you need centered outputs β like our value head.
- Softmax: applied to a group of neurons, converts raw scores into a probability distribution that sums to 1. We will use this for the policy head.
Layers: stacking neurons together
A single neuron can only learn simple, linear-ish decision boundaries. The real power comes from organizing neurons into layers and stacking those layers on top of each other.
A neural network has three kinds of layers:
Input layer β This is not really a layer of neurons at all. It is simply the raw data you feed into the network. For our Tic-Tac-Toe network, the input layer will be 18 numbers representing the board state (more on the encoding later).
Hidden layers β These are the layers in the middle where the actual learning happens. Each hidden layer is a group of neurons that all take the same inputs (the outputs from the previous layer) and each produce one output. βHiddenβ just means they are not directly visible as input or output β they are internal to the network.
Output layer β The final layer that produces the networkβs answer. For our game AI, this will produce move probabilities and a value estimate.
Here is how a simple two-layer network looks:
Input layer Hidden layer 1 Hidden layer 2 Output layer
(3 inputs) (4 neurons) (4 neurons) (2 outputs)
xβ βββββββββββ hβ βββββββββββ hβ
βββββββββββ oβ
β² β± β² β± β² β±
β² β± β² β± β² β±
xβ βββββ³βββββ hβ βββββ³βββββ hβ βββββ³βββββ oβ
β± β² β± β² β±
β± β² β± β² β±
xβ βββββββββββ hβ βββββββββββ hβ
β± β±
β± β±
hβ βββββββββββ hβ
Every neuron in one layer is connected to every neuron in the next layer β this is called a fully connected (or dense) layer. Each connection has its own weight. In this diagram:
- The connections from the input layer to hidden layer 1 have 3 x 4 = 12 weights, plus 4 biases = 16 parameters.
- Hidden layer 1 to hidden layer 2 has 4 x 4 = 16 weights, plus 4 biases = 20 parameters.
- Hidden layer 2 to the output has 4 x 2 = 8 weights, plus 2 biases = 10 parameters.
- Total: 46 learnable parameters.
Even this tiny network has 46 knobs to tune. Real networks have millions.
Why does stacking layers help? Each layer learns progressively more abstract features. In image recognition, the first layer might learn to detect edges, the second layer combines edges into shapes, and the third layer combines shapes into objects. In game evaluation, the first layer might learn to detect individual piece patterns, the second layer might recognize tactical motifs, and deeper layers might understand strategic concepts. The network discovers these intermediate representations on its own β we never tell it what to look for.
The forward pass
The forward pass is the process of pushing data through the network from input to output. There is no mystery to it β it is just repeated application of the neuron computation (weighted sum, add bias, apply activation) at every layer.
Here is the forward pass in pseudocode for a two-hidden-layer network:
function forward(input):
# Layer 1: input β hidden1
hidden1 = activation(weights1 Γ input + biases1)
# Layer 2: hidden1 β hidden2
hidden2 = activation(weights2 Γ hidden1 + biases2)
# Output layer: hidden2 β output
output = output_activation(weights3 Γ hidden2 + biases3)
return output
Each step takes the outputs from the previous layer, multiplies by a weight matrix, adds biases, and applies the activation function. The whole thing is pure arithmetic β matrix multiplications and element-wise function applications. This is why neural networks can run so fast on GPUs, which are designed for exactly this kind of math.
For our Tic-Tac-Toe network, a single forward pass will take the 18-number board encoding, push it through two hidden layers of 64 neurons each, and produce 10 outputs (9 move probabilities + 1 value). The entire computation takes microseconds β vastly faster than running hundreds of random rollouts.
Training: how the network learns
A freshly created neural network has random weights. Its forward pass produces garbage. Training is the process of adjusting those weights so the networkβs outputs match what we want.
Training requires three ingredients:
1. Training data β input/output pairs.
For our game AI, the data comes from self-play. After playing thousands of games using MCTS, we collect triples of (board position, MCTS move probabilities, game outcome). The board position is the input; the MCTS probabilities and game outcome are the βcorrectβ outputs we want the network to learn.
2. A loss function β a measure of how wrong the network is.
The loss function compares the networkβs output to the desired output and produces a single number: the loss. A high loss means the network is doing badly; a low loss means it is doing well. Common choices:
- Cross-entropy loss for the policy head: measures how different the networkβs predicted move probabilities are from the MCTS move probabilities. If MCTS strongly preferred one move but the network assigns it low probability, the loss is high.
- Mean squared error for the value head: measures how far the networkβs predicted value is from the actual game outcome. If the network said +0.8 (strongly favoring the current player) but the game was actually a loss (-1.0), the loss is high.
The total loss is typically the sum of these two:
total_loss = policy_loss + value_loss
3. An optimizer β a method for adjusting weights to reduce the loss.
This is where gradient descent comes in. The core idea is beautifully simple:
- Run the forward pass to get the networkβs output.
- Compute the loss (how wrong is the output?).
- Figure out which direction to nudge each weight to make the loss smaller. This is the gradient β it tells you, for every single weight in the network, βif you increase this weight slightly, does the loss go up or down, and by how much?β
- Nudge each weight a small step in the direction that reduces the loss.
- Repeat thousands of times.
The process of computing gradients is called backpropagation. The name comes from the fact that the computation flows backward through the network β starting from the loss at the output and propagating back through each layer to determine how each weight contributed to the error. You do not need to understand the calculus behind it to use it; every neural network library handles backpropagation automatically.
Here is the intuition with an analogy. Imagine you are blindfolded on a hilly landscape and trying to reach the lowest valley. You cannot see, but you can feel the slope of the ground beneath your feet. At each step, you feel which direction goes downhill and take a small step that way. That is gradient descent. The βlandscapeβ is the loss function, the βpositionβ is the current set of weights, and each step adjusts the weights to reduce the loss.
The size of each step is controlled by the learning rate β a small number like 0.001 or 0.01. Too large and you overshoot the valley, bouncing around wildly. Too small and training takes forever. Finding a good learning rate is one of the practical arts of training neural networks.
The typical training loop looks like this:
for each batch of training examples:
predictions = forward(inputs)
loss = compute_loss(predictions, targets)
gradients = backpropagate(loss)
for each weight in network:
weight = weight - learning_rate Γ gradient
After thousands of iterations through the training data, the weights converge to values that make the network produce good outputs for the kinds of inputs it has seen β and, crucially, it generalizes to similar inputs it has not seen. This is the magic: a network trained on a million board positions can evaluate positions it has never encountered before.
The dual-headed architecture
Now we can talk about the specific architecture used in AlphaGo Zero and our Tic-Tac-Toe AI. It is called a dual-headed network because it has one body and two βheadsβ β one for policy and one for value.
The shared trunk
The first part of the network is a set of hidden layers that process the raw board state into a learned representation β an internal encoding that captures the important features of the position. This is called the βtrunkβ or βbackboneβ of the network.
The trunk takes the board encoding as input and outputs a vector of numbers that represents what it has learned about the position. These numbers are not directly interpretable by humans β they are abstract features that the network has discovered are useful for both predicting good moves and predicting who will win.
The policy head
The policy head branches off from the trunk and produces a probability distribution over all possible moves. For Tic-Tac-Toe, that means 9 numbers (one per cell) that sum to 1.0. Each number represents how strongly the network recommends playing in that cell.
For example, the policy head might output:
Cell: 0 1 2 3 4 5 6 7 8
Probability: 0.02 0.03 0.01 0.05 0.70 0.08 0.01 0.05 0.05
This says: βI strongly recommend cell 4 (the center), with 70% probability.β The MCTS algorithm uses these probabilities to focus its search on the most promising moves rather than exploring everything equally.
The policy headβs final activation function is softmax, which converts raw scores into a valid probability distribution.
The value head
The value head also branches off from the trunk, but instead of 9 outputs it produces a single number between -1 and 1. This number estimates who is winning from the perspective of the current player:
- +1 means βthe current player is definitely winning.β
- -1 means βthe current player is definitely losing.β
- 0 means βthe position is roughly even.β
For example, on a Tic-Tac-Toe board where X has two in a row and it is Xβs turn, the value head might output +0.95 β βX is almost certainly going to win.β
The value headβs final activation function is tanh, which squashes any number into the range (-1, 1).
Why two heads on one trunk?
You might wonder: why not train two completely separate networks, one for policy and one for value? There are three reasons for sharing a trunk.
Shared feature learning. Many features of a position are useful for both predicting good moves and predicting who is winning. For instance, recognizing that a player has two in a row is relevant both to βwhat move should I play?β (block or complete the row) and βwho is winning?β (the player with two in a row has an advantage). By sharing a trunk, both heads benefit from the same learned features without having to discover them independently.
Efficiency. One network is faster to run than two. During MCTS, we call the network for every node we expand, so speed matters. A shared trunk means the expensive computation happens once, and then the two heads β which are small β each do a tiny bit of additional work.
Regularization. Training the trunk to be useful for two different tasks simultaneously acts as a form of regularization β it prevents the network from overfitting to either task alone. The trunk must learn features that are generally useful, which tends to produce more robust representations.
Input representation: encoding the board as a tensor
Neural networks only understand numbers. We cannot feed a Tic-Tac-Toe board directly into the network β we need to convert it into a numerical format called a tensor (which is just a fancy word for a multi-dimensional array of numbers).
The standard approach for board games is to use binary feature planes β one plane per piece type. For Tic-Tac-Toe, we create two 3x3 grids:
- Plane 1 (X positions): a 1 where X has played, 0 elsewhere.
- Plane 2 (O positions): a 1 where O has played, 0 elsewhere.
For example, given this board:
X | O | .
-----------
. | X | .
-----------
. | . | O
The two planes are:
Plane 1 (X): Plane 2 (O):
1 0 0 0 1 0
0 1 0 0 0 0
0 0 0 0 0 1
We flatten these two 3x3 planes into a single 1-dimensional vector of 18 numbers:
[1, 0, 0, 0, 1, 0, 0, 0, 0, β X plane (row by row)
0, 1, 0, 0, 0, 0, 0, 0, 1] β O plane (row by row)
This 18-element vector is what the networkβs input layer receives.
Why two separate planes instead of a single grid with, say, +1 for X and -1 for O? The two-plane encoding gives the network richer information. It can independently learn features about Xβs position and Oβs position, rather than having them tangled together in a single number. This also scales naturally to more complex games β chess, for instance, uses 12 planes (one per piece type per color).
We also always encode the board from the perspective of the current player. That is, if it is Oβs turn, we swap the planes so that βmy piecesβ are always in Plane 1 and βopponentβs piecesβ are always in Plane 2. This way the network learns a single concept β βam I winning?β and βwhat should I play?β β rather than needing to learn separate strategies for X and O.
Output shapes
The network produces two outputs:
Policy output: 9 values.
One value for each of the 9 cells on the board. After the softmax activation, these form a probability distribution. Illegal moves (cells already occupied) will still get some probability assigned by the network, but we mask them out before using the policy β we set their probabilities to zero and renormalize the remaining probabilities to sum to 1.
Value output: 1 value.
A single scalar in the range [-1, 1], produced by the tanh activation. This estimates the game-theoretic value of the position from the current playerβs perspective.
Our concrete architecture
Here is the specific network we will build for Tic-Tac-Toe:
βββββββββββββββββββββββββββββββββββ
β INPUT LAYER β
β 18 neurons (2 Γ 3 Γ 3) β
β X-plane (9) + O-plane (9) β
ββββββββββββββββββ¬βββββββββββββββββ
β
ββββββββ΄βββββββ
βββββββββββ΄ββββββββββββββ΄ββββββββββ
β HIDDEN LAYER 1 (Trunk) β
β 64 neurons, ReLU activation β
β Parameters: 18Γ64 + 64 = 1216 β
ββββββββββββββββββ¬βββββββββββββββββ
β
ββββββββ΄βββββββ
βββββββββββ΄ββββββββββββββ΄ββββββββββ
β HIDDEN LAYER 2 (Trunk) β
β 64 neurons, ReLU activation β
β Parameters: 64Γ64 + 64 = 4160 β
ββββββββββββ¬ββββββββββ¬βββββββββββββ
β β
ββββββββββββββ ββββββββββββββ
β β
ββββββββββββββ΄βββββββββββββ ββββββββββββββββ΄βββββββββββββββ
β POLICY HEAD β β VALUE HEAD β
β 9 neurons (one/cell) β β 1 neuron β
β Softmax activation β β Tanh activation β
β Params: 64Γ9 + 9=585 β β Params: 64Γ1 + 1 = 65 β
ββββββββββββββ¬βββββββββββββ ββββββββββββββββ¬βββββββββββββββ
β β
βΌ βΌ
βββββββββββββββββββ ββββββββββββββββββββ
β Move probs [9] β β Value [-1, +1] β
β (sums to 1.0) β β (single scalar) β
βββββββββββββββββββ ββββββββββββββββββββ
Letβs count the total parameters:
| Layer | Weights | Biases | Total |
|---|---|---|---|
| Input β Hidden 1 | 18 Γ 64 = 1,152 | 64 | 1,216 |
| Hidden 1 β Hidden 2 | 64 Γ 64 = 4,096 | 64 | 4,160 |
| Hidden 2 β Policy | 64 Γ 9 = 576 | 9 | 585 |
| Hidden 2 β Value | 64 Γ 1 = 64 | 1 | 65 |
| Total | 6,026 |
Just over 6,000 learnable parameters. That is tiny by modern standards (large language models have billions), but it is more than enough for Tic-Tac-Toe. The game has only 5,478 possible board positions β our network has enough capacity to learn a good representation.
Summary notation
The architecture can be written compactly as:
18 β 64 (ReLU) β 64 (ReLU) β { 9 (Softmax), 1 (Tanh) }
Read this as: β18 inputs, through a 64-neuron hidden layer with ReLU, through another 64-neuron hidden layer with ReLU, splitting into a policy head of 9 outputs with softmax and a value head of 1 output with tanh.β
How this replaces random rollouts
In Part 3, our pure MCTS agent evaluated positions by running random rollouts β playing random moves until the game ended and using the win/loss result as the value estimate. This works, but has two problems:
-
Random moves are a terrible proxy for good play. A position might be winning for X, but if the random rollout happens to have X make several blunders, the rollout says X lost. You need many rollouts to average out the noise.
-
Rollouts give no guidance on which moves to explore first. MCTS starts by trying all moves roughly equally, relying on UCB1 to gradually focus on the better ones. Early in the search, a lot of time is wasted exploring obviously bad moves.
The neural network solves both problems:
Value head replaces rollouts. Instead of playing a game to completion with random moves and seeing who wins, we simply ask the value head: βgiven this board position, who is winning?β The network returns a number immediately. One forward pass takes microseconds, compared to potentially hundreds of random rollouts. And the value estimate is based on learned patterns, not random play β so it is far more accurate, even with a single evaluation.
Policy head guides the search. Instead of exploring all moves equally at first, we use the policy headβs probabilities to focus on the moves the network thinks are best. If the network assigns 70% probability to one move and 2% to another, MCTS will explore that first move far more eagerly. This dramatically reduces the amount of search needed to find good moves.
Together, these two improvements are what allowed AlphaGo Zero to surpass all previous Go programs. The network provides strong βintuitionβ (good priors on which moves to try, and good evaluations of positions), and MCTS provides exact βcalculationβ (deep search to verify and refine the networkβs intuition). They complement each other beautifully.
In the next section, we will integrate a neural network crate into our Rust project and build this architecture in code.
10. Integrating a neural network crate
In Β§9 we designed a dual-headed neural network on paper: 18 inputs, two hidden layers of 64 neurons each, a 9-output policy head, and a 1-output value head. Now we build it in Rust.
Choosing an approach
The Rust ecosystem has several neural network options:
tch-rsβ Rust bindings to PyTorchβs C++ library (libtorch). Full-featured and fast, but it drags in a massive native dependency (~2 GB). Overkill for our 6,026-parameter network.candleβ A pure-Rust ML framework from Hugging Face. No C++ dependency, but still a large crate with GPU support, autograd, and many features we will never touch.- Implement from scratch β Write our own matrix math, forward pass, and backpropagation.
For a network this small, we will implement everything from scratch. Here is why:
- It is more educational. You will understand exactly what happens inside a neural network β not just how to call an API, but what the API does internally.
- Zero dependencies. No build headaches, no version conflicts, no multi-gigabyte downloads.
- Tic-Tac-Toe does not need GPU acceleration. Our forward pass multiplies an 18-element vector through a few small matrices. That takes microseconds on a CPU. An entire training run completes in seconds.
The only downside is that we must implement backpropagation ourselves. For our simple architecture (a few dense layers with standard activations), this is entirely manageable. Letβs get started.
The building blocks: vectors and matrices
Neural networks are fundamentally about matrix multiplication. Before building layers, we need a few utility functions that operate on flat Vec<f32> vectors. We represent an m Γ n matrix as a Vec<f32> of length m Γ n, stored in row-major order β row 0 first, then row 1, and so on.
#![allow(unused)]
fn main() {
/// Multiply matrix `a` (rows Γ inner) by matrix `b` (inner Γ cols).
/// Both are stored in row-major order. Returns a vector of length rows Γ cols.
fn mat_mul(a: &[f32], b: &[f32], rows: usize, inner: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0.0; rows * cols];
for r in 0..rows {
for c in 0..cols {
let mut sum = 0.0;
for k in 0..inner {
sum += a[r * inner + k] * b[k * cols + c];
}
out[r * cols + c] = sum;
}
}
out
}
/// Add bias vector `b` to each element of `a` (both length `n`).
fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
}
This is a straightforward triple-loop matrix multiply β O(rows Γ inner Γ cols). For our largest multiplication (64 Γ 64), that is only 262,144 floating-point operations. Modern CPUs execute billions per second, so this will be instantaneous.
The Layer struct
A dense (fully connected) layer is defined by its weight matrix and bias vector. During a forward pass, it computes output = activation(input Γ weights + bias).
We represent the activation function as an enum so we can branch on it during both the forward and backward passes:
#![allow(unused)]
fn main() {
/// Supported activation functions.
#[derive(Debug, Clone, Copy)]
enum Activation {
/// ReLU: max(0, x). Used in hidden layers.
Relu,
/// Tanh: maps to (-1, +1). Used in the value head.
Tanh,
/// Identity: no transformation. We apply softmax separately for the
/// policy head because softmax operates on the entire output vector,
/// not element-wise.
Identity,
}
/// A single fully connected layer.
#[derive(Debug, Clone)]
struct Layer {
/// Weight matrix, shape: input_size Γ output_size, row-major.
weights: Vec<f32>,
/// Bias vector, length: output_size.
biases: Vec<f32>,
/// Number of inputs this layer expects.
input_size: usize,
/// Number of outputs this layer produces.
output_size: usize,
/// Activation function applied after the linear transform.
activation: Activation,
}
}
Why is the policy headβs activation Identity instead of Softmax? Softmax is not element-wise β it depends on all the outputs simultaneously (because the probabilities must sum to 1). So we compute the linear output first (called βlogitsβ), then apply softmax as a separate step. This also makes it easier to implement the cross-entropy loss during training, which works directly on logits.
Weight initialization
Before a network can learn, its weights need initial values. Setting them all to zero is a disaster β every neuron would compute the same thing and receive the same gradient, so they would never differentiate. We need random initial values, but the scale matters.
Xavier initialization (also called Glorot initialization) sets each weight to a random value drawn from a uniform distribution:
w ~ Uniform(-limit, +limit) where limit = sqrt(6 / (fan_in + fan_out))
Here fan_in is the number of inputs to the layer and fan_out is the number of outputs. The idea is to keep the variance of activations roughly constant as signals flow through the network. If weights are too large, activations explode; too small, and they vanish to zero. Xavier initialization balances this.
We will use a simple linear congruential generator (LCG) so our code has zero dependencies β not even on rand. For a production system you would want a better RNG, but for training a Tic-Tac-Toe network, this is perfectly adequate:
#![allow(unused)]
fn main() {
/// A simple pseudo-random number generator (LCG).
/// Not cryptographically secure, but fine for weight initialization.
struct Rng {
state: u64,
}
impl Rng {
fn new(seed: u64) -> Self {
Rng { state: seed }
}
/// Returns a random f32 in [0, 1).
fn next_f32(&mut self) -> f32 {
// LCG parameters from Numerical Recipes.
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
// Use bits 33..17 for better distribution.
let bits = ((self.state >> 17) & 0x7FFF) as f32;
bits / 32768.0
}
/// Returns a random f32 in [-limit, +limit).
fn uniform(&mut self, limit: f32) -> f32 {
self.next_f32() * 2.0 * limit - limit
}
}
}
Now we can create a layer with Xavier-initialized weights:
#![allow(unused)]
fn main() {
impl Layer {
/// Create a new layer with Xavier-initialized weights and zero biases.
fn new(input_size: usize, output_size: usize, activation: Activation, rng: &mut Rng) -> Self {
let limit = (6.0 / (input_size + output_size) as f32).sqrt();
let weights: Vec<f32> = (0..input_size * output_size)
.map(|_| rng.uniform(limit))
.collect();
let biases = vec![0.0; output_size];
Layer {
weights,
biases,
input_size,
output_size,
activation,
}
}
}
}
Biases start at zero. This is standard practice β the random weights already break symmetry, so zero biases are fine.
Forward pass
The forward pass pushes an input vector through the layer. We need to save the intermediate values (the pre-activation output and the input) because backpropagation will need them later:
#![allow(unused)]
fn main() {
impl Layer {
/// Run the forward pass. Returns (output, pre_activation, input_clone).
/// We save intermediate values for backpropagation.
fn forward(&self, input: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
// Linear transform: output = input Γ weights + biases
let pre_act = vec_add(
&mat_mul(input, &self.weights, 1, self.input_size, self.output_size),
&self.biases,
);
// Apply activation function element-wise.
let output = match self.activation {
Activation::Relu => pre_act.iter().map(|&x| x.max(0.0)).collect(),
Activation::Tanh => pre_act.iter().map(|&x| x.tanh()).collect(),
Activation::Identity => pre_act.clone(),
};
(output, pre_act, input.to_vec())
}
}
}
Notice that we treat the input as a 1 Γ input_size matrix and the weights as an input_size Γ output_size matrix. The result is a 1 Γ output_size vector, which we then add the bias to and activate.
Softmax
The policy head outputs raw logits β unbounded numbers that can be positive or negative. Softmax converts them into a probability distribution (all positive, summing to 1):
softmax(x_i) = exp(x_i) / sum(exp(x_j) for all j)
A numerical stability trick: we subtract the maximum value before exponentiating. This prevents overflow when a logit is large:
#![allow(unused)]
fn main() {
/// Convert a vector of logits into a probability distribution.
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
}
If the logits are [2.0, 1.0, 0.1], softmax produces roughly [0.66, 0.24, 0.10] β the largest logit gets the highest probability, but every value is positive and they sum to 1.
The Network struct
Our network has four layers total: two shared βtrunkβ layers, one policy head, and one value head. After the trunk, the computation branches β the trunkβs output feeds into both heads independently:
#![allow(unused)]
fn main() {
/// A dual-headed neural network for Tic-Tac-Toe.
///
/// Architecture: 18 β 64 (ReLU) β 64 (ReLU) β { 9 (Softmax), 1 (Tanh) }
#[derive(Debug, Clone)]
struct Network {
/// First hidden layer (trunk): 18 β 64.
hidden1: Layer,
/// Second hidden layer (trunk): 64 β 64.
hidden2: Layer,
/// Policy head: 64 β 9. Activation is Identity (softmax applied separately).
policy_head: Layer,
/// Value head: 64 β 1. Activation is Tanh.
value_head: Layer,
}
impl Network {
/// Create a new network with Xavier-initialized weights.
fn new(seed: u64) -> Self {
let mut rng = Rng::new(seed);
Network {
hidden1: Layer::new(18, 64, Activation::Relu, &mut rng),
hidden2: Layer::new(64, 64, Activation::Relu, &mut rng),
policy_head: Layer::new(64, 9, Activation::Identity, &mut rng),
value_head: Layer::new(64, 1, Activation::Tanh, &mut rng),
}
}
}
}
This exactly matches the architecture diagram from Β§9. The total parameter count is 1,216 + 4,160 + 585 + 65 = 6,026 β the same number we calculated on paper.
Encoding the board state
Before feeding a position to the network, we must convert our GameState into the 18-element input vector described in Β§9. Recall the encoding: the first 9 elements are the βcurrent playerβs piecesβ plane, and the second 9 elements are the βopponentβs piecesβ plane.
#![allow(unused)]
fn main() {
/// Encode a game state into the 18-element input vector.
///
/// Elements 0..9: 1.0 where the current player has a piece, else 0.0.
/// Elements 9..18: 1.0 where the opponent has a piece, else 0.0.
fn encode_board(state: &GameState) -> Vec<f32> {
let mut input = vec![0.0_f32; 18];
let current = state.current_player;
for i in 0..9 {
match state.board[i] {
Some(p) if p == current => input[i] = 1.0,
Some(_) => input[i + 9] = 1.0,
None => {}
}
}
input
}
}
By always putting the current playerβs pieces first, the network learns a single perspective: βwhat should I do?β This works for both X and O without needing separate networks or any special-casing.
Illegal move masking
The policy head outputs a probability for all 9 cells, but some cells may already be occupied. We must not select an occupied cell. The solution: mask the illegal moves by setting their probabilities to zero, then renormalize so the remaining probabilities sum to 1.
#![allow(unused)]
fn main() {
/// Zero out probabilities for occupied cells and renormalize.
fn mask_illegal(policy: &[f32], state: &GameState) -> Vec<f32> {
let mut masked: Vec<f32> = policy
.iter()
.enumerate()
.map(|(i, &p)| if state.board[i].is_none() { p } else { 0.0 })
.collect();
let sum: f32 = masked.iter().sum();
if sum > 0.0 {
for p in &mut masked {
*p /= sum;
}
}
masked
}
}
We apply masking after softmax. This is simpler than masking logits (which would require setting illegal logits to negative infinity before softmax). Both approaches work, but post-softmax masking is easier to understand.
The predict method
Now we wire everything together. The predict method takes a GameState and returns a (Vec<f32>, f32) tuple β the move probability distribution and the positionβs value:
#![allow(unused)]
fn main() {
impl Network {
/// Run a forward pass and return (policy, value).
///
/// `policy` is a 9-element probability distribution over moves,
/// with illegal moves masked to zero.
/// `value` is a scalar in [-1, +1] estimating the current player's
/// chance of winning.
fn predict(&self, state: &GameState) -> (Vec<f32>, f32) {
let input = encode_board(state);
// Forward through the trunk.
let (h1_out, _, _) = self.hidden1.forward(&input);
let (h2_out, _, _) = self.hidden2.forward(&h1_out);
// Policy head: logits β softmax β mask illegal moves.
let (policy_logits, _, _) = self.policy_head.forward(&h2_out);
let policy = softmax(&policy_logits);
let policy = mask_illegal(&policy, state);
// Value head: single tanh output.
let (value_out, _, _) = self.value_head.forward(&h2_out);
let value = value_out[0];
(policy, value)
}
}
}
That is the complete forward pass. Given a board position, we encode it, push it through two hidden layers, then branch into the two heads. The policy headβs raw output goes through softmax and illegal-move masking; the value headβs output is already in [-1, +1] thanks to tanh.
Training infrastructure
A network with random weights is useless β it needs to learn from data. In the self-play loop (covered in Part 5), we will generate training examples by having the network play games against itself, using MCTS to produce improved move probabilities. Each training example is a tuple of:
- State: a board position (encoded as 18 floats)
- Target policy: the MCTS-improved move probabilities (9 floats summing to 1)
- Target value: the actual game outcome from this playerβs perspective (+1 win, -1 loss, 0 draw)
The network learns by adjusting its weights to make its predictions match these targets. This requires three components: a loss function, backpropagation, and an optimizer.
Loss functions
We need to measure how wrong the networkβs predictions are. Different heads use different loss functions:
Cross-entropy loss for the policy head. Cross-entropy measures the distance between two probability distributions. If the target says move 4 should have probability 0.8 and the network says 0.2, the loss is high. The formula is:
L_policy = -sum(target_i * ln(predicted_i)) for all i
We add a tiny epsilon to avoid taking the log of zero:
#![allow(unused)]
fn main() {
/// Cross-entropy loss between target and predicted distributions.
fn cross_entropy_loss(target: &[f32], predicted: &[f32]) -> f32 {
let eps = 1e-8;
-target
.iter()
.zip(predicted.iter())
.map(|(&t, &p)| t * (p + eps).ln())
.sum::<f32>()
}
}
MSE loss for the value head. Mean squared error is the simplest regression loss: how far is the predicted value from the target?
L_value = (target - predicted)^2
#![allow(unused)]
fn main() {
/// Mean squared error between a target and predicted scalar.
fn mse_loss(target: f32, predicted: f32) -> f32 {
(target - predicted).powi(2)
}
}
Combined loss. The total loss is the sum of both, optionally with a weighting factor. We keep it simple and weight them equally:
L_total = L_policy + L_value
In practice, some implementations add an L2 regularization term to prevent overfitting, but for our tiny network and simple game, it is unnecessary.
Backpropagation
Backpropagation is the algorithm that computes how much each weight contributed to the loss, so we know which direction to adjust it. It works by applying the chain rule of calculus backwards through the network.
The key insight: if we know the gradient of the loss with respect to a layerβs output, we can compute the gradient with respect to that layerβs weights, biases, and input. The input gradient then becomes the output gradient for the previous layer, and we repeat.
For a dense layer computing output = activation(input Γ weights + bias):
-
Gradient through activation: Multiply the output gradient element-wise by the activationβs derivative.
- ReLU derivative: 1 if pre_activation > 0, else 0.
- Tanh derivative: 1 - tanh(x)^2.
- Identity derivative: 1.
-
Bias gradient: Equal to the activation gradient (since bias is just added).
-
Weight gradient:
input^T Γ activation_gradient(outer product when input is a single vector). -
Input gradient:
activation_gradient Γ weights^T(to propagate further back).
Here is the implementation:
#![allow(unused)]
fn main() {
impl Layer {
/// Compute gradients and return (input_gradient, weight_gradient, bias_gradient).
///
/// `output_grad`: gradient of the loss with respect to this layer's output.
/// `pre_act`: the pre-activation values saved during forward pass.
/// `input`: the input to this layer saved during forward pass.
fn backward(
&self,
output_grad: &[f32],
pre_act: &[f32],
input: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
// Step 1: gradient through activation function.
let act_grad: Vec<f32> = output_grad
.iter()
.zip(pre_act.iter())
.map(|(&og, &pa)| {
let d_act = match self.activation {
Activation::Relu => {
if pa > 0.0 {
1.0
} else {
0.0
}
}
Activation::Tanh => 1.0 - pa.tanh().powi(2),
Activation::Identity => 1.0,
};
og * d_act
})
.collect();
// Step 2: bias gradient = activation gradient.
let bias_grad = act_grad.clone();
// Step 3: weight gradient = input^T Γ act_grad.
// input is 1 Γ input_size, act_grad is 1 Γ output_size.
// Result is input_size Γ output_size (same shape as weights).
let mut weight_grad = vec![0.0; self.input_size * self.output_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
weight_grad[i * self.output_size + j] = input[i] * act_grad[j];
}
}
// Step 4: input gradient = act_grad Γ weights^T.
// act_grad is 1 Γ output_size, weights is input_size Γ output_size.
// We need to multiply act_grad by weights transposed.
let mut input_grad = vec![0.0; self.input_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
input_grad[i] += act_grad[j] * self.weights[i * self.output_size + j];
}
}
(input_grad, weight_grad, bias_grad)
}
/// Apply gradients using SGD: weight -= learning_rate * gradient.
fn apply_gradients(&mut self, weight_grad: &[f32], bias_grad: &[f32], lr: f32) {
for (w, g) in self.weights.iter_mut().zip(weight_grad.iter()) {
*w -= lr * g;
}
for (b, g) in self.biases.iter_mut().zip(bias_grad.iter()) {
*b -= lr * g;
}
}
}
}
The backward pass mirrors the forward pass in reverse. Each mathematical operation has a corresponding gradient rule, and we chain them together. This is the same algorithm that every deep learning framework implements β we are just doing it by hand.
The train_step method
Now we combine the forward pass, loss computation, and backward pass into a single training step. This method takes one training example and updates all the weights:
#![allow(unused)]
fn main() {
impl Network {
/// Perform one training step on a single example.
///
/// Returns the total loss (for monitoring training progress).
fn train_step(
&mut self,
state: &GameState,
target_policy: &[f32],
target_value: f32,
lr: f32,
) -> f32 {
// === Forward pass (saving intermediates) ===
let input = encode_board(state);
let (h1_out, h1_pre, h1_in) = self.hidden1.forward(&input);
let (h2_out, h2_pre, h2_in) = self.hidden2.forward(&h1_out);
let (p_logits, p_pre, p_in) = self.policy_head.forward(&h2_out);
let (v_out, v_pre, v_in) = self.value_head.forward(&h2_out);
let policy = softmax(&p_logits);
let value = v_out[0];
// === Compute losses ===
let policy_loss = cross_entropy_loss(target_policy, &policy);
let value_loss = mse_loss(target_value, value);
let total_loss = policy_loss + value_loss;
// === Backward pass ===
// Policy head gradient.
// For cross-entropy loss with softmax, the gradient of the loss
// with respect to the logits simplifies beautifully to:
// d_logits = predicted - target
let policy_grad: Vec<f32> = policy
.iter()
.zip(target_policy.iter())
.map(|(&p, &t)| p - t)
.collect();
// Value head gradient.
// For MSE loss with tanh, the gradient w.r.t. the value head output is:
// d_value = 2 * (predicted - target)
// The tanh derivative is handled inside backward().
let value_grad = vec![2.0 * (value - target_value)];
// Backprop through heads.
let (policy_input_grad, p_wg, p_bg) =
self.policy_head.backward(&policy_grad, &p_pre, &p_in);
let (value_input_grad, v_wg, v_bg) =
self.value_head.backward(&value_grad, &v_pre, &v_in);
// The trunk's output feeds both heads, so its gradient is the sum.
let h2_output_grad: Vec<f32> = policy_input_grad
.iter()
.zip(value_input_grad.iter())
.map(|(&a, &b)| a + b)
.collect();
// Backprop through trunk.
let (h1_output_grad, h2_wg, h2_bg) =
self.hidden2.backward(&h2_output_grad, &h2_pre, &h2_in);
let (_, h1_wg, h1_bg) = self.hidden1.backward(&h1_output_grad, &h1_pre, &h1_in);
// === Apply gradients (SGD) ===
self.hidden1.apply_gradients(&h1_wg, &h1_bg, lr);
self.hidden2.apply_gradients(&h2_wg, &h2_bg, lr);
self.policy_head.apply_gradients(&p_wg, &p_bg, lr);
self.value_head.apply_gradients(&v_wg, &v_bg, lr);
total_loss
}
}
}
There is one subtle but important detail in the policy gradient. When you combine softmax with cross-entropy loss, the gradient of the loss with respect to the logits (not the probabilities) simplifies to predicted - target. This is one of the most elegant results in neural network math, and it is why softmax + cross-entropy is such a popular combination. We take advantage of this by passing the simplified gradient directly to the policy headβs backward method, which treats the head as having Identity activation β the softmax derivative is already baked into the predicted - target formula.
Another detail: where the trunk branches into the two heads, the gradient flowing back is the sum of the gradients from both heads. This follows from calculus β if a variable contributes to two terms in the loss, its total gradient is the sum of its contributions to each.
SGD optimizer
Our optimizer is stochastic gradient descent (SGD) in its simplest form: weight -= learning_rate * gradient. The apply_gradients method on Layer already implements this. There are fancier optimizers (Adam, RMSProp) that adapt the learning rate per-parameter, but plain SGD works well for small networks. The learning rate is a hyperparameter we will tune β a typical starting point is 0.01.
Putting it all together
Here is the complete, compilable code. All the pieces above are assembled into a single program that creates a network, runs a forward pass on a sample board, and performs a training step:
use std::fmt;
// βββ Player and GameState (from Part 1) ββββββββββββββββββββββββββ
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player { X, O }
impl Player {
pub fn opponent(self) -> Self {
match self { Player::X => Player::O, Player::O => Player::X }
}
}
impl fmt::Display for Player {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { Player::X => write!(f, "X"), Player::O => write!(f, "O") }
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
pub board: [Option<Player>; 9],
pub current_player: Player,
}
impl GameState {
pub fn new() -> Self {
GameState { board: [None; 9], current_player: Player::X }
}
}
// βββ Random number generator βββββββββββββββββββββββββββββββββββββ
struct Rng { state: u64 }
impl Rng {
fn new(seed: u64) -> Self { Rng { state: seed } }
fn next_f32(&mut self) -> f32 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
let bits = ((self.state >> 17) & 0x7FFF) as f32;
bits / 32768.0
}
fn uniform(&mut self, limit: f32) -> f32 {
self.next_f32() * 2.0 * limit - limit
}
}
// βββ Matrix utilities ββββββββββββββββββββββββββββββββββββββββββββ
fn mat_mul(a: &[f32], b: &[f32], rows: usize, inner: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0.0; rows * cols];
for r in 0..rows {
for c in 0..cols {
let mut sum = 0.0;
for k in 0..inner {
sum += a[r * inner + k] * b[k * cols + c];
}
out[r * cols + c] = sum;
}
}
out
}
fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
// βββ Layer βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#[derive(Debug, Clone, Copy)]
enum Activation { Relu, Tanh, Identity }
#[derive(Debug, Clone)]
struct Layer {
weights: Vec<f32>,
biases: Vec<f32>,
input_size: usize,
output_size: usize,
activation: Activation,
}
impl Layer {
fn new(
input_size: usize,
output_size: usize,
activation: Activation,
rng: &mut Rng,
) -> Self {
let limit = (6.0 / (input_size + output_size) as f32).sqrt();
let weights: Vec<f32> = (0..input_size * output_size)
.map(|_| rng.uniform(limit))
.collect();
let biases = vec![0.0; output_size];
Layer { weights, biases, input_size, output_size, activation }
}
fn forward(&self, input: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let pre_act = vec_add(
&mat_mul(input, &self.weights, 1, self.input_size, self.output_size),
&self.biases,
);
let output = match self.activation {
Activation::Relu => pre_act.iter().map(|&x| x.max(0.0)).collect(),
Activation::Tanh => pre_act.iter().map(|&x| x.tanh()).collect(),
Activation::Identity => pre_act.clone(),
};
(output, pre_act, input.to_vec())
}
fn backward(
&self,
output_grad: &[f32],
pre_act: &[f32],
input: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let act_grad: Vec<f32> = output_grad
.iter()
.zip(pre_act.iter())
.map(|(&og, &pa)| {
let d_act = match self.activation {
Activation::Relu => if pa > 0.0 { 1.0 } else { 0.0 },
Activation::Tanh => 1.0 - pa.tanh().powi(2),
Activation::Identity => 1.0,
};
og * d_act
})
.collect();
let bias_grad = act_grad.clone();
let mut weight_grad = vec![0.0; self.input_size * self.output_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
weight_grad[i * self.output_size + j] = input[i] * act_grad[j];
}
}
let mut input_grad = vec![0.0; self.input_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
input_grad[i] += act_grad[j] * self.weights[i * self.output_size + j];
}
}
(input_grad, weight_grad, bias_grad)
}
fn apply_gradients(&mut self, weight_grad: &[f32], bias_grad: &[f32], lr: f32) {
for (w, g) in self.weights.iter_mut().zip(weight_grad.iter()) {
*w -= lr * g;
}
for (b, g) in self.biases.iter_mut().zip(bias_grad.iter()) {
*b -= lr * g;
}
}
}
// βββ Network βββββββββββββββββββββββββββββββββββββββββββββββββββββ
#[derive(Debug, Clone)]
struct Network {
hidden1: Layer,
hidden2: Layer,
policy_head: Layer,
value_head: Layer,
}
fn encode_board(state: &GameState) -> Vec<f32> {
let mut input = vec![0.0_f32; 18];
let current = state.current_player;
for i in 0..9 {
match state.board[i] {
Some(p) if p == current => input[i] = 1.0,
Some(_) => input[i + 9] = 1.0,
None => {}
}
}
input
}
fn mask_illegal(policy: &[f32], state: &GameState) -> Vec<f32> {
let mut masked: Vec<f32> = policy
.iter()
.enumerate()
.map(|(i, &p)| if state.board[i].is_none() { p } else { 0.0 })
.collect();
let sum: f32 = masked.iter().sum();
if sum > 0.0 {
for p in &mut masked {
*p /= sum;
}
}
masked
}
fn cross_entropy_loss(target: &[f32], predicted: &[f32]) -> f32 {
let eps = 1e-8;
-target
.iter()
.zip(predicted.iter())
.map(|(&t, &p)| t * (p + eps).ln())
.sum::<f32>()
}
fn mse_loss(target: f32, predicted: f32) -> f32 {
(target - predicted).powi(2)
}
impl Network {
fn new(seed: u64) -> Self {
let mut rng = Rng::new(seed);
Network {
hidden1: Layer::new(18, 64, Activation::Relu, &mut rng),
hidden2: Layer::new(64, 64, Activation::Relu, &mut rng),
policy_head: Layer::new(64, 9, Activation::Identity, &mut rng),
value_head: Layer::new(64, 1, Activation::Tanh, &mut rng),
}
}
fn predict(&self, state: &GameState) -> (Vec<f32>, f32) {
let input = encode_board(state);
let (h1_out, _, _) = self.hidden1.forward(&input);
let (h2_out, _, _) = self.hidden2.forward(&h1_out);
let (policy_logits, _, _) = self.policy_head.forward(&h2_out);
let policy = softmax(&policy_logits);
let policy = mask_illegal(&policy, state);
let (value_out, _, _) = self.value_head.forward(&h2_out);
(policy, value_out[0])
}
fn train_step(
&mut self,
state: &GameState,
target_policy: &[f32],
target_value: f32,
lr: f32,
) -> f32 {
let input = encode_board(state);
let (h1_out, h1_pre, h1_in) = self.hidden1.forward(&input);
let (h2_out, h2_pre, h2_in) = self.hidden2.forward(&h1_out);
let (p_logits, p_pre, p_in) = self.policy_head.forward(&h2_out);
let (v_out, v_pre, v_in) = self.value_head.forward(&h2_out);
let policy = softmax(&p_logits);
let value = v_out[0];
let policy_loss = cross_entropy_loss(target_policy, &policy);
let value_loss = mse_loss(target_value, value);
let total_loss = policy_loss + value_loss;
// Backprop: policy head
let policy_grad: Vec<f32> = policy
.iter()
.zip(target_policy.iter())
.map(|(&p, &t)| p - t)
.collect();
let value_grad = vec![2.0 * (value - target_value)];
let (pg_input, p_wg, p_bg) =
self.policy_head.backward(&policy_grad, &p_pre, &p_in);
let (vg_input, v_wg, v_bg) =
self.value_head.backward(&value_grad, &v_pre, &v_in);
// Sum gradients from both heads at the trunk output.
let h2_grad: Vec<f32> = pg_input.iter()
.zip(vg_input.iter())
.map(|(&a, &b)| a + b)
.collect();
let (h1_grad, h2_wg, h2_bg) =
self.hidden2.backward(&h2_grad, &h2_pre, &h2_in);
let (_, h1_wg, h1_bg) =
self.hidden1.backward(&h1_grad, &h1_pre, &h1_in);
// SGD update.
self.hidden1.apply_gradients(&h1_wg, &h1_bg, lr);
self.hidden2.apply_gradients(&h2_wg, &h2_bg, lr);
self.policy_head.apply_gradients(&p_wg, &p_bg, lr);
self.value_head.apply_gradients(&v_wg, &v_bg, lr);
total_loss
}
}
// βββ Test ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
fn main() {
let mut net = Network::new(42);
// Create a sample board position:
// X | O | .
// ---------
// . | X | .
// ---------
// . | . | O
let mut state = GameState::new();
state.board[0] = Some(Player::X);
state.board[1] = Some(Player::O);
state.board[4] = Some(Player::X);
state.board[8] = Some(Player::O);
// It is X's turn (4 pieces placed, 2 each).
// Forward pass.
let (policy, value) = net.predict(&state);
println!("Board:");
for r in 0..3 {
for c in 0..3 {
let i = r * 3 + c;
match state.board[i] {
Some(Player::X) => print!(" X "),
Some(Player::O) => print!(" O "),
None => print!(" . "),
}
if c < 2 { print!("|"); }
}
println!();
if r < 2 { println!("-----------"); }
}
println!("\nPolicy (move probabilities):");
for i in 0..9 {
let label = if state.board[i].is_some() { "occupied" } else { "legal" };
println!(" cell {}: {:.4} ({})", i, policy[i], label);
}
let policy_sum: f32 = policy.iter().sum();
println!(" sum: {:.4}", policy_sum);
println!("\nValue: {:.4}", value);
// Verify shapes and constraints.
assert_eq!(policy.len(), 9, "Policy must have 9 elements");
assert!((policy_sum - 1.0).abs() < 1e-5, "Policy must sum to ~1.0");
assert!(policy[0] == 0.0, "Occupied cell 0 must have 0 probability");
assert!(policy[1] == 0.0, "Occupied cell 1 must have 0 probability");
assert!(policy[4] == 0.0, "Occupied cell 4 must have 0 probability");
assert!(policy[8] == 0.0, "Occupied cell 8 must have 0 probability");
assert!(value >= -1.0 && value <= 1.0, "Value must be in [-1, +1]");
println!("\nAll shape checks passed!\n");
// Training test: run a few training steps and verify the loss decreases.
let target_policy = [0.0, 0.0, 0.3, 0.2, 0.0, 0.2, 0.1, 0.2, 0.0];
let target_value = 0.5;
let lr = 0.01;
let loss_before = net.train_step(&state, &target_policy, target_value, lr);
println!("Loss before training: {:.4}", loss_before);
// Run 100 training steps on the same example.
let mut last_loss = loss_before;
for _ in 0..100 {
last_loss = net.train_step(&state, &target_policy, target_value, lr);
}
println!("Loss after 100 steps: {:.4}", last_loss);
assert!(
last_loss < loss_before,
"Loss should decrease after training"
);
println!("Training is working β loss decreased.\n");
}
Running this program produces output like:
Board:
X | O | .
-----------
. | X | .
-----------
. | . | O
Policy (move probabilities):
cell 0: 0.0000 (occupied)
cell 1: 0.0000 (occupied)
cell 2: 0.1891 (legal)
cell 3: 0.2004 (legal)
cell 4: 0.0000 (occupied)
cell 5: 0.2147 (legal)
cell 6: 0.1965 (legal)
cell 7: 0.1993 (legal)
cell 8: 0.0000 (occupied)
sum: 1.0000
Value: -0.0274
All shape checks passed!
Loss before training: 2.0835
Loss after 100 steps: 0.5921
Training is working β loss decreased.
The exact numbers depend on the random seed, but the structure is always the same: occupied cells have zero probability, legal cells share the remaining probability mass (roughly uniform initially because the weights are random), the value is a small number near zero, and training reduces the loss.
What we built
Letβs step back and appreciate what we have. From zero external dependencies, we implemented:
- A complete feedforward neural network with dense layers, ReLU/tanh/softmax activations, and Xavier initialization.
- A forward pass that converts a board position into a policy distribution and value estimate.
- Illegal move masking that ensures the network never suggests occupied cells.
- Cross-entropy and MSE loss functions to measure prediction quality.
- Full backpropagation through the entire network, including the tricky softmax + cross-entropy gradient simplification and gradient summation at the trunk fork.
- SGD optimization that updates weights to reduce the loss.
This is a real, working neural network β the same fundamental architecture (input β hidden layers β output heads, trained with backprop and SGD) that powers everything from image classifiers to language models. Ours happens to have 6,026 parameters instead of billions, but the principles are identical.
In the next section, we will use this network to generate training data through MCTS self-play and teach the network to actually play good Tic-Tac-Toe.
11. Exercise 3: train the network on MCTS data
We have two powerful components sitting side by side: a pure-MCTS player that discovers strong moves through thousands of random rollouts (Β§7-8), and a neural network that can be trained to predict move probabilities and position values (Β§9-10). Right now they are strangers. This exercise introduces them.
The idea is simple but profound: use MCTS as a teacher and the neural network as a student. MCTS plays games against itself, generating high-quality data about which moves are good and who is winning. We then train the network to mimic MCTSβs judgments. If successful, the network will learn to instantly predict what MCTS would conclude after thousands of iterations of search.
This is exactly how AlphaGo Zero bootstraps its neural network. The network starts knowing nothing β its weights are random, its predictions are noise. But MCTS does not need the network. MCTS can play reasonably well using random rollouts alone. So we let MCTS generate training data, train the network on it, and later (in Β§12) we will plug the trained network back into MCTS to make it even stronger.
What data does the network need?
Recall from Β§10 that our network has two output heads:
- Policy head: a probability distribution over the 9 cells β βwhich move should I play?β
- Value head: a single number in [-1, +1] β βwho is winning from this position?β
To train these heads, we need labeled examples: positions where we know the right policy and the right value. MCTS gives us both.
Policy targets from visit counts. After MCTS finishes searching a position (say, 1000 iterations), the rootβs children have accumulated visit counts. The most-visited move is the one MCTS considers best, but the full visit distribution is more informative. If cell 4 got 500 visits, cell 0 got 200, and cell 2 got 300, the visit distribution is [0.2, 0.0, 0.3, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0]. This is our policy target β it tells the network not just the best move, but how the search budget was distributed across all moves.
Value targets from game outcomes. After the game finishes, we know the result: X won (+1), O won (-1), or draw (0). We assign this outcome to every position in the game, adjusted for perspective. If X won, then positions where X is the current player get a value target of +1 (the current player won), and positions where O is the current player get -1 (the current player lost).
The TrainingExample struct
Each position in a self-play game produces one training example:
#![allow(unused)]
fn main() {
/// A single training example generated from MCTS self-play.
///
/// Contains the board position, the MCTS visit distribution as a
/// policy target, and the eventual game outcome as a value target.
struct TrainingExample {
/// The game state at this position.
state: GameState,
/// The MCTS visit distribution (visits per move / total visits).
/// This is the policy target β what the network should learn to predict.
policy_target: [f32; 9],
/// The game outcome from the perspective of the current player.
/// +1.0 = current player won, -1.0 = current player lost, 0.0 = draw.
value_target: f32,
}
}
Generating training data
The data generation pipeline works as follows:
- Start a new game from the empty board.
- At each position, run MCTS (e.g., 800 iterations) to get the visit distribution.
- Record a
TrainingExamplewith the current state and the visit distribution. Leave the value target blank for now. - Play the move that MCTS chose (the most-visited child).
- Repeat steps 2-4 until the game ends.
- Once the game ends, go back and fill in the value target for every recorded position using the game outcome.
- Repeat for many games (e.g., 100-500 games).
The key function we need is one that extracts the visit distribution from an MCTS tree after search:
#![allow(unused)]
fn main() {
/// Extracts the visit distribution from the root of an MCTS tree.
///
/// Returns an array of 9 floats representing the fraction of visits
/// each move received. Moves that were never visited (or are illegal)
/// get 0.0. The values sum to 1.0.
fn visit_distribution(tree: &MctsTree) -> [f32; 9] {
let root = &tree.nodes[0];
let total: u32 = root.children.iter()
.map(|&idx| tree.nodes[idx].visit_count)
.sum();
let mut dist = [0.0_f32; 9];
if total == 0 {
return dist;
}
for &child_idx in &root.children {
let child = &tree.nodes[child_idx];
let mv = child.move_from_parent
.expect("root children always have a move");
dist[mv] = child.visit_count as f32 / total as f32;
}
dist
}
}
Now we can write the self-play data generation:
#![allow(unused)]
fn main() {
/// Plays one complete game using MCTS and collects training examples.
///
/// Returns a vector of `TrainingExample` β one per move in the game.
/// Each example has the board state, the MCTS visit distribution as
/// the policy target, and the game outcome as the value target.
fn generate_game_data(
mcts_iterations: u32,
rng: &mut impl rand::Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
// Play the game, recording each position and its visit distribution.
while !state.is_terminal() {
let mut tree = MctsTree::new(state.clone());
tree.search(mcts_iterations, rng);
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
let best_move = tree.best_move();
state = state.apply_move(best_move);
}
// Determine the game outcome.
let winner = state.winner();
// Convert history into training examples with value targets.
history
.into_iter()
.map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample {
state: s,
policy_target,
value_target,
}
})
.collect()
}
}
To generate a full dataset, we call this function many times:
#![allow(unused)]
fn main() {
/// Generates training data from many self-play games.
fn generate_training_data(
num_games: u32,
mcts_iterations: u32,
rng: &mut impl rand::Rng,
) -> Vec<TrainingExample> {
let mut all_examples = Vec::new();
for game_num in 1..=num_games {
let examples = generate_game_data(mcts_iterations, rng);
all_examples.extend(examples);
if game_num % 50 == 0 {
println!(
"Generated {} games ({} examples so far)",
game_num,
all_examples.len()
);
}
}
println!(
"Total: {} games, {} training examples",
num_games,
all_examples.len()
);
all_examples
}
}
Each Tic-Tac-Toe game lasts 5-9 moves, so 100 games produce roughly 500-900 training examples. With 800 MCTS iterations per move, generating 100 games takes a few seconds on a modern machine.
The training loop
Now we train the network on this data. The approach is standard mini-batch stochastic gradient descent:
- Shuffle the training examples to break correlations (examples from the same game are adjacent, which could bias learning).
- Split into mini-batches of size B (e.g., 32).
- For each mini-batch, process every example through
train_stepand accumulate the average loss. - Repeat for multiple epochs (full passes through the dataset).
- Track the loss β it should decrease over time.
Why mini-batches rather than processing one example at a time? Single-example SGD is noisy β the gradient from one position might point in a misleading direction. Averaging gradients over a batch of 32 examples smooths out the noise, leading to more stable learning. Why not process the entire dataset at once? That would be standard gradient descent, which is slow for large datasets because you only update weights once per pass. Mini-batches strike a balance: stable enough to converge, frequent enough to learn quickly.
Here is the training loop:
#![allow(unused)]
fn main() {
/// Shuffles the training examples in place using Fisher-Yates.
fn shuffle(examples: &mut [TrainingExample], rng: &mut impl rand::Rng) {
for i in (1..examples.len()).rev() {
let j = rng.gen_range(0..=i);
examples.swap(i, j);
}
}
/// Trains the network on the given examples for a specified number
/// of epochs using mini-batch SGD.
///
/// Returns the average loss from the final epoch.
fn train_network(
net: &mut Network,
examples: &mut Vec<TrainingExample>,
epochs: u32,
batch_size: usize,
learning_rate: f32,
rng: &mut impl rand::Rng,
) -> f32 {
let mut final_avg_loss = 0.0;
for epoch in 0..epochs {
shuffle(examples, rng);
let mut epoch_loss = 0.0;
let mut count = 0;
// Process in mini-batches.
for batch_start in (0..examples.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(examples.len());
for i in batch_start..batch_end {
let ex = &examples[i];
let loss = net.train_step(
&ex.state,
&ex.policy_target,
ex.value_target,
learning_rate,
);
epoch_loss += loss;
count += 1;
}
}
let avg_loss = epoch_loss / count as f32;
final_avg_loss = avg_loss;
if (epoch + 1) % 5 == 0 || epoch == 0 {
println!("Epoch {:>3}: avg loss = {:.4}", epoch + 1, avg_loss);
}
}
final_avg_loss
}
}
A note on our mini-batch implementation: because our train_step method applies gradients after every single example (online SGD), processing examples in βbatchesβ here is really about grouping for reporting purposes. A true mini-batch implementation would accumulate gradients across the batch and apply them once. For our small network and dataset, per-example SGD works well enough. The shuffle is what matters most β it ensures the network sees a diverse mix of positions rather than all moves from the same game in sequence.
Your task
Write a program that:
- Generates training data from 100 MCTS self-play games (800 iterations per move).
- Creates a fresh neural network.
- Trains the network for 30 epochs with a learning rate of 0.01 and a batch size of 32.
- Evaluates the trained network on a few sample positions:
- The empty board (opening position).
- A mid-game position.
- Prints the networkβs policy predictions and compares them to what MCTS would choose.
Try writing this yourself before looking at the solution. You already have all the building blocks β this exercise is about wiring them together.
Solution
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand::Rng;
use std::f64::consts::SQRT_2;
use std::fmt;
// βββ Player and GameState ββββββββββββββββββββββββββββββββββββββββ
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Player { X, O }
impl Player {
pub fn opponent(self) -> Self {
match self { Player::X => Player::O, Player::O => Player::X }
}
}
impl fmt::Display for Player {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { Player::X => write!(f, "X"), Player::O => write!(f, "O") }
}
}
const WIN_LINES: [[usize; 3]; 8] = [
[0, 1, 2], [3, 4, 5], [6, 7, 8],
[0, 3, 6], [1, 4, 7], [2, 5, 8],
[0, 4, 8], [2, 4, 6],
];
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GameState {
pub board: [Option<Player>; 9],
pub current_player: Player,
}
impl GameState {
pub fn new() -> Self {
GameState { board: [None; 9], current_player: Player::X }
}
pub fn winner(&self) -> Option<Player> {
for line in &WIN_LINES {
let [a, b, c] = *line;
if let (Some(p1), Some(p2), Some(p3)) =
(self.board[a], self.board[b], self.board[c])
{
if p1 == p2 && p2 == p3 {
return Some(p1);
}
}
}
None
}
pub fn is_terminal(&self) -> bool {
self.winner().is_some() || self.board.iter().all(|c| c.is_some())
}
pub fn legal_moves(&self) -> Vec<usize> {
(0..9).filter(|&i| self.board[i].is_none()).collect()
}
pub fn apply_move(&self, index: usize) -> Self {
assert!(
self.board[index].is_none(),
"Cell {} is already occupied",
index
);
let mut new_board = self.board;
new_board[index] = Some(self.current_player);
GameState {
board: new_board,
current_player: self.current_player.opponent(),
}
}
}
impl fmt::Display for GameState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for r in 0..3 {
for c in 0..3 {
let i = r * 3 + c;
match self.board[i] {
Some(Player::X) => write!(f, " X ")?,
Some(Player::O) => write!(f, " O ")?,
None => write!(f, " . ")?,
}
if c < 2 { write!(f, "|")?; }
}
writeln!(f)?;
if r < 2 { writeln!(f, "-----------")?; }
}
Ok(())
}
}
// βββ MCTS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
}
struct MctsTree {
nodes: Vec<MctsNode>,
}
fn uct_score(child_value: f64, child_visits: u32, parent_visits: u32) -> f64 {
if child_visits == 0 {
return f64::INFINITY;
}
let n = child_visits as f64;
let exploitation = child_value / n;
let exploration = SQRT_2 * ((parent_visits as f64).ln() / n).sqrt();
exploitation + exploration
}
impl MctsTree {
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
MctsTree { nodes: vec![root] }
}
fn select(&self) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
if !node.unexpanded_moves.is_empty() { return current; }
if node.children.is_empty() { return current; }
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = uct_score(
child.total_value,
child.visit_count,
parent_visits,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
fn expand(&mut self, node_idx: usize) -> Option<usize> {
let mv = self.nodes[node_idx].unexpanded_moves.pop()?;
let child_state = self.nodes[node_idx].state.apply_move(mv);
let unexpanded = child_state.legal_moves();
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
Some(child_idx)
}
fn simulate(&self, state: &GameState, rng: &mut impl Rng) -> f64 {
let mut current = state.clone();
while !current.is_terminal() {
let moves = current.legal_moves();
let &mv = moves.choose(rng).expect("non-terminal has moves");
current = current.apply_move(mv);
}
let rollout_player = state.current_player.opponent();
match current.winner() {
Some(winner) if winner == rollout_player => 1.0,
Some(_) => -1.0,
None => 0.0,
}
}
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
fn search(&mut self, iterations: u32, rng: &mut impl Rng) -> usize {
for _ in 0..iterations {
let selected = self.select();
let node_to_simulate = match self.expand(selected) {
Some(child_idx) => child_idx,
None => selected,
};
let state = self.nodes[node_to_simulate].state.clone();
let value = self.simulate(&state, rng);
self.backpropagate(node_to_simulate, value);
}
self.best_move()
}
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have children after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root children have moves")
}
}
// βββ Neural network (from Β§9-10) ββββββββββββββββββββββββββββββββ
struct SimpleRng { state: u64 }
impl SimpleRng {
fn new(seed: u64) -> Self { SimpleRng { state: seed } }
fn next_f32(&mut self) -> f32 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
let bits = ((self.state >> 17) & 0x7FFF) as f32;
bits / 32768.0
}
fn uniform(&mut self, limit: f32) -> f32 {
self.next_f32() * 2.0 * limit - limit
}
}
fn mat_mul(a: &[f32], b: &[f32], rows: usize, inner: usize, cols: usize) -> Vec<f32> {
let mut out = vec![0.0; rows * cols];
for r in 0..rows {
for c in 0..cols {
let mut sum = 0.0;
for k in 0..inner {
sum += a[r * inner + k] * b[k * cols + c];
}
out[r * cols + c] = sum;
}
}
out
}
fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
}
fn softmax(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
#[derive(Debug, Clone, Copy)]
enum Activation { Relu, Tanh, Identity }
#[derive(Debug, Clone)]
struct Layer {
weights: Vec<f32>,
biases: Vec<f32>,
input_size: usize,
output_size: usize,
activation: Activation,
}
impl Layer {
fn new(
input_size: usize,
output_size: usize,
activation: Activation,
rng: &mut SimpleRng,
) -> Self {
let limit = (6.0 / (input_size + output_size) as f32).sqrt();
let weights: Vec<f32> = (0..input_size * output_size)
.map(|_| rng.uniform(limit))
.collect();
let biases = vec![0.0; output_size];
Layer { weights, biases, input_size, output_size, activation }
}
fn forward(&self, input: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let pre_act = vec_add(
&mat_mul(input, &self.weights, 1, self.input_size, self.output_size),
&self.biases,
);
let output = match self.activation {
Activation::Relu => pre_act.iter().map(|&x| x.max(0.0)).collect(),
Activation::Tanh => pre_act.iter().map(|&x| x.tanh()).collect(),
Activation::Identity => pre_act.clone(),
};
(output, pre_act, input.to_vec())
}
fn backward(
&self,
output_grad: &[f32],
pre_act: &[f32],
input: &[f32],
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let act_grad: Vec<f32> = output_grad
.iter()
.zip(pre_act.iter())
.map(|(&og, &pa)| {
let d_act = match self.activation {
Activation::Relu => if pa > 0.0 { 1.0 } else { 0.0 },
Activation::Tanh => 1.0 - pa.tanh().powi(2),
Activation::Identity => 1.0,
};
og * d_act
})
.collect();
let bias_grad = act_grad.clone();
let mut weight_grad = vec![0.0; self.input_size * self.output_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
weight_grad[i * self.output_size + j] = input[i] * act_grad[j];
}
}
let mut input_grad = vec![0.0; self.input_size];
for i in 0..self.input_size {
for j in 0..self.output_size {
input_grad[i] += act_grad[j] * self.weights[i * self.output_size + j];
}
}
(input_grad, weight_grad, bias_grad)
}
fn apply_gradients(&mut self, weight_grad: &[f32], bias_grad: &[f32], lr: f32) {
for (w, g) in self.weights.iter_mut().zip(weight_grad.iter()) {
*w -= lr * g;
}
for (b, g) in self.biases.iter_mut().zip(bias_grad.iter()) {
*b -= lr * g;
}
}
}
#[derive(Debug, Clone)]
struct Network {
hidden1: Layer,
hidden2: Layer,
policy_head: Layer,
value_head: Layer,
}
fn encode_board(state: &GameState) -> Vec<f32> {
let mut input = vec![0.0_f32; 18];
let current = state.current_player;
for i in 0..9 {
match state.board[i] {
Some(p) if p == current => input[i] = 1.0,
Some(_) => input[i + 9] = 1.0,
None => {}
}
}
input
}
fn mask_illegal(policy: &[f32], state: &GameState) -> Vec<f32> {
let mut masked: Vec<f32> = policy
.iter()
.enumerate()
.map(|(i, &p)| if state.board[i].is_none() { p } else { 0.0 })
.collect();
let sum: f32 = masked.iter().sum();
if sum > 0.0 {
for p in &mut masked {
*p /= sum;
}
}
masked
}
fn cross_entropy_loss(target: &[f32], predicted: &[f32]) -> f32 {
let eps = 1e-8;
-target
.iter()
.zip(predicted.iter())
.map(|(&t, &p)| t * (p + eps).ln())
.sum::<f32>()
}
fn mse_loss(target: f32, predicted: f32) -> f32 {
(target - predicted).powi(2)
}
impl Network {
fn new(seed: u64) -> Self {
let mut rng = SimpleRng::new(seed);
Network {
hidden1: Layer::new(18, 64, Activation::Relu, &mut rng),
hidden2: Layer::new(64, 64, Activation::Relu, &mut rng),
policy_head: Layer::new(64, 9, Activation::Identity, &mut rng),
value_head: Layer::new(64, 1, Activation::Tanh, &mut rng),
}
}
fn predict(&self, state: &GameState) -> (Vec<f32>, f32) {
let input = encode_board(state);
let (h1_out, _, _) = self.hidden1.forward(&input);
let (h2_out, _, _) = self.hidden2.forward(&h1_out);
let (policy_logits, _, _) = self.policy_head.forward(&h2_out);
let policy = softmax(&policy_logits);
let policy = mask_illegal(&policy, state);
let (value_out, _, _) = self.value_head.forward(&h2_out);
(policy, value_out[0])
}
fn train_step(
&mut self,
state: &GameState,
target_policy: &[f32],
target_value: f32,
lr: f32,
) -> f32 {
let input = encode_board(state);
let (h1_out, h1_pre, h1_in) = self.hidden1.forward(&input);
let (h2_out, h2_pre, h2_in) = self.hidden2.forward(&h1_out);
let (p_logits, p_pre, p_in) = self.policy_head.forward(&h2_out);
let (v_out, v_pre, v_in) = self.value_head.forward(&h2_out);
let policy = softmax(&p_logits);
let value = v_out[0];
let policy_loss = cross_entropy_loss(target_policy, &policy);
let value_loss = mse_loss(target_value, value);
let total_loss = policy_loss + value_loss;
let policy_grad: Vec<f32> = policy
.iter()
.zip(target_policy.iter())
.map(|(&p, &t)| p - t)
.collect();
let value_grad = vec![2.0 * (value - target_value)];
let (pg_input, p_wg, p_bg) =
self.policy_head.backward(&policy_grad, &p_pre, &p_in);
let (vg_input, v_wg, v_bg) =
self.value_head.backward(&value_grad, &v_pre, &v_in);
let h2_grad: Vec<f32> = pg_input.iter()
.zip(vg_input.iter())
.map(|(&a, &b)| a + b)
.collect();
let (h1_grad, h2_wg, h2_bg) =
self.hidden2.backward(&h2_grad, &h2_pre, &h2_in);
let (_, h1_wg, h1_bg) =
self.hidden1.backward(&h1_grad, &h1_pre, &h1_in);
self.hidden1.apply_gradients(&h1_wg, &h1_bg, lr);
self.hidden2.apply_gradients(&h2_wg, &h2_bg, lr);
self.policy_head.apply_gradients(&p_wg, &p_bg, lr);
self.value_head.apply_gradients(&v_wg, &v_bg, lr);
total_loss
}
}
// βββ Training data generation ββββββββββββββββββββββββββββββββββββ
struct TrainingExample {
state: GameState,
policy_target: [f32; 9],
value_target: f32,
}
/// Extracts the visit distribution from the root of an MCTS tree.
fn visit_distribution(tree: &MctsTree) -> [f32; 9] {
let root = &tree.nodes[0];
let total: u32 = root.children.iter()
.map(|&idx| tree.nodes[idx].visit_count)
.sum();
let mut dist = [0.0_f32; 9];
if total == 0 {
return dist;
}
for &child_idx in &root.children {
let child = &tree.nodes[child_idx];
let mv = child.move_from_parent
.expect("root children always have a move");
dist[mv] = child.visit_count as f32 / total as f32;
}
dist
}
/// Plays one complete game using MCTS and collects training examples.
fn generate_game_data(
mcts_iterations: u32,
rng: &mut impl Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
while !state.is_terminal() {
let mut tree = MctsTree::new(state.clone());
tree.search(mcts_iterations, rng);
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
let best_move = tree.best_move();
state = state.apply_move(best_move);
}
let winner = state.winner();
history
.into_iter()
.map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample {
state: s,
policy_target,
value_target,
}
})
.collect()
}
/// Generates training data from many self-play games.
fn generate_training_data(
num_games: u32,
mcts_iterations: u32,
rng: &mut impl Rng,
) -> Vec<TrainingExample> {
let mut all_examples = Vec::new();
for game_num in 1..=num_games {
let examples = generate_game_data(mcts_iterations, rng);
all_examples.extend(examples);
if game_num % 50 == 0 {
println!(
"Generated {} games ({} examples so far)",
game_num,
all_examples.len()
);
}
}
println!(
"Total: {} games, {} training examples",
num_games,
all_examples.len()
);
all_examples
}
// βββ Training loop βββββββββββββββββββββββββββββββββββββββββββββββ
/// Shuffles the training examples using Fisher-Yates.
fn shuffle_examples(examples: &mut [TrainingExample], rng: &mut impl Rng) {
for i in (1..examples.len()).rev() {
let j = rng.gen_range(0..=i);
examples.swap(i, j);
}
}
/// Trains the network on the given examples for the specified
/// number of epochs. Returns the average loss from the final epoch.
fn train_network(
net: &mut Network,
examples: &mut Vec<TrainingExample>,
epochs: u32,
batch_size: usize,
learning_rate: f32,
rng: &mut impl Rng,
) -> f32 {
let mut final_avg_loss = 0.0;
for epoch in 0..epochs {
shuffle_examples(examples, rng);
let mut epoch_loss = 0.0;
let mut count = 0;
for batch_start in (0..examples.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(examples.len());
for i in batch_start..batch_end {
let ex = &examples[i];
let loss = net.train_step(
&ex.state,
&ex.policy_target,
ex.value_target,
learning_rate,
);
epoch_loss += loss;
count += 1;
}
}
let avg_loss = epoch_loss / count as f32;
final_avg_loss = avg_loss;
if (epoch + 1) % 5 == 0 || epoch == 0 {
println!("Epoch {:>3}: avg loss = {:.4}", epoch + 1, avg_loss);
}
}
final_avg_loss
}
// βββ Evaluation ββββββββββββββββββββββββββββββββββββββββββββββββββ
/// Prints the network's policy and value predictions for a position,
/// alongside the MCTS visit distribution for comparison.
fn evaluate_position(
label: &str,
state: &GameState,
net: &Network,
mcts_iterations: u32,
rng: &mut impl Rng,
) {
// Get network predictions.
let (net_policy, net_value) = net.predict(state);
// Get MCTS visit distribution.
let mut tree = MctsTree::new(state.clone());
tree.search(mcts_iterations, rng);
let mcts_dist = visit_distribution(&tree);
println!("=== {} ===", label);
println!("{}", state);
println!("Current player: {}", state.current_player);
println!();
println!("Cell Network MCTS");
println!("---- ------- -----");
for i in 0..9 {
let label = if state.board[i].is_some() { " (occupied)" } else { "" };
println!(
" {} {:.3} {:.3}{}",
i, net_policy[i], mcts_dist[i], label
);
}
println!();
let net_best = net_policy
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
let mcts_best = tree.best_move();
println!(
"Network best move: {} | MCTS best move: {} | {}",
net_best,
mcts_best,
if net_best == mcts_best { "MATCH" } else { "differ" }
);
println!("Network value: {:.4} (positive = good for {})", net_value, state.current_player);
println!();
}
// βββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
fn main() {
let mut rng = StdRng::seed_from_u64(42);
// Step 1: Generate training data.
println!("--- Generating training data ---\n");
let mut examples = generate_training_data(100, 800, &mut rng);
// Step 2: Create a fresh network.
let mut net = Network::new(123);
// Step 3: Evaluate BEFORE training (to see the baseline).
println!("\n--- Before training ---\n");
let opening = GameState::new();
evaluate_position("Opening (before training)", &opening, &net, 1000, &mut rng);
// Step 4: Train.
println!("--- Training ---\n");
let final_loss = train_network(&mut net, &mut examples, 30, 32, 0.01, &mut rng);
println!("\nFinal average loss: {:.4}\n", final_loss);
// Step 5: Evaluate AFTER training.
println!("--- After training ---\n");
// Opening position.
let opening = GameState::new();
evaluate_position("Opening (after training)", &opening, &net, 1000, &mut rng);
// Mid-game: X took center, O took corner.
let mut mid_game = GameState::new();
mid_game = mid_game.apply_move(4); // X takes center
mid_game = mid_game.apply_move(0); // O takes corner
evaluate_position("Mid-game: X=center, O=corner", &mid_game, &net, 1000, &mut rng);
// Late game: X is about to win.
let mut late_game = GameState::new();
late_game = late_game.apply_move(4); // X center
late_game = late_game.apply_move(0); // O corner
late_game = late_game.apply_move(2); // X top-right
late_game = late_game.apply_move(3); // O mid-left
// X's turn β should play cell 6 to win (diagonal 2-4-6).
evaluate_position("Late game: X to play (winning move at 6)", &late_game, &net, 1000, &mut rng);
}
Running the solution
The output will look something like this (exact numbers vary with the RNG seed):
--- Generating training data ---
Generated 50 games (346 examples so far)
Generated 100 games (671 examples so far)
Total: 100 games, 671 training examples
--- Before training ---
=== Opening (before training) ===
. | . | .
-----------
. | . | .
-----------
. | . | .
Current player: X
Cell Network MCTS
---- ------- -----
0 0.104 0.077
1 0.098 0.058
2 0.101 0.078
3 0.124 0.061
4 0.126 0.349
5 0.104 0.067
6 0.098 0.081
7 0.140 0.066
8 0.106 0.163
Network best move: 7 | MCTS best move: 4 | differ
Network value: -0.0274 (positive = good for X)
--- Training ---
Epoch 1: avg loss = 2.3849
Epoch 5: avg loss = 1.7542
Epoch 10: avg loss = 1.4321
Epoch 15: avg loss = 1.2188
Epoch 20: avg loss = 1.0763
Epoch 25: avg loss = 0.9612
Epoch 30: avg loss = 0.8895
Final average loss: 0.8895
--- After training ---
=== Opening (after training) ===
Cell Network MCTS
---- ------- -----
0 0.084 0.079
1 0.053 0.055
2 0.092 0.083
3 0.066 0.064
4 0.371 0.348
5 0.068 0.069
6 0.089 0.076
7 0.064 0.059
8 0.113 0.167
Network best move: 4 | MCTS best move: 4 | MATCH
Network value: 0.0412 (positive = good for X)
Several things to notice in the output:
The loss decreases steadily. From around 2.4 in the first epoch to around 0.9 by epoch 30. This confirms the network is learning β its predictions are getting closer to the MCTS targets.
Before training, the networkβs policy is roughly uniform. With random weights, the network assigns approximately equal probability to all legal moves. It has no idea that the center is special.
After training, the network strongly prefers the center. Cell 4 gets around 35-40% of the probability mass, closely matching the MCTS visit distribution. The network has learned from MCTS data that the center is the best opening move.
The value prediction for the opening is near zero. This is correct β Tic-Tac-Toe is a theoretical draw with perfect play, so neither player has an advantage from the starting position. A value near 0.0 reflects this.
The networkβs best move matches MCTS on common positions. After training, the network agrees with MCTS on the opening and many mid-game positions. It has internalized the strategic patterns that MCTS discovers through search.
Understanding what was learned
The trained network has essentially compressed thousands of MCTS searches into a set of weights. When MCTS searches the opening position, it runs 800 iterations of tree traversal and random rollouts to determine that the center is best. The trained network reaches the same conclusion with a single forward pass β a few matrix multiplications taking microseconds.
This is the key insight of AlphaGo Zeroβs approach: the neural network is a fast approximation of a slow search. The search is accurate but expensive. The network is less accurate but nearly instant. And as we will see in Β§12, we can combine them: use the network to guide the search, making the search both faster and more accurate.
What to look for (and what can go wrong)
Loss not decreasing? Check your learning rate. If it is too high (e.g., 0.1), the gradients overshoot and training oscillates. If it is too low (e.g., 0.0001), learning is too slow to see progress in 30 epochs. Start with 0.01.
Network always predicts uniform policy? This can happen if you forget to shuffle the training data. Without shuffling, the network sees all positions from game 1, then all positions from game 2, and so on. It overfits to the most recent game and forgets everything else.
Value predictions stuck near zero? This is partly expected for Tic-Tac-Toe because most MCTS self-play games end in draws (value target = 0.0). The value head has less signal to learn from compared to the policy head. You can verify it is working by checking value predictions on clearly winning or losing positions.
Network disagrees with MCTS on some positions? This is normal. 100 games of training data is not enough for the network to perfectly approximate MCTS on every position. Rare or unusual board states may not appear in the training data at all. More training games and more epochs improve coverage.
Experiments to try
-
Vary the number of training games. Try 50, 100, 200, and 500 games. How does the final loss change? How much better does the network agree with MCTS on the evaluation positions?
-
Vary MCTS iterations per move. Try 200, 800, and 2000. Higher iteration counts produce higher-quality training targets, but take longer to generate. Is the improvement in network quality worth the extra computation?
-
Vary the learning rate. Try 0.001, 0.01, and 0.05. Plot (or print) the loss curve for each. Which converges fastest? Which achieves the lowest final loss?
-
Train for more epochs. Try 50 or 100 epochs instead of 30. Does the loss keep decreasing, or does it plateau? If it plateaus, the network has learned everything it can from this dataset and needs more training data to improve further.
-
Test on unusual positions. Create board states that are unlikely to appear in MCTS self-play (e.g., positions where one player made obvious blunders). How does the network handle positions outside its training distribution?
These experiments will build your intuition for the key hyperparameters in neural network training β data quantity, data quality, learning rate, and training duration. These same levers matter whether you are training a Tic-Tac-Toe network with 6,000 parameters or a language model with 6 billion.
12. Exercise 4: replace rollout with the value network
In Β§11 we trained the neural network to mimic MCTS: given a board position, predict the MCTS visit distribution (policy) and the game outcome (value). The network learned to compress thousands of iterations of random-rollout search into a single forward pass. Now we close the loop β we plug the network back into MCTS, replacing the two weakest parts of the algorithm with neural network predictions.
This is the key innovation from AlphaGo Zero. Pure MCTS (Β§7-8) has two limitations:
- Random rollouts are noisy. Playing random moves to the end of the game is a crude way to evaluate a position. It takes many iterations to average out the noise.
- Blind exploration. UCT treats every unexplored move equally. It has to try each one at least once before it can start discriminating. In games with large branching factors, this wastes a lot of search budget.
The neural network fixes both problems:
- The value head replaces rollouts. Instead of playing random moves to a terminal state, we call
network.predict(state)and use the value output directly. One forward pass, one number β no randomness, no wasted computation. - The policy head guides exploration. Instead of treating all moves equally, we use the policy output as a prior probability on each move. Moves the network considers promising get explored first and more often.
The exercise
Modify your MCTS implementation from Β§8 to accept a Network and use it in place of random rollouts. You need to make four changes:
- Add a
priorfield toMctsNodeβ the networkβs policy probability for the move that led to this node. - Replace the UCT formula with PUCT (Predictor + UCT), which incorporates the prior.
- When expanding a node, set child priors from the networkβs policy output.
- Replace the
simulatefunction with a network forward pass.
Try implementing these changes yourself before looking at the solution. The rest of this section explains each change in detail.
Change 1: add a prior to MctsNode
The original MctsNode has no notion of βhow promising is this move?β β it discovers that purely through simulation. We need to add a field that stores the networkβs prior probability:
#![allow(unused)]
fn main() {
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
/// Prior probability from the network's policy head.
/// This is the network's estimate of how good the move
/// that led to this node is, before any search.
prior: f64,
}
}
The root node has no parent move, so its prior does not matter β set it to 0.0 (or any value; it is never read during selection).
Change 2: the PUCT formula
In Β§7 we used the UCT formula for selection:
UCT(child) = W/N + C * sqrt(ln(N_parent) / N_child)
AlphaGo Zero replaces this with PUCT (Predictor + Upper Confidence bounds applied to Trees):
PUCT(child) = Q + c_puct * P * sqrt(N_parent) / (1 + N_child)
where:
- Q =
total_value / visit_countβ the average value of this child (exploitation). Same as the W/N term in UCT. - P =
priorβ the networkβs policy probability for this move. A high prior means the network thinks this move is promising before any search. - N_parent = parentβs visit count.
- N_child = this childβs visit count.
- c_puct = a constant controlling exploration vs exploitation. AlphaGo Zero uses values around 1.0-2.5; we will use 1.5 for Tic-Tac-Toe.
Notice the differences from UCT:
- No logarithm. UCT has
ln(N_parent)in the exploration term; PUCT usessqrt(N_parent)directly. This makes PUCT explore more aggressively early on. - Prior probability P. This is the key addition. Moves with high prior get a large exploration bonus, so the search visits them early. Moves with near-zero prior are effectively pruned β the search may never visit them if better options exist.
- Denominator is (1 + N_child). When a child has zero visits, the exploration term is
c_puct * P * sqrt(N_parent), which is large and finite (unlike UCT, which returns infinity for unvisited nodes). This means the prior determines the order in which unvisited children are explored β not just βall unvisited children are equally attractive.β
In Rust:
#![allow(unused)]
fn main() {
/// Computes the PUCT score for a child node.
///
/// Unlike UCT, PUCT uses the network's prior probability to bias
/// exploration toward moves the network considers promising.
fn puct_score(
child_value: f64,
child_visits: u32,
parent_visits: u32,
prior: f64,
c_puct: f64,
) -> f64 {
let q = if child_visits == 0 {
0.0
} else {
child_value / child_visits as f64
};
let exploration = c_puct
* prior
* (parent_visits as f64).sqrt()
/ (1.0 + child_visits as f64);
q + exploration
}
}
When a child has zero visits, Q defaults to 0.0 (neutral β we have no information yet). The exploration term is then c_puct * P * sqrt(N_parent), which is entirely determined by the prior. The networkβs opinion drives the initial exploration order.
Change 3: set priors during expansion
In the original MCTS, expand creates a child node with no prior information. In network-guided MCTS, we query the network when expanding and set each childβs prior from the policy output.
The cleanest approach: when we expand a node, we call network.predict on the nodeβs state to get the full policy vector, then create all children at once (setting each childβs prior from the corresponding policy entry). This differs from the original MCTS, which expanded one child at a time from unexpanded_moves. Expanding all children at once is simpler and matches the AlphaGo Zero approach β the network evaluates the position once, and we use both outputs (policy for child priors, value for backpropagation).
#![allow(unused)]
fn main() {
/// Expands all children of a node at once, using the network's
/// policy output to set priors. Returns the value estimate for
/// backpropagation.
fn expand_with_network(
&mut self,
node_idx: usize,
network: &Network,
) -> f64 {
let state = &self.nodes[node_idx].state;
// If the position is terminal, return the true outcome
// instead of a network estimate.
if state.is_terminal() {
let eval_player = state.current_player.opponent();
return match state.winner() {
Some(winner) if winner == eval_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
}
let (policy, value) = network.predict(state);
// Create a child for every legal move, setting the prior
// from the network's policy output.
let legal_moves: Vec<usize> = self.nodes[node_idx]
.unexpanded_moves
.drain(..)
.collect();
for mv in &legal_moves {
let child_state = self.nodes[node_idx].state.apply_move(*mv);
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(*mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: Vec::new(), // expanded later
prior: policy[*mv] as f64,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
}
// Populate unexpanded_moves for each child (for future expansion).
let children: Vec<usize> = self.nodes[node_idx].children.clone();
for &child_idx in &children {
let moves = self.nodes[child_idx].state.legal_moves();
self.nodes[child_idx].unexpanded_moves = moves;
}
// Return the value from the perspective of the player who
// moved to reach this node (the opponent of current_player).
// The network returns value from current_player's perspective,
// so we negate it.
-(value as f64)
}
}
The value sign is important. The networkβs value output is from the perspective of state.current_player (the player whose turn it is). But total_value at a node is stored from the perspective of the player who moved to reach this node β that is state.current_player.opponent(). So we negate the networkβs value before backpropagation.
Change 4: replace simulate with evaluate
The simulate function is gone entirely. Instead of playing random moves to a terminal state, the expand_with_network function already returns the value estimate. The search loop becomes:
#![allow(unused)]
fn main() {
/// Runs network-guided MCTS for the given number of iterations.
fn search_with_network(
&mut self,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
for _ in 0..iterations {
// Selection: walk down the tree using PUCT.
let selected = self.select_puct(c_puct);
// Expansion + evaluation: expand all children and get
// the network's value estimate in one step.
let value = self.expand_with_network(selected, network);
// Backpropagation: same as before.
self.backpropagate(selected, value);
}
self.best_move()
}
}
The selection phase also needs updating to use PUCT instead of UCT:
#![allow(unused)]
fn main() {
/// Selection using PUCT scores instead of UCT.
fn select_puct(&self, c_puct: f64) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
// If this node has unexpanded moves, select it for expansion.
if !node.unexpanded_moves.is_empty() {
return current;
}
// If this is a terminal node (no children, no moves), stop.
if node.children.is_empty() {
return current;
}
// Pick the child with the highest PUCT score.
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = puct_score(
child.total_value,
child.visit_count,
parent_visits,
child.prior,
c_puct,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
}
This is structurally identical to the original select β the only change is calling puct_score instead of uct_score.
Complete modified MCTS
Solution: full network-guided MCTS implementation
#![allow(unused)]
fn main() {
use rand::seq::SliceRandom;
// βββ Network-guided MCTS ββββββββββββββββββββββββββββββββββββββββββ
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
/// Prior probability from the network's policy head.
prior: f64,
}
struct MctsTree {
nodes: Vec<MctsNode>,
}
/// Computes the PUCT score for a child node.
fn puct_score(
child_value: f64,
child_visits: u32,
parent_visits: u32,
prior: f64,
c_puct: f64,
) -> f64 {
let q = if child_visits == 0 {
0.0
} else {
child_value / child_visits as f64
};
let exploration = c_puct
* prior
* (parent_visits as f64).sqrt()
/ (1.0 + child_visits as f64);
q + exploration
}
impl MctsTree {
/// Creates a new tree rooted at the given game state.
fn new(root_state: GameState) -> Self {
let unexpanded = root_state.legal_moves();
let root = MctsNode {
state: root_state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: unexpanded,
prior: 0.0, // root prior is unused
};
MctsTree { nodes: vec![root] }
}
/// Selection using PUCT scores.
fn select_puct(&self, c_puct: f64) -> usize {
let mut current = 0;
loop {
let node = &self.nodes[current];
if !node.unexpanded_moves.is_empty() { return current; }
if node.children.is_empty() { return current; }
let parent_visits = node.visit_count;
let mut best_child = node.children[0];
let mut best_score = f64::NEG_INFINITY;
for &child_idx in &node.children {
let child = &self.nodes[child_idx];
let score = puct_score(
child.total_value,
child.visit_count,
parent_visits,
child.prior,
c_puct,
);
if score > best_score {
best_score = score;
best_child = child_idx;
}
}
current = best_child;
}
}
/// Expands all children of a node using the network's policy,
/// and returns the value estimate for backpropagation.
fn expand_with_network(
&mut self,
node_idx: usize,
network: &Network,
) -> f64 {
let state = &self.nodes[node_idx].state;
if state.is_terminal() {
let eval_player = state.current_player.opponent();
return match state.winner() {
Some(winner) if winner == eval_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
}
let (policy, value) = network.predict(state);
let legal_moves: Vec<usize> = self.nodes[node_idx]
.unexpanded_moves
.drain(..)
.collect();
for mv in &legal_moves {
let child_state = self.nodes[node_idx].state.apply_move(*mv);
let child_idx = self.nodes.len();
let child = MctsNode {
state: child_state,
move_from_parent: Some(*mv),
parent: Some(node_idx),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: Vec::new(),
prior: policy[*mv] as f64,
};
self.nodes.push(child);
self.nodes[node_idx].children.push(child_idx);
}
// Fill in unexpanded_moves for each new child.
let children: Vec<usize> = self.nodes[node_idx].children.clone();
for &child_idx in &children {
let moves = self.nodes[child_idx].state.legal_moves();
self.nodes[child_idx].unexpanded_moves = moves;
}
// Negate: network returns value for current_player,
// but we need value for the player who moved here.
-(value as f64)
}
/// Backpropagation β identical to the pure MCTS version.
fn backpropagate(&mut self, node_idx: usize, value: f64) {
let mut current = Some(node_idx);
let mut current_value = value;
while let Some(idx) = current {
self.nodes[idx].visit_count += 1;
self.nodes[idx].total_value += current_value;
current = self.nodes[idx].parent;
current_value = -current_value;
}
}
/// Returns the move of the root's most-visited child.
fn best_move(&self) -> usize {
let root = &self.nodes[0];
let best_child_idx = root
.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.expect("root must have children after search");
self.nodes[*best_child_idx]
.move_from_parent
.expect("root children have moves")
}
/// Runs network-guided MCTS and returns the best move.
fn search_with_network(
&mut self,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
for _ in 0..iterations {
let selected = self.select_puct(c_puct);
let value = self.expand_with_network(selected, network);
self.backpropagate(selected, value);
}
self.best_move()
}
}
/// Convenience function: choose a move using network-guided MCTS.
fn mcts_network_move(
state: &GameState,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(iterations, network, c_puct)
}
}
Testing: pure MCTS vs network-guided MCTS
Now letβs pit the two approaches against each other. We will run a tournament: pure MCTS (random rollouts) vs network-guided MCTS (using the network trained in Β§11), and count wins, losses, and draws.
Solution: tournament between pure and network-guided MCTS
/// Plays one game: network-guided MCTS (X) vs pure MCTS (O).
/// Returns the outcome from X's perspective.
fn play_one_game(
network: &Network,
net_iterations: u32,
pure_iterations: u32,
c_puct: f64,
rng: &mut impl rand::Rng,
) -> f64 {
let mut state = GameState::new();
while !state.is_terminal() {
let mv = match state.current_player {
Player::X => {
// Network-guided MCTS
mcts_network_move(&state, net_iterations, network, c_puct)
}
Player::O => {
// Pure MCTS with random rollouts
mcts_move(&state, pure_iterations, rng)
}
};
state = state.apply_move(mv);
}
match state.winner() {
Some(Player::X) => 1.0,
Some(Player::O) => -1.0,
None => 0.0,
}
}
fn run_tournament(
network: &Network,
num_games: u32,
net_iterations: u32,
pure_iterations: u32,
c_puct: f64,
) {
let mut rng = rand::thread_rng();
let mut net_wins = 0;
let mut pure_wins = 0;
let mut draws = 0;
for game in 0..num_games {
// Alternate who plays X and O for fairness.
let result = if game % 2 == 0 {
play_one_game(
network, net_iterations, pure_iterations,
c_puct, &mut rng,
)
} else {
// Swap roles: pure MCTS plays X, network plays O.
-play_one_game_swapped(
network, net_iterations, pure_iterations,
c_puct, &mut rng,
)
};
match result.partial_cmp(&0.0) {
Some(std::cmp::Ordering::Greater) => net_wins += 1,
Some(std::cmp::Ordering::Less) => pure_wins += 1,
_ => draws += 1,
}
}
println!("Results over {} games:", num_games);
println!(
" Network-guided MCTS ({} iters): {} wins",
net_iterations, net_wins
);
println!(
" Pure MCTS ({} iters): {} wins",
pure_iterations, pure_wins
);
println!(" Draws: {}", draws);
}
fn main() {
// Load or train the network (from Β§11).
let network = train_network();
// Fair comparison: same iteration budget.
println!("=== Same budget: 200 iterations each ===");
run_tournament(&network, 100, 200, 200, 1.5);
// Efficiency test: can the network do more with less?
println!("\n=== Network 100 iters vs Pure 500 iters ===");
run_tournament(&network, 100, 100, 500, 1.5);
}
What should you expect? With a well-trained network (from Β§11 with 100+ self-play games), the network-guided MCTS should:
- Win most games at equal iteration counts. At 200 iterations each, network-guided MCTS has a significant advantage because every iteration is informed by learned knowledge rather than random noise.
- Compete with fewer iterations. Network-guided MCTS at 100 iterations should perform comparably to pure MCTS at 500 iterations. The network replaces thousands of random moves with a single forward pass that encodes the patterns learned from training data.
If the network-guided version performs worse than pure MCTS, the network is not trained well enough. Go back to Β§11 and train on more games (200+) or more iterations per move (1000+).
Why is this better?
Consider what happens in pure MCTS when it encounters a new position. It must try every legal move at least once, then play random moves to the end of the game, repeating hundreds of times to build up reliable statistics. Most of those random rollouts are wasted β they follow nonsensical move sequences that no reasonable player would choose.
Network-guided MCTS skips all of this. The policy head says βmove 4 looks best, move 0 is second, the rest are unlikelyβ β so the search spends most of its budget on the two or three most promising moves. The value head says βthis position is +0.6 for the current playerβ β one number that summarizes what random rollouts would have (noisily) converged to after thousands of games.
This is the core insight of AlphaGo Zero: a learned evaluation function turns MCTS from a brute-force explorer into a focused, efficient search. The same iteration budget goes much further because the search is guided by knowledge rather than randomness.
But there is a subtlety. The networkβs predictions are only as good as its training data. If the network was trained on weak MCTS play (few iterations, few games), its priors may be wrong, and it could steer the search in bad directions. This is why the full AlphaGo Zero loop (Β§13) keeps iterating: the network guides better MCTS, which generates better training data, which trains a better network, and so on. Each generation is stronger than the last.
Experiments to try
-
Vary c_puct. Try 0.5, 1.0, 1.5, and 3.0. Low values make the search exploit the networkβs prior more (less exploration). High values make the search explore more broadly despite the prior. Which value wins the most games?
-
Vary the iteration budget. Run network-guided MCTS at 50, 100, 200, and 500 iterations against pure MCTS at 1000. At what point does the network-guided version start winning consistently?
-
Test with an untrained network. Create a
Network::new()with random weights and use it to guide MCTS. It should perform worse than pure MCTS, because the random priors actively mislead the search. This confirms that the improvement comes from learned knowledge, not from the PUCT formula itself. -
Inspect the search tree. After running network-guided MCTS, print the visit counts at the rootβs children. Compare with pure MCTS on the same position. The network-guided version should concentrate visits on fewer moves β the ones the network considers most promising.
Part 5 β Self-Play Loop
13. The full AlphaGo Zero training loop
In Β§11 we trained a neural network to mimic pure MCTS. In Β§12 we plugged that network back into MCTS, replacing random rollouts with learned evaluation. The result was a stronger player β but a static one. The network was trained once on data from pure MCTS, and that was it. It could not improve beyond the quality of its training data.
This section closes the loop. Instead of training once and stopping, we iterate: the network guides MCTS to generate better training data, we train a new network on that data, check that it actually improved, and repeat. Each cycle produces a stronger player than the last. This is the full AlphaGo Zero self-improvement pipeline.
The training cycle
The AlphaGo Zero training loop has three phases that repeat indefinitely:
ββββββββββββββββββββββββ
β 1. Self-Play β
β Network-guided β
β MCTS plays games β
β against itself. β
β Collect (state, β
β policy, outcome) β
β triples. β
ββββββββββ¬ββββββββββββββ
β
βΌ
ββββββββββββββββββββββββ
β 2. Train β
β Update the network β
β on the collected β
β data. Produces a β
β candidate network. β
ββββββββββ¬ββββββββββββββ
β
βΌ
ββββββββββββββββββββββββ
ββNoβββ 3. Evaluate β
β β Pit the candidate β
β β against the currentβ
β β best network. β
β β Does it win > 55%? β
β ββββββββββ¬ββββββββββββββ
β β Yes
β βΌ
β ββββββββββββββββββββββββ
β β Candidate becomes β
β β the new best. β
β ββββββββββ¬ββββββββββββββ
β β
ββββββββββββββββ
β
βΌ
Back to step 1
Step 1 β Self-play. The current best network guides MCTS to play complete games against itself. At each position, MCTS runs for a fixed number of iterations (e.g., 200), using the networkβs policy head for move priors and value head for position evaluation. We record every position along with the MCTS visit distribution (the policy target) and, once the game ends, the outcome (the value target). This produces a batch of TrainingExample triples.
Step 2 β Train. We create a copy of the current best network and train it on the self-play data for several epochs. This produces a candidate network that (hopefully) has internalized the patterns discovered during self-play.
Step 3 β Evaluate. We play a set of evaluation games between the candidate network and the current best network, each using MCTS with the same iteration budget. If the candidate wins more than a threshold (typically 55% of games), it becomes the new best. Otherwise, we discard it and try again.
Then we go back to step 1 and generate new self-play data using the (possibly updated) best network.
Why the evaluation gate matters
It might seem wasteful to discard a freshly trained network. Why not always use the latest one? The problem is that training can sometimes make things worse. A few specific failure modes:
-
Policy collapse. The network becomes overconfident in a single move and assigns near-zero probability to everything else. MCTS guided by this network effectively stops searching β it always plays the same move. The resulting self-play data is narrow and repetitive, which reinforces the collapse in the next training round.
-
Catastrophic forgetting. Training on a new batch of games can overwrite knowledge from previous iterations. The network might learn to handle mid-game positions well but forget how to play openings, or vice versa.
-
Overfitting to noise. With small datasets (which is what Tic-Tac-Toe generates), the network can memorize training examples rather than learning general patterns. A memorized network performs well on positions it has seen but poorly on novel ones.
The evaluation gate catches all of these. If the candidate cannot beat the current best in head-to-head play, something went wrong during training, and we keep the proven network. This acts as a ratchet β strength can only go up, never down.
In practice with Tic-Tac-Toe, the evaluation gate rarely rejects a candidate because the game is simple enough that training is stable. But for larger games, this safeguard is essential.
Temperature in move selection
During self-play, we do not always pick the move with the highest visit count. Instead, we use a temperature parameter to control how deterministic the move selection is.
After MCTS completes a search, the rootβs children have visit counts \(N_1, N_2, \ldots, N_k\). We convert these into a probability distribution using temperature \(\tau\):
P(move_i) = N_i^(1/Ο) / Ξ£_j N_j^(1/Ο)
The temperature controls exploration vs exploitation:
-
Ο = 1 (temperature = 1): The probability is proportional to visit counts. If move A got 60% of visits and move B got 40%, we pick A 60% of the time and B 40% of the time. This adds randomness, which creates diverse training data.
-
Ο β 0 (low temperature): The distribution becomes sharply peaked. The most-visited move gets probability ~1.0, everything else gets ~0.0. This is essentially greedy β always play the best move. Good for strong play, bad for exploration.
-
Ο > 1 (high temperature): The distribution flattens. Even rarely-visited moves get a reasonable chance of being selected. This creates maximum diversity but at the cost of playing some weak moves.
AlphaGo Zero uses a schedule: for the first N moves of each game, use Ο = 1 (explore). After that, use Ο β 0 (exploit). For Tic-Tac-Toe, a reasonable choice is Ο = 1 for the first 3 moves and Ο β 0 (greedy) for the rest.
Why does this matter? If self-play always picks the best move (Ο = 0 everywhere), every game follows the same trajectory. The network only sees a handful of positions and learns nothing about the rest of the game tree. With Ο = 1 in the opening, games branch into diverse positions, producing a richer training set that covers more of the state space.
Here is the temperature-scaled move selection in Rust:
#![allow(unused)]
fn main() {
/// Selects a move from the MCTS root using temperature-scaled
/// visit counts.
///
/// When temperature is very small (< 0.01), picks the most-visited
/// move deterministically. Otherwise, samples from the distribution
/// P(move_i) = N_i^(1/Ο) / Ξ£ N_j^(1/Ο).
fn select_move_with_temperature(
tree: &MctsTree,
temperature: f64,
rng: &mut impl rand::Rng,
) -> usize {
let root = &tree.nodes[0];
if temperature < 0.01 {
// Greedy: pick the most-visited child.
return tree.best_move();
}
// Compute temperature-scaled visit counts.
let inv_temp = 1.0 / temperature;
let scaled: Vec<f64> = root
.children
.iter()
.map(|&idx| (tree.nodes[idx].visit_count as f64).powf(inv_temp))
.collect();
let total: f64 = scaled.iter().sum();
if total == 0.0 {
// Fallback: random legal move.
return tree.best_move();
}
// Sample from the distribution.
let mut r: f64 = rng.gen::<f64>() * total;
for (i, &s) in scaled.iter().enumerate() {
r -= s;
if r <= 0.0 {
let child_idx = root.children[i];
return tree.nodes[child_idx]
.move_from_parent
.expect("child must have a move");
}
}
// Rounding edge case β return last child.
let last = *root.children.last().unwrap();
tree.nodes[last]
.move_from_parent
.expect("child must have a move")
}
}
Self-play with network-guided MCTS
The self-play function is similar to generate_game_data from Β§11, but with two key differences: it uses network-guided MCTS (from Β§12) instead of pure MCTS, and it uses temperature for move selection.
#![allow(unused)]
fn main() {
/// Plays one complete game using network-guided MCTS and collects
/// training examples.
///
/// Uses temperature = 1.0 for the first `temp_moves` moves
/// (exploration), then temperature β 0 (exploitation) for
/// the rest.
fn self_play_game(
network: &Network,
mcts_iterations: u32,
c_puct: f64,
temp_moves: usize,
rng: &mut impl rand::Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
let mut move_number = 0;
while !state.is_terminal() {
// Run network-guided MCTS from the current position.
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(mcts_iterations, network, c_puct);
// Record the visit distribution as the policy target.
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
// Select a move using temperature.
let temperature = if move_number < temp_moves {
1.0
} else {
0.0 // greedy
};
let chosen_move =
select_move_with_temperature(&tree, temperature, rng);
state = state.apply_move(chosen_move);
move_number += 1;
}
// Determine the game outcome and build training examples.
let winner = state.winner();
history
.into_iter()
.map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample {
state: s,
policy_target,
value_target,
}
})
.collect()
}
}
Notice that the policy target is always the raw visit distribution from MCTS β not the temperature-adjusted distribution. We want the network to learn what MCTS thinks is best, not the exploration noise we added for diversity. The temperature only affects which move is played (and therefore which positions appear later in the game).
Evaluating the candidate network
After training a candidate network, we need to know if it is actually better than the current best. The simplest approach: play a set of games where each side uses network-guided MCTS with its respective network, and count wins.
#![allow(unused)]
fn main() {
/// Plays one evaluation game between two networks.
///
/// `network_x` plays as X, `network_o` plays as O.
/// Returns +1.0 if X wins, -1.0 if O wins, 0.0 for draw.
fn play_evaluation_game(
network_x: &Network,
network_o: &Network,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut state = GameState::new();
while !state.is_terminal() {
let network = match state.current_player {
Player::X => network_x,
Player::O => network_o,
};
// Evaluation games use temperature = 0 (always play
// the best move). No exploration β we want to measure
// true strength.
let mv = mcts_network_move(&state, mcts_iterations, network, c_puct);
state = state.apply_move(mv);
}
match state.winner() {
Some(Player::X) => 1.0,
Some(Player::O) => -1.0,
None => 0.0,
}
}
/// Evaluates a candidate network against the current best.
///
/// Plays `num_games` games, alternating which network plays X
/// and which plays O for fairness. Returns the candidate's win
/// rate (0.0 to 1.0).
fn evaluate_networks(
candidate: &Network,
current_best: &Network,
num_games: u32,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut candidate_wins = 0;
let mut total_decisive = 0;
for game in 0..num_games {
let result = if game % 2 == 0 {
// Candidate plays X
play_evaluation_game(
candidate, current_best,
mcts_iterations, c_puct,
)
} else {
// Candidate plays O (negate result so positive = candidate won)
-play_evaluation_game(
current_best, candidate,
mcts_iterations, c_puct,
)
};
if result > 0.0 {
candidate_wins += 1;
total_decisive += 1;
} else if result < 0.0 {
total_decisive += 1;
}
// Draws don't count toward win rate
}
if total_decisive == 0 {
0.5 // All draws β treat as neutral
} else {
candidate_wins as f64 / total_decisive as f64
}
}
}
A note on draws: Tic-Tac-Toe between strong players ends in a draw most of the time. The evaluation function only counts decisive games (wins and losses) when computing the win rate. If all games are draws, the candidate is considered equal (50%) and does not replace the current best. This avoids the situation where a network that always draws βwinsβ 100% because it never lost.
The main training loop
Now we tie everything together. The outer loop runs for a fixed number of iterations, each time generating self-play data, training a candidate, and evaluating it.
#![allow(unused)]
fn main() {
/// Hyperparameters for the AlphaGo Zero training loop.
struct TrainingConfig {
/// Number of training iterations (cycles of self-play + train + evaluate).
num_iterations: u32,
/// Number of self-play games per iteration.
games_per_iteration: u32,
/// MCTS iterations per move during self-play.
mcts_iterations: u32,
/// PUCT exploration constant.
c_puct: f64,
/// Number of opening moves to play with temperature = 1.
temp_moves: usize,
/// Training epochs per iteration.
training_epochs: u32,
/// Mini-batch size for training.
batch_size: usize,
/// Learning rate for SGD.
learning_rate: f32,
/// Number of evaluation games to play.
eval_games: u32,
/// Win rate threshold to accept the candidate (e.g., 0.55 = 55%).
win_threshold: f64,
}
/// Runs the full AlphaGo Zero training loop.
///
/// Starts with a randomly initialized network and iteratively
/// improves it through self-play, training, and evaluation.
fn alphago_zero_training(config: &TrainingConfig, seed: u64) -> Network {
let mut rng = StdRng::seed_from_u64(seed);
let mut best_network = Network::new(rng.gen());
println!("Starting AlphaGo Zero training loop");
println!(" Iterations: {}", config.num_iterations);
println!(" Games/iter: {}", config.games_per_iteration);
println!(" MCTS iters: {}", config.mcts_iterations);
println!(" Eval games: {}", config.eval_games);
println!(" Win threshold: {:.0}%", config.win_threshold * 100.0);
println!();
for iteration in 1..=config.num_iterations {
println!("βββ Iteration {}/{} βββ", iteration, config.num_iterations);
// ββ Phase 1: Self-play ββββββββββββββββββββββββββββββββββ
println!("Phase 1: Generating self-play data...");
let mut all_examples = Vec::new();
for game in 1..=config.games_per_iteration {
let examples = self_play_game(
&best_network,
config.mcts_iterations,
config.c_puct,
config.temp_moves,
&mut rng,
);
all_examples.extend(examples);
if game % 25 == 0 {
println!(
" Game {}/{} ({} examples so far)",
game, config.games_per_iteration, all_examples.len()
);
}
}
println!(
" Self-play complete: {} examples from {} games",
all_examples.len(),
config.games_per_iteration
);
// ββ Phase 2: Train ββββββββββββββββββββββββββββββββββββββ
println!("Phase 2: Training candidate network...");
let mut candidate = best_network.clone();
let final_loss = train_network(
&mut candidate,
&mut all_examples,
config.training_epochs,
config.batch_size,
config.learning_rate,
&mut rng,
);
println!(" Training complete: final loss = {:.4}", final_loss);
// ββ Phase 3: Evaluate βββββββββββββββββββββββββββββββββββ
println!("Phase 3: Evaluating candidate vs current best...");
let win_rate = evaluate_networks(
&candidate,
&best_network,
config.eval_games,
config.mcts_iterations,
config.c_puct,
);
println!(" Candidate win rate: {:.1}%", win_rate * 100.0);
if win_rate >= config.win_threshold {
println!(" β Candidate accepted as new best network!");
best_network = candidate;
} else {
println!(" β Candidate rejected. Keeping current best.");
}
println!();
}
best_network
}
}
A concrete training schedule for Tic-Tac-Toe
Tic-Tac-Toe has a small state space (at most 5,478 reachable positions) and short games (5-9 moves). The network converges quickly. Here is a training schedule that works well:
fn main() {
let config = TrainingConfig {
num_iterations: 20, // 20 cycles of improvement
games_per_iteration: 100, // 100 self-play games per cycle
mcts_iterations: 200, // 200 MCTS iterations per move
c_puct: 1.5, // PUCT exploration constant
temp_moves: 3, // temperature = 1 for first 3 moves
training_epochs: 20, // 20 epochs of training per cycle
batch_size: 32, // mini-batch size
learning_rate: 0.01, // SGD learning rate
eval_games: 40, // 40 evaluation games
win_threshold: 0.55, // candidate must win > 55%
};
let trained = alphago_zero_training(&config, 42);
// Test the final network on a few positions.
println!("βββ Final Network Evaluation βββ");
let empty = GameState::new();
let (policy, value) = trained.predict(&empty);
let masked = mask_illegal(&policy, &empty);
println!("Empty board:");
println!(" Policy: {:?}", masked);
println!(" Value: {:.3}", value);
println!(" Best move: cell {}", masked
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0
);
// Test against pure MCTS to verify strength.
println!("\nFinal tournament: trained network vs pure MCTS (500 iters)");
run_tournament(&trained, 100, 200, 500, 1.5);
}
This schedule produces roughly 2,000 self-play games total (20 iterations times 100 games), each generating 5-9 training examples. Total training time on a modern machine is under a minute.
What to expect
As training progresses, you should observe several things:
Iteration 1-3: rapid initial improvement. The network starts with random weights, so its predictions are noise. Even one round of self-play data gives the network some signal. The candidate easily beats the random network, and the evaluation gate accepts it every time.
Iteration 4-8: diminishing returns. The network already plays reasonably. Self-play games between identical copies of a decent player produce less diverse data β many games follow similar lines. Improvement slows down, and the candidate occasionally fails the evaluation gate.
Iteration 9-15: convergence. The network approaches near-perfect play. Tic-Tac-Toe is a solved game (perfect play from both sides results in a draw), so the network should converge to a strategy that never loses. Most evaluation games end in draws.
Iteration 15-20: plateau. The network has learned the game. Further iterations produce candidates that tie with the current best but rarely beat it. The evaluation gate rejects most candidates β not because they are worse, but because perfect-vs-perfect play always draws, and draws do not count as wins.
The key insight: the network teaches itself to play the game with zero human knowledge. It starts knowing nothing β only the rules. Through repeated cycles of self-play, training, and evaluation, it discovers opening strategy, tactical patterns, and endgame play entirely on its own. For Tic-Tac-Toe this happens in minutes. For Go (the original AlphaGo Zero domain), it took three days on thousands of TPUs β but the same algorithm produced superhuman play.
Summary
The AlphaGo Zero training loop is a simple cycle with powerful consequences:
- Self-play generates training data from the networkβs own games. Temperature in the opening ensures diverse positions.
- Training updates a candidate network on the self-play data.
- Evaluation gates the update β the candidate must prove it is stronger than the current best. This prevents regression.
- Repeat with the improved network, generating higher-quality data each round.
Each component is something we built in earlier sections. Β§11 showed how to generate training data from MCTS and train a network on it. Β§12 showed how to use the network to guide MCTS. This section connected them into a loop. The only new pieces are temperature-based move selection (for diverse self-play data) and the evaluation gate (to prevent regression).
The result is an algorithm that, given only the rules of a game, learns to play it at an expert level β by playing against itself, learning from its own games, and steadily improving.
14. Exercise 5: 1000 self-play games; observe improvement
This is the capstone exercise. You have built every piece of the AlphaGo Zero pipeline β the game engine (Β§6), Monte Carlo Tree Search (Β§7-8), a neural network with policy and value heads (Β§9-10), supervised training on MCTS data (Β§11), network-guided MCTS (Β§12), and the full training loop (Β§13). Now you wire them all together, press run, and watch an AI teach itself to play perfect Tic-Tac-Toe from nothing.
The exercise
Run the full AlphaGo Zero training loop for 20 iterations of 50 self-play games each (1,000 games total). After each iteration, print a diagnostics panel so you can observe the network improving in real time. When training finishes, validate that the network has converged to perfect play.
Step 1: Wire up all the modules
Your main.rs needs everything from the previous exercises assembled into one file:
GameState,Player, and the game logic (Β§6)MctsNode,MctsTree, and MCTS search β both pure and network-guided (Β§7-8, Β§12)Networkwith forward pass, backpropagation, andtrain_step(Β§9-10)TrainingExample,self_play_game,evaluate_networks, andtrain_network(Β§11, Β§13)select_move_with_temperatureandvisit_distribution(Β§13)TrainingConfigandalphago_zero_training(Β§13)
If you have been building incrementally, you already have all of these. If not, the complete solution at the end of this section provides everything in one place.
Step 2: Add diagnostics to the training loop
Modify the alphago_zero_training function to collect and print diagnostics after each iteration. Here is what to track:
Self-play diagnostics:
#![allow(unused)]
fn main() {
/// Statistics collected during one iteration of self-play.
struct IterationStats {
/// Average number of moves per game.
avg_game_length: f64,
/// Number of games won by X.
x_wins: u32,
/// Number of games won by O.
o_wins: u32,
/// Number of draws.
draws: u32,
/// Total training examples generated.
num_examples: usize,
}
}
To compute these, modify the self-play phase to track game outcomes and lengths:
#![allow(unused)]
fn main() {
// Inside the self-play loop:
let mut total_moves = 0u32;
let mut x_wins = 0u32;
let mut o_wins = 0u32;
let mut draws = 0u32;
for _game in 0..config.games_per_iteration {
let examples = self_play_game(
&best_network,
config.mcts_iterations,
config.c_puct,
config.temp_moves,
&mut rng,
);
// The number of examples equals the number of moves in the game.
let game_length = examples.len() as u32;
total_moves += game_length;
// Check the terminal state to determine the outcome.
// The last example's state had one more move applied to reach
// the terminal position, so we reconstruct it.
if let Some(last) = examples.last() {
// The value_target tells us the outcome from the last
// player's perspective.
match last.value_target.partial_cmp(&0.0) {
Some(std::cmp::Ordering::Greater) => {
// Last position's current player won.
match last.state.current_player {
Player::X => x_wins += 1,
Player::O => o_wins += 1,
}
}
Some(std::cmp::Ordering::Less) => {
match last.state.current_player {
Player::X => o_wins += 1,
Player::O => x_wins += 1,
}
}
_ => draws += 1,
}
}
all_examples.extend(examples);
}
let avg_game_length =
total_moves as f64 / config.games_per_iteration as f64;
}
Training diagnostics β policy loss and value loss:
Split the loss reported by train_network into its two components. If you followed Β§10, the total loss is policy_loss + value_loss. Modify train_network (or add a wrapper) to return both:
#![allow(unused)]
fn main() {
/// Loss components from one training run.
struct TrainingLoss {
/// Cross-entropy loss between predicted and target policy.
policy_loss: f32,
/// Mean squared error between predicted and target value.
value_loss: f32,
}
}
Sample prediction on the empty board:
After each iteration, run a forward pass on the empty board and print the policy and value. This is the simplest way to see the networkβs βopening bookβ evolving:
#![allow(unused)]
fn main() {
// After each iteration:
let empty = GameState::new();
let (policy, value) = best_network.predict(&empty);
let masked = mask_illegal(&policy, &empty);
let best_cell = masked
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
}
Step 3: Print the diagnostics panel
After each iteration, print a summary block like this:
βββ Iteration 7/20 βββββββββββββββββββββββββββββββββββββββ
Self-play: 50 games, 312 examples
Game length: 6.2 moves (avg)
Outcomes: X wins 12 | O wins 8 | Draws 30
Training: policy_loss = 1.2843 value_loss = 0.4521
Evaluation: candidate win rate = 62.5% β ACCEPTED
Empty board: best move = cell 4 (center)
policy = [.06 .04 .07 .05 .42 .05 .08 .04 .06]
value = +0.031
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Here is the Rust code that prints this:
#![allow(unused)]
fn main() {
println!("βββ Iteration {}/{} βββ", iteration, config.num_iterations);
println!(" Self-play: {} games, {} examples",
config.games_per_iteration, all_examples.len());
println!(" Game length: {:.1} moves (avg)", avg_game_length);
println!(" Outcomes: X wins {} | O wins {} | Draws {}",
x_wins, o_wins, draws);
println!(" Training: policy_loss = {:.4} value_loss = {:.4}",
loss.policy_loss, loss.value_loss);
println!(" Evaluation: candidate win rate = {:.1}% β {}",
win_rate * 100.0,
if win_rate >= config.win_threshold { "ACCEPTED" } else { "REJECTED" });
let (policy, value) = best_network.predict(&GameState::new());
let masked = mask_illegal(&policy, &GameState::new());
let best_cell = masked
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
let cell_name = match best_cell {
0 => "top-left", 1 => "top-center", 2 => "top-right",
3 => "mid-left", 4 => "center", 5 => "mid-right",
6 => "bot-left", 7 => "bot-center", 8 => "bot-right",
_ => "unknown",
};
println!(" Empty board: best move = cell {} ({})", best_cell, cell_name);
println!(" policy = [{:.2} {:.2} {:.2} {:.2} {:.2} {:.2} {:.2} {:.2} {:.2}]",
masked[0], masked[1], masked[2], masked[3], masked[4],
masked[5], masked[6], masked[7], masked[8]);
println!(" value = {:+.3}", value);
println!();
}
What to expect in the diagnostics
Watch for these trends as training progresses:
Average game length should generally decrease. Early on, the network plays semi-random moves, and games meander for 7-9 moves. As the network improves, games become more decisive β strong opening play leads to faster wins or quicker draws. Expect the average to settle around 5-7 moves.
Win/draw/loss ratio should shift toward draws. In early iterations, the network is weak and games are chaotic β there are many decisive results (X or O wins). As the network approaches perfect play, most self-play games end in draws because both sides defend correctly. By iteration 15-20, you should see 80-100% draws.
Policy loss should decrease steadily. The network is learning to predict MCTS visit distributions more accurately. Starting from random weights (policy loss around 2.0-2.5, which is roughly -ln(1/9)), it should drop below 1.0 within a few iterations. Perfect prediction would be 0.0, but the network will plateau around 0.5-1.0 because MCTS visit distributions have inherent noise.
Value loss should also decrease. Starting around 0.5-1.0 (random predictions vs outcomes in {-1, 0, +1}), it should drop below 0.3 within a few iterations. As more games end in draws, the value loss may stabilize at a low level because the target is consistently 0.0.
Evaluation results should show early candidates accepted (they easily beat the random initial network), middle candidates sometimes rejected (improvements are marginal), and late candidates mostly rejected (perfect play draws against perfect play, so no candidate can demonstrate superiority).
Empty board policy is the most visually satisfying diagnostic. In early iterations, the policy is roughly uniform β the network has no preference. By iteration 5-10, the center (cell 4) should emerge as the clear favorite. By iteration 15-20, the policy should be heavily concentrated on cell 4, with corners (cells 0, 2, 6, 8) receiving the remaining probability and edges (cells 1, 3, 5, 7) getting almost nothing. This matches optimal Tic-Tac-Toe strategy: center is the strongest opening move.
Step 4: Validate convergence to perfect play
After training completes, run three validation tests.
Test 1: Trained network vs pure MCTS. Play 100 games between the trained network (using 200 MCTS iterations) and pure MCTS (using 10,000 iterations β a massive search budget). If the network has converged to perfect play, every game should be a draw. Pure MCTS with 10,000 iterations plays near-perfectly through brute force; the trained network should match it with far less search.
#![allow(unused)]
fn main() {
// Validation: trained network vs strong pure MCTS.
println!("βββ Validation 1: Trained AI vs Pure MCTS (10k iters) βββ");
let mut draws = 0;
let mut rng = StdRng::seed_from_u64(123);
for game in 0..100 {
let mut state = GameState::new();
while !state.is_terminal() {
let mv = match state.current_player {
Player::X => {
if game % 2 == 0 {
mcts_network_move(&state, 200, &trained, 1.5)
} else {
mcts_move(&state, 10_000, &mut rng)
}
}
Player::O => {
if game % 2 == 0 {
mcts_move(&state, 10_000, &mut rng)
} else {
mcts_network_move(&state, 200, &trained, 1.5)
}
}
};
state = state.apply_move(mv);
}
if state.winner().is_none() {
draws += 1;
}
}
println!(" Result: {}/100 draws", draws);
if draws == 100 {
println!(" PASS β the network plays perfectly.");
} else {
println!(" Some decisive games β network may need more training.");
}
}
Test 2: Check the opening move. The optimal first move in Tic-Tac-Toe is the center (cell 4). Verify that the trained network always plays it:
#![allow(unused)]
fn main() {
println!("\nβββ Validation 2: Opening Move βββ");
let empty = GameState::new();
let mv = mcts_network_move(&empty, 200, &trained, 1.5);
println!(" First move: cell {} (expected: 4 = center)", mv);
if mv == 4 {
println!(" PASS β network plays center.");
} else {
println!(" UNEXPECTED β center is the optimal opening.");
}
}
Test 3: Human vs AI. Add an interactive mode where you can play against the trained network. Try to beat it β you should not be able to.
#![allow(unused)]
fn main() {
println!("\nβββ Validation 3: Human vs Trained AI βββ");
println!("You are X. Enter a cell number (0-8):");
println!(" 0 | 1 | 2");
println!(" ---------");
println!(" 3 | 4 | 5");
println!(" ---------");
println!(" 6 | 7 | 8");
println!();
let mut state = GameState::new();
while !state.is_terminal() {
state.print_board();
match state.current_player {
Player::X => {
// Human move
print!("Your move (0-8): ");
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let mut input = String::new();
std::io::stdin().read_line(&mut input).unwrap();
let mv: usize = input.trim().parse().expect("enter 0-8");
if !state.legal_moves().contains(&mv) {
println!("Illegal move. Try again.");
continue;
}
state = state.apply_move(mv);
}
Player::O => {
// AI move
let mv = mcts_network_move(&state, 200, &trained, 1.5);
println!("AI plays cell {}.", mv);
state = state.apply_move(mv);
}
}
}
state.print_board();
match state.winner() {
Some(Player::X) => println!("You win! (The AI has a bug.)"),
Some(Player::O) => println!("AI wins. Better luck next time."),
None => println!("Draw. That is the best you can do against perfect play."),
}
}
What you just built
Take a moment to appreciate what happened. You started with:
- A set of game rules (how to place Xβs and Oβs).
- A neural network initialized with random noise.
- An algorithm β self-play, train, evaluate, repeat.
And you ended with an AI that plays Tic-Tac-Toe perfectly. No one told it that the center is a good opening move. No one programmed in the concept of a βforkβ or a βblock.β No one provided any examples of expert play. The network discovered all of this on its own, purely by playing against itself and learning from the results.
The same algorithm β with no structural changes β was used to train AlphaGo Zero, which defeated the world champion at Go. The difference is only in scale: Go has 10^170 possible positions instead of 5,478; AlphaGo Zero used 4.9 million self-play games instead of 1,000; it trained on 64 GPU workers and 19 CPU parameter servers instead of your laptop. But the loop is identical: self-play, train, evaluate, repeat.
This is what makes the AlphaGo Zero approach remarkable. It is a general algorithm. It does not encode any game-specific strategy. Give it the rules of any two-player perfect-information game, and it will learn to play it at an expert level β given enough compute and enough iterations.
Where to go from here
You have a working AlphaGo Zero implementation for Tic-Tac-Toe. Here are some directions to extend it:
Scale up to Connect Four. Connect Four has a 7x6 board (42 cells), branching factor of 7, and games that last up to 42 moves. It is significantly more complex than Tic-Tac-Toe but still tractable on a single machine. You will need to adjust the network architecture (larger hidden layers), increase the MCTS iteration budget, and train for more iterations. The same algorithm applies without changes to its structure.
Add data augmentation via board symmetries. A Tic-Tac-Toe board has 8 symmetries: 4 rotations (0, 90, 180, 270 degrees) and 4 reflections. Every training example can be transformed into 8 equivalent examples for free. This multiplies your effective dataset by 8x and helps the network learn that the game is symmetric. In the AlphaGo Zero paper, this technique was crucial for efficient learning on the Go boardβs 8-fold symmetry.
Try different network architectures. The simple two-layer network in this course works for Tic-Tac-Toe but would struggle on larger games. AlphaGo Zero used a deep residual network (ResNet) with 20-40 residual blocks. Try adding residual connections (output = input + layer(input)) and see if training converges faster. If you move to a 2D board game like Connect Four, try convolutional layers β they are natural for grid-structured inputs.
Read the original papers. Now that you understand the algorithm from the ground up, the papers will make much more sense:
- Mastering the Game of Go without Human Knowledge (Silver et al., 2017) β the AlphaGo Zero paper. Describes the full training pipeline you just implemented.
- A General Reinforcement Learning Algorithm that Masters Chess, Shogi, and Go Through Self-Play (Silver et al., 2018) β the AlphaZero paper. Generalizes AlphaGo Zero to chess and shogi with minimal modifications.
- Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model (Schrittwieser et al., 2020) β MuZero. Removes the requirement of knowing the game rules by learning a model of the environment.
Explore other self-play domains. Self-play is not limited to board games. It has been applied to:
- Real-time strategy games (AlphaStar for StarCraft II)
- Card games with hidden information (Pluribus for six-player poker)
- Robotic control (learning dexterous manipulation through simulated self-play)
- Multi-agent emergent behavior (OpenAIβs hide-and-seek experiments)
The pattern is always the same: agents learn by competing against themselves, bootstrapping skill from zero.
Solution: complete main.rs for the full training pipeline
use rand::prelude::*;
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Game logic (Β§6)
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Player {
X,
O,
}
impl Player {
fn opponent(self) -> Player {
match self {
Player::X => Player::O,
Player::O => Player::X,
}
}
}
#[derive(Debug, Clone)]
struct GameState {
board: [Option<Player>; 9],
current_player: Player,
}
impl GameState {
fn new() -> Self {
GameState {
board: [None; 9],
current_player: Player::X,
}
}
fn legal_moves(&self) -> Vec<usize> {
(0..9).filter(|&i| self.board[i].is_none()).collect()
}
fn apply_move(&self, cell: usize) -> GameState {
let mut new = self.clone();
new.board[cell] = Some(self.current_player);
new.current_player = self.current_player.opponent();
new
}
fn winner(&self) -> Option<Player> {
const LINES: [[usize; 3]; 8] = [
[0, 1, 2], [3, 4, 5], [6, 7, 8],
[0, 3, 6], [1, 4, 7], [2, 5, 8],
[0, 4, 8], [2, 4, 6],
];
for line in &LINES {
if let Some(p) = self.board[line[0]] {
if self.board[line[1]] == Some(p)
&& self.board[line[2]] == Some(p)
{
return Some(p);
}
}
}
None
}
fn is_terminal(&self) -> bool {
self.winner().is_some() || self.legal_moves().is_empty()
}
fn print_board(&self) {
for row in 0..3 {
for col in 0..3 {
let i = row * 3 + col;
let ch = match self.board[i] {
Some(Player::X) => "X",
Some(Player::O) => "O",
None => ".",
};
print!(" {}", ch);
}
println!();
}
}
/// Encodes the board as a flat array of f32 for the network.
/// X = +1.0, O = -1.0, empty = 0.0, from the perspective of
/// the current player.
fn encode(&self) -> [f32; 9] {
let mut encoded = [0.0f32; 9];
for i in 0..9 {
encoded[i] = match self.board[i] {
Some(p) if p == self.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
}
encoded
}
}
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Neural network (Β§9-10)
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#[derive(Clone)]
struct Network {
// Input (9) β Hidden (64) β Policy (9) + Value (1)
w1: Vec<Vec<f32>>, // 64 x 9
b1: Vec<f32>, // 64
w_policy: Vec<Vec<f32>>, // 9 x 64
b_policy: Vec<f32>, // 9
w_value: Vec<Vec<f32>>, // 1 x 64
b_value: Vec<f32>, // 1
}
impl Network {
fn new(seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
let hidden = 64;
let input = 9;
let w1 = (0..hidden)
.map(|_| (0..input).map(|_| rng.gen_range(-0.5..0.5)).collect())
.collect();
let b1 = vec![0.0; hidden];
let w_policy = (0..9)
.map(|_| (0..hidden).map(|_| rng.gen_range(-0.5..0.5)).collect())
.collect();
let b_policy = vec![0.0; 9];
let w_value = vec![
(0..hidden).map(|_| rng.gen_range(-0.5..0.5)).collect()
];
let b_value = vec![0.0];
Network { w1, b1, w_policy, b_policy, w_value, b_value }
}
fn forward(&self, input: &[f32; 9]) -> (Vec<f32>, [f32; 9], f32) {
// Hidden layer with ReLU
let mut hidden = vec![0.0f32; self.w1.len()];
for (h, (weights, bias)) in hidden.iter_mut()
.zip(self.w1.iter().zip(self.b1.iter()))
{
*h = weights.iter().zip(input.iter())
.map(|(w, x)| w * x).sum::<f32>() + bias;
*h = h.max(0.0); // ReLU
}
// Policy head (softmax applied later)
let mut policy_logits = [0.0f32; 9];
for (i, (weights, bias)) in self.w_policy.iter()
.zip(self.b_policy.iter()).enumerate()
{
policy_logits[i] = weights.iter().zip(hidden.iter())
.map(|(w, h)| w * h).sum::<f32>() + bias;
}
// Value head (tanh)
let value_raw: f32 = self.w_value[0].iter().zip(hidden.iter())
.map(|(w, h)| w * h).sum::<f32>() + self.b_value[0];
let value = value_raw.tanh();
(hidden, policy_logits, value)
}
fn predict(&self, state: &GameState) -> ([f32; 9], f32) {
let input = state.encode();
let (_, logits, value) = self.forward(&input);
// Softmax
let max_logit = logits.iter().cloned()
.fold(f32::NEG_INFINITY, f32::max);
let mut policy = [0.0f32; 9];
let mut sum = 0.0f32;
for i in 0..9 {
policy[i] = (logits[i] - max_logit).exp();
sum += policy[i];
}
for p in policy.iter_mut() {
*p /= sum;
}
(policy, value)
}
fn train_step(
&mut self,
state: &GameState,
policy_target: &[f32; 9],
value_target: f32,
lr: f32,
) -> (f32, f32) {
let input = state.encode();
let (hidden, logits, value) = self.forward(&input);
// Softmax for policy
let max_logit = logits.iter().cloned()
.fold(f32::NEG_INFINITY, f32::max);
let mut policy = [0.0f32; 9];
let mut sum = 0.0f32;
for i in 0..9 {
policy[i] = (logits[i] - max_logit).exp();
sum += policy[i];
}
for p in policy.iter_mut() {
*p /= sum;
}
// Policy loss: cross-entropy
let policy_loss: f32 = -(0..9)
.map(|i| {
if policy_target[i] > 0.0 {
policy_target[i] * policy[i].max(1e-8).ln()
} else {
0.0
}
})
.sum::<f32>();
// Value loss: MSE
let value_loss = (value - value_target).powi(2);
// ββ Backprop ββ
// Policy gradient: d_logits[i] = policy[i] - target[i]
let mut d_logits = [0.0f32; 9];
for i in 0..9 {
d_logits[i] = policy[i] - policy_target[i];
}
// Value gradient: d_value = 2 * (value - target) * (1 - tanh^2)
let d_value = 2.0 * (value - value_target) * (1.0 - value * value);
// Update policy weights
for i in 0..9 {
for j in 0..hidden.len() {
self.w_policy[i][j] -= lr * d_logits[i] * hidden[j];
}
self.b_policy[i] -= lr * d_logits[i];
}
// Update value weights
for j in 0..hidden.len() {
self.w_value[0][j] -= lr * d_value * hidden[j];
}
self.b_value[0] -= lr * d_value;
// Gradient for hidden layer
let mut d_hidden = vec![0.0f32; hidden.len()];
for j in 0..hidden.len() {
for i in 0..9 {
d_hidden[j] += d_logits[i] * self.w_policy[i][j];
}
d_hidden[j] += d_value * self.w_value[0][j];
// ReLU gradient
if hidden[j] <= 0.0 {
d_hidden[j] = 0.0;
}
}
// Update input-to-hidden weights
for j in 0..hidden.len() {
for k in 0..9 {
self.w1[j][k] -= lr * d_hidden[j] * input[k];
}
self.b1[j] -= lr * d_hidden[j];
}
(policy_loss, value_loss)
}
}
/// Masks illegal moves by zeroing their probabilities and
/// renormalizing.
fn mask_illegal(policy: &[f32; 9], state: &GameState) -> [f32; 9] {
let mut masked = [0.0f32; 9];
let legal = state.legal_moves();
let mut sum = 0.0f32;
for &m in &legal {
masked[m] = policy[m];
sum += policy[m];
}
if sum > 0.0 {
for m in &legal {
masked[*m] /= sum;
}
} else {
// Uniform over legal moves
let uniform = 1.0 / legal.len() as f32;
for &m in &legal {
masked[m] = uniform;
}
}
masked
}
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// MCTS β pure and network-guided (Β§7-8, Β§12)
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#[derive(Debug, Clone)]
struct MctsNode {
state: GameState,
move_from_parent: Option<usize>,
parent: Option<usize>,
children: Vec<usize>,
visit_count: u32,
total_value: f64,
unexpanded_moves: Vec<usize>,
prior: f64,
}
struct MctsTree {
nodes: Vec<MctsNode>,
}
impl MctsTree {
fn new(state: GameState) -> Self {
let moves = state.legal_moves();
let root = MctsNode {
state,
move_from_parent: None,
parent: None,
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: moves,
prior: 0.0,
};
MctsTree { nodes: vec![root] }
}
fn best_move(&self) -> usize {
let root = &self.nodes[0];
root.children
.iter()
.max_by_key(|&&idx| self.nodes[idx].visit_count)
.map(|&idx| self.nodes[idx].move_from_parent.unwrap())
.expect("root must have children after search")
}
/// Pure MCTS with random rollouts.
fn search(&mut self, iterations: u32, rng: &mut impl Rng) {
for _ in 0..iterations {
let leaf = self.select(true);
let leaf = self.expand_pure(leaf, rng);
let value = self.rollout(leaf, rng);
self.backpropagate(leaf, value);
}
}
/// Network-guided MCTS.
fn search_with_network(
&mut self,
iterations: u32,
network: &Network,
c_puct: f64,
) {
for _ in 0..iterations {
let leaf = self.select_puct(c_puct);
let (leaf, value) =
self.expand_with_network(leaf, network);
self.backpropagate(leaf, value);
}
}
fn select(&mut self, use_uct: bool) -> usize {
let mut node = 0;
loop {
let n = &self.nodes[node];
if n.state.is_terminal() {
return node;
}
if !n.unexpanded_moves.is_empty() {
return node;
}
// Select child with highest UCT
let parent_visits = n.visit_count as f64;
let c = 1.414;
node = *n.children.iter().max_by(|&&a, &&b| {
let sa = self.uct_score(a, parent_visits, c);
let sb = self.uct_score(b, parent_visits, c);
sa.partial_cmp(&sb).unwrap()
}).unwrap();
}
}
fn select_puct(&self, c_puct: f64) -> usize {
let mut node = 0;
loop {
let n = &self.nodes[node];
if n.state.is_terminal() {
return node;
}
if !n.unexpanded_moves.is_empty() {
return node;
}
let parent_visits = n.visit_count;
node = *n.children.iter().max_by(|&&a, &&b| {
let sa = puct_score(
self.nodes[a].total_value,
self.nodes[a].visit_count,
parent_visits,
self.nodes[a].prior,
c_puct,
);
let sb = puct_score(
self.nodes[b].total_value,
self.nodes[b].visit_count,
parent_visits,
self.nodes[b].prior,
c_puct,
);
sa.partial_cmp(&sb).unwrap()
}).unwrap();
}
}
fn uct_score(&self, idx: usize, parent_visits: f64, c: f64) -> f64 {
let node = &self.nodes[idx];
if node.visit_count == 0 {
return f64::INFINITY;
}
let exploitation =
node.total_value / node.visit_count as f64;
let exploration =
c * (parent_visits.ln() / node.visit_count as f64).sqrt();
exploitation + exploration
}
fn expand_pure(
&mut self,
node: usize,
rng: &mut impl Rng,
) -> usize {
if self.nodes[node].state.is_terminal()
|| self.nodes[node].unexpanded_moves.is_empty()
{
return node;
}
let idx = rng.gen_range(
0..self.nodes[node].unexpanded_moves.len()
);
let mv = self.nodes[node].unexpanded_moves.swap_remove(idx);
let new_state = self.nodes[node].state.apply_move(mv);
let moves = new_state.legal_moves();
let child = MctsNode {
state: new_state,
move_from_parent: Some(mv),
parent: Some(node),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: moves,
prior: 0.0,
};
let child_idx = self.nodes.len();
self.nodes.push(child);
self.nodes[node].children.push(child_idx);
child_idx
}
fn expand_with_network(
&mut self,
node: usize,
network: &Network,
) -> (usize, f64) {
if self.nodes[node].state.is_terminal() {
// Return exact value for terminal states.
let value = match self.nodes[node].state.winner() {
Some(p) if p == self.nodes[node].state.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
return (node, value);
}
if self.nodes[node].unexpanded_moves.is_empty() {
// Already fully expanded β use network value.
let (_, value) =
network.predict(&self.nodes[node].state);
return (node, value as f64);
}
// Get network predictions for this state.
let (policy, value) =
network.predict(&self.nodes[node].state);
let masked =
mask_illegal(&policy, &self.nodes[node].state);
// Expand ALL children at once with priors.
let moves: Vec<usize> =
self.nodes[node].unexpanded_moves.drain(..).collect();
for mv in &moves {
let new_state =
self.nodes[node].state.apply_move(*mv);
let child_moves = new_state.legal_moves();
let child = MctsNode {
state: new_state,
move_from_parent: Some(*mv),
parent: Some(node),
children: Vec::new(),
visit_count: 0,
total_value: 0.0,
unexpanded_moves: child_moves,
prior: masked[*mv] as f64,
};
let child_idx = self.nodes.len();
self.nodes.push(child);
self.nodes[node].children.push(child_idx);
}
(node, value as f64)
}
fn rollout(&self, node: usize, rng: &mut impl Rng) -> f64 {
let mut state = self.nodes[node].state.clone();
while !state.is_terminal() {
let moves = state.legal_moves();
let mv = moves[rng.gen_range(0..moves.len())];
state = state.apply_move(mv);
}
match state.winner() {
Some(p) if p == self.nodes[node].state.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
}
}
fn backpropagate(&mut self, mut node: usize, mut value: f64) {
loop {
self.nodes[node].visit_count += 1;
self.nodes[node].total_value += value;
value = -value; // Flip perspective
if let Some(parent) = self.nodes[node].parent {
node = parent;
} else {
break;
}
}
}
}
fn puct_score(
child_value: f64,
child_visits: u32,
parent_visits: u32,
prior: f64,
c_puct: f64,
) -> f64 {
let q = if child_visits == 0 {
0.0
} else {
child_value / child_visits as f64
};
let exploration = c_puct
* prior
* (parent_visits as f64).sqrt()
/ (1.0 + child_visits as f64);
q + exploration
}
/// Returns the best move using pure MCTS with random rollouts.
fn mcts_move(
state: &GameState,
iterations: u32,
rng: &mut impl Rng,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search(iterations, rng);
tree.best_move()
}
/// Returns the best move using network-guided MCTS.
fn mcts_network_move(
state: &GameState,
iterations: u32,
network: &Network,
c_puct: f64,
) -> usize {
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(iterations, network, c_puct);
tree.best_move()
}
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Training data and self-play (Β§11, Β§13)
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
struct TrainingExample {
state: GameState,
policy_target: [f32; 9],
value_target: f32,
}
/// Extracts the visit distribution from the root of an MCTS tree.
fn visit_distribution(tree: &MctsTree) -> [f32; 9] {
let root = &tree.nodes[0];
let total: u32 = root.children.iter()
.map(|&idx| tree.nodes[idx].visit_count)
.sum();
let mut dist = [0.0f32; 9];
if total == 0 {
return dist;
}
for &idx in &root.children {
let mv = tree.nodes[idx].move_from_parent.unwrap();
dist[mv] = tree.nodes[idx].visit_count as f32 / total as f32;
}
dist
}
fn select_move_with_temperature(
tree: &MctsTree,
temperature: f64,
rng: &mut impl Rng,
) -> usize {
let root = &tree.nodes[0];
if temperature < 0.01 {
return tree.best_move();
}
let inv_temp = 1.0 / temperature;
let scaled: Vec<f64> = root.children.iter()
.map(|&idx| {
(tree.nodes[idx].visit_count as f64).powf(inv_temp)
})
.collect();
let total: f64 = scaled.iter().sum();
if total == 0.0 {
return tree.best_move();
}
let mut r: f64 = rng.gen::<f64>() * total;
for (i, &s) in scaled.iter().enumerate() {
r -= s;
if r <= 0.0 {
let child_idx = root.children[i];
return tree.nodes[child_idx]
.move_from_parent
.expect("child must have a move");
}
}
let last = *root.children.last().unwrap();
tree.nodes[last].move_from_parent
.expect("child must have a move")
}
fn self_play_game(
network: &Network,
mcts_iterations: u32,
c_puct: f64,
temp_moves: usize,
rng: &mut impl Rng,
) -> Vec<TrainingExample> {
let mut state = GameState::new();
let mut history: Vec<(GameState, [f32; 9])> = Vec::new();
let mut move_number = 0;
while !state.is_terminal() {
let mut tree = MctsTree::new(state.clone());
tree.search_with_network(mcts_iterations, network, c_puct);
let dist = visit_distribution(&tree);
history.push((state.clone(), dist));
let temperature = if move_number < temp_moves {
1.0
} else {
0.0
};
let chosen_move =
select_move_with_temperature(&tree, temperature, rng);
state = state.apply_move(chosen_move);
move_number += 1;
}
let winner = state.winner();
history.into_iter().map(|(s, policy_target)| {
let value_target = match winner {
Some(w) if w == s.current_player => 1.0,
Some(_) => -1.0,
None => 0.0,
};
TrainingExample { state: s, policy_target, value_target }
}).collect()
}
fn shuffle_examples(examples: &mut Vec<TrainingExample>, rng: &mut impl Rng) {
for i in (1..examples.len()).rev() {
let j = rng.gen_range(0..=i);
examples.swap(i, j);
}
}
struct TrainingLoss {
policy_loss: f32,
value_loss: f32,
}
fn train_on_examples(
net: &mut Network,
examples: &mut Vec<TrainingExample>,
epochs: u32,
batch_size: usize,
learning_rate: f32,
rng: &mut impl Rng,
) -> TrainingLoss {
let mut final_policy_loss = 0.0f32;
let mut final_value_loss = 0.0f32;
for _epoch in 0..epochs {
shuffle_examples(examples, rng);
let mut epoch_policy_loss = 0.0f32;
let mut epoch_value_loss = 0.0f32;
let mut count = 0;
for batch_start in (0..examples.len()).step_by(batch_size) {
let batch_end =
(batch_start + batch_size).min(examples.len());
for i in batch_start..batch_end {
let ex = &examples[i];
let (pl, vl) = net.train_step(
&ex.state,
&ex.policy_target,
ex.value_target,
learning_rate,
);
epoch_policy_loss += pl;
epoch_value_loss += vl;
count += 1;
}
}
final_policy_loss = epoch_policy_loss / count as f32;
final_value_loss = epoch_value_loss / count as f32;
}
TrainingLoss {
policy_loss: final_policy_loss,
value_loss: final_value_loss,
}
}
fn play_evaluation_game(
network_x: &Network,
network_o: &Network,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut state = GameState::new();
while !state.is_terminal() {
let network = match state.current_player {
Player::X => network_x,
Player::O => network_o,
};
let mv = mcts_network_move(
&state, mcts_iterations, network, c_puct,
);
state = state.apply_move(mv);
}
match state.winner() {
Some(Player::X) => 1.0,
Some(Player::O) => -1.0,
None => 0.0,
}
}
fn evaluate_networks(
candidate: &Network,
current_best: &Network,
num_games: u32,
mcts_iterations: u32,
c_puct: f64,
) -> f64 {
let mut candidate_wins = 0;
let mut total_decisive = 0;
for game in 0..num_games {
let result = if game % 2 == 0 {
play_evaluation_game(
candidate, current_best,
mcts_iterations, c_puct,
)
} else {
-play_evaluation_game(
current_best, candidate,
mcts_iterations, c_puct,
)
};
if result > 0.0 {
candidate_wins += 1;
total_decisive += 1;
} else if result < 0.0 {
total_decisive += 1;
}
}
if total_decisive == 0 {
0.5
} else {
candidate_wins as f64 / total_decisive as f64
}
}
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// Full training loop with diagnostics (Β§14)
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
struct TrainingConfig {
num_iterations: u32,
games_per_iteration: u32,
mcts_iterations: u32,
c_puct: f64,
temp_moves: usize,
training_epochs: u32,
batch_size: usize,
learning_rate: f32,
eval_games: u32,
win_threshold: f64,
}
fn alphago_zero_training(
config: &TrainingConfig,
seed: u64,
) -> Network {
let mut rng = StdRng::seed_from_u64(seed);
let mut best_network = Network::new(rng.gen());
println!("Starting AlphaGo Zero training loop");
println!(" Iterations: {}", config.num_iterations);
println!(" Games/iter: {}", config.games_per_iteration);
println!(" MCTS iters: {}", config.mcts_iterations);
println!(" Total games: {}",
config.num_iterations * config.games_per_iteration);
println!();
for iteration in 1..=config.num_iterations {
// ββ Phase 1: Self-play βββββββββββββββββββββββββββββ
let mut all_examples = Vec::new();
let mut total_moves = 0u32;
let mut x_wins = 0u32;
let mut o_wins = 0u32;
let mut draws = 0u32;
for _game in 0..config.games_per_iteration {
let examples = self_play_game(
&best_network,
config.mcts_iterations,
config.c_puct,
config.temp_moves,
&mut rng,
);
let game_length = examples.len() as u32;
total_moves += game_length;
if let Some(last) = examples.last() {
match last.value_target.partial_cmp(&0.0) {
Some(std::cmp::Ordering::Greater) => {
match last.state.current_player {
Player::X => x_wins += 1,
Player::O => o_wins += 1,
}
}
Some(std::cmp::Ordering::Less) => {
match last.state.current_player {
Player::X => o_wins += 1,
Player::O => x_wins += 1,
}
}
_ => draws += 1,
}
}
all_examples.extend(examples);
}
let avg_game_length =
total_moves as f64 / config.games_per_iteration as f64;
// ββ Phase 2: Train βββββββββββββββββββββββββββββββββ
let mut candidate = best_network.clone();
let loss = train_on_examples(
&mut candidate,
&mut all_examples,
config.training_epochs,
config.batch_size,
config.learning_rate,
&mut rng,
);
// ββ Phase 3: Evaluate ββββββββββββββββββββββββββββββ
let win_rate = evaluate_networks(
&candidate,
&best_network,
config.eval_games,
config.mcts_iterations,
config.c_puct,
);
let accepted = win_rate >= config.win_threshold;
if accepted {
best_network = candidate;
}
// ββ Print diagnostics ββββββββββββββββββββββββββββββ
let (policy, value) =
best_network.predict(&GameState::new());
let masked =
mask_illegal(&policy, &GameState::new());
let best_cell = masked.iter().enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap().0;
let cell_name = match best_cell {
0 => "top-left", 1 => "top-center",
2 => "top-right", 3 => "mid-left",
4 => "center", 5 => "mid-right",
6 => "bot-left", 7 => "bot-center",
8 => "bot-right", _ => "unknown",
};
println!(
"βββ Iteration {}/{} βββ",
iteration, config.num_iterations
);
println!(
" Self-play: {} games, {} examples",
config.games_per_iteration, all_examples.len()
);
println!(
" Game length: {:.1} moves (avg)",
avg_game_length
);
println!(
" Outcomes: X wins {} | O wins {} | Draws {}",
x_wins, o_wins, draws
);
println!(
" Training: policy_loss = {:.4} value_loss = {:.4}",
loss.policy_loss, loss.value_loss
);
println!(
" Evaluation: candidate win rate = {:.1}% β {}",
win_rate * 100.0,
if accepted { "ACCEPTED" } else { "REJECTED" }
);
println!(
" Empty board: best move = cell {} ({})",
best_cell, cell_name
);
println!(
" policy = [{:.2} {:.2} {:.2} {:.2} \
{:.2} {:.2} {:.2} {:.2} {:.2}]",
masked[0], masked[1], masked[2], masked[3],
masked[4], masked[5], masked[6], masked[7],
masked[8]
);
println!(
" value = {:+.3}",
value
);
println!();
}
best_network
}
fn main() {
// ββ Train ββββββββββββββββββββββββββββββββββββββββββββββ
let config = TrainingConfig {
num_iterations: 20,
games_per_iteration: 50,
mcts_iterations: 200,
c_puct: 1.5,
temp_moves: 3,
training_epochs: 20,
batch_size: 32,
learning_rate: 0.01,
eval_games: 40,
win_threshold: 0.55,
};
let trained = alphago_zero_training(&config, 42);
// ββ Validation 1: vs strong pure MCTS ββββββββββββββββββ
println!("βββ Validation 1: Trained AI vs Pure MCTS (10k) βββ");
let mut rng = StdRng::seed_from_u64(123);
let mut val_draws = 0;
for game in 0..100 {
let mut state = GameState::new();
while !state.is_terminal() {
let mv = match state.current_player {
Player::X => {
if game % 2 == 0 {
mcts_network_move(
&state, 200, &trained, 1.5,
)
} else {
mcts_move(&state, 10_000, &mut rng)
}
}
Player::O => {
if game % 2 == 0 {
mcts_move(&state, 10_000, &mut rng)
} else {
mcts_network_move(
&state, 200, &trained, 1.5,
)
}
}
};
state = state.apply_move(mv);
}
if state.winner().is_none() {
val_draws += 1;
}
}
println!(" Result: {}/100 draws", val_draws);
if val_draws == 100 {
println!(" PASS β the network plays perfectly.");
} else {
println!(
" Some decisive games β network may need \
more training."
);
}
// ββ Validation 2: opening move βββββββββββββββββββββββββ
println!("\nβββ Validation 2: Opening Move βββ");
let empty = GameState::new();
let mv = mcts_network_move(&empty, 200, &trained, 1.5);
println!(
" First move: cell {} (expected: 4 = center)", mv
);
if mv == 4 {
println!(" PASS β network plays center.");
} else {
println!(" UNEXPECTED β center is the optimal opening.");
}
// ββ Validation 3: human vs AI ββββββββββββββββββββββββββ
println!("\nβββ Validation 3: Human vs Trained AI βββ");
println!("You are X. Enter a cell number (0-8):");
println!(" 0 | 1 | 2");
println!(" ---------");
println!(" 3 | 4 | 5");
println!(" ---------");
println!(" 6 | 7 | 8");
println!();
let mut state = GameState::new();
while !state.is_terminal() {
state.print_board();
match state.current_player {
Player::X => {
print!("Your move (0-8): ");
std::io::Write::flush(
&mut std::io::stdout(),
).unwrap();
let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.unwrap();
let mv: usize = input
.trim()
.parse()
.expect("enter 0-8");
if !state.legal_moves().contains(&mv) {
println!("Illegal move. Try again.");
continue;
}
state = state.apply_move(mv);
}
Player::O => {
let mv = mcts_network_move(
&state, 200, &trained, 1.5,
);
println!("AI plays cell {}.", mv);
state = state.apply_move(mv);
}
}
}
state.print_board();
match state.winner() {
Some(Player::X) => {
println!("You win! (The AI has a bug.)")
}
Some(Player::O) => {
println!("AI wins. Better luck next time.")
}
None => println!(
"Draw. That is the best you can do against \
perfect play."
),
}
}
Congratulations
You have built a complete AlphaGo Zero implementation from scratch. Starting with nothing but the rules of Tic-Tac-Toe and a bag of random numbers, you watched an AI discover optimal play through pure self-improvement. No human games to study. No handcrafted evaluation functions. No opening books. Just a loop: play, learn, evaluate, repeat.
The network you trained has internalized something remarkable. It does not search through game trees at test time (MCTS does that). Instead, it has compressed the intuition of thousands of games into a set of floating-point weights. When it looks at an empty board and says βplay center,β it is not following a rule someone programmed β it learned that from experience, the same way a human player would.
This is the same algorithm that, scaled up with deeper networks and more compute, mastered Go, chess, and shogi at a superhuman level. The architecture changes. The training budget grows by orders of magnitude. But the core idea β self-play generating training data, a neural network learning from that data, and an evaluation gate preventing regression β remains exactly what you built in this chapter.
Well done.