Notations:
Construction of the MC Tree: For each question, we build a Monte Carlo Tree, as shown in Fig. 1 (THIS IS THE FINAL OUTPUT FOR EACH QUESTION)
Each node $s$ in the tree contains the question $q$ and prefix solution $x_{1:t}$, together with all previous rollouts $\{(s, r_i)\}_{i=1}^k$ from the state. The nodes also store a set of statistics $\{N(s), \text{MC}(s), Q(s, r)\},$
where $N(s)$ denotes the visit count of a state,
$\text{MC}(s)$ represents the Monte Carlo estimation of a state, calculated as follows:
the Monte Carlo (MC) ratio: $c_t = MonteCarlo(q, x_{1:t}) = MonteCarlo(s) = \frac{\text{num(correct rollouts from } t\text{-th step)}}{\text{num(total rollouts from } t\text{-th step)}}$ where $c_t$ measures the proportion of correct rollouts from the $t$-th step.
visualization:
$Q(s, r)$ is a state-rollout value function that is correlated to the chance of selecting a rollout during the selection phase of tree traversal.
Specifically, $Q(s, r) = \alpha^{1 - \text{MC}(s)} \cdot \beta^{\frac{\text{len}(r)}{L}},$ where $\alpha$ and $\beta$ are constants, and $L$ represents the maximum rollout length.
Each edge $(s, a)$ is either a single step or a sequence of consecutive steps from the node $s$.
Monte Carlo Tree Search
Objective: As suggested by Lightman et al. (2023), supervising up to the first incorrect step in a solution is sufficient to train a PRM. Therefore, our objective is locating the first error in an efficient way.
Overview of three stages: The dotted lines in $Select$ stage represent the available rollouts for binary search. The bold colored edges represent steps with correctness estimations. The yellow color indicates a correct step, i.e., with a preceding state $s$ that $MC(s) > 0$ and the blue color indicates an incorrect step, i.e., with $MC(s) = 0$. The number of dashes in each colored edge indicates the number of steps.
Selection Stage: In selection phase, we maintain a pool of all rollouts $\{(s_i, r_i^t)\}$ from previous searches that satisfy $0 < \text{MC}(s_i) < 1$.
Binary Search Stage (to identify the first error location in the selected rollout): Given the inefficiency of performing rollouts for every step (as done in previous works), a binary search approach is proposed to find the first incorrect step as follows:
Split the solution at the midpoint $m$ and perform rollouts for $x_{1:m}$.
If $c_m > 0$, indicating at least one correct rollout, the error is in the latter half of the solution.
If $c_m = 0$, the error is in the first half.
This process iterates until the first error is isolated, reducing complexity to $O(k \log M)$ instead of $O(kM)$, where $M$ is the total steps in the solution.
Visualization:
All divide-and-rollout positions before the first error become new states. The trajectory $s[q] โ s[q, x_{1:4}] โ s[q, x_{1:6}] โ s[q, x_{1:7}]$ is added to the tree after the binary search. The edges $s[q] โ s[q, x_{1:4}]$ and $s[q, x_{1:4}] โ s[q, x_{1:6}]$ are correct, with MC values of 0.25 and 0.5, respectively; while the edge $s[q, x_{1:6}] โ s[q, x_{1:7}]$ is incorrect with MC value of 0.
Maintain Stage: After the binary search, the tree statistics $N(s), MC(s)$ and $Q(s,r)$ are updated. Specifically, $N(s)$ is incremented by 1 for the selected $(s,r)$. Both $MC(s)$ and $Q(s,r)$ are updated for the new rollouts sampled from the binary search.
PRM Training: We use the pointwise soft label (use the Monte Carlo estimation as the correctness label) when evaluating the main result
Main results
Step Distribution:
Different to Lightman et al., 2023; Wang et al., 2024a,b which use newline as delimiters, we propose a more flexible method for step division, treating any sequence of consecutive tokens in a solution as a valid step. We observe that many step divisions in Math-Shepherd lack semantic coherence to some extent. Therefore, we hypothesize that semantically explicit cutting is not necessary for training a PRM.
During binary search, we aim to divide a full solution into 16 pieces. To calculate the expected step length, we divide the average solution length by 16. The binary search terminates when a step is shorter than this value.