We introduce a new parameterization of the reward model in RLHF that enables
Supervised Fine-tuning (SFT): fine-tuning a pre-trained LM, resulting a model $\pi^{SFT}$
Preference Sampling and Reward Learning:
Generating answer pairs: $(y_1, y_2) \sim \pi^{SFT}(y \vert x)$ where $x$ denotes the input prompt
Getting preferences from human: $y_w \succ y_l \vert x$. The preferences are assumed to be generated by some latent reward model $r^{\star}(y, x)$
Forming the preference dataset $D = \{x^{(i)}, y_w^{(i)}, y_l^{(i)}\}^N_{i=1}$: the dataset $D$ is sampled from the human preference distribution $p^{\star}$, and is used to parameterize a reward model $r_\phi(x,y)$
The human preference distribution is modeled by BT model as:
$p^(y_1 \succ y_2 \mid x) =\frac{\exp \big( r^(x, y_1) \big)}{\exp \big( r^(x, y_1) \big) + \exp \big( r^(x, y_2) \big)}.$
The latent reward model and its signal is simulated by $r_\phi(x,y)$ in practice, whose parameters shall be further updated along training.
Learning / parameterizing the reward model: the reward model is learned via maximum likelihood. To be precise, the problem is framed as a binary classification, we aim to minimize the negative expected log-likelihood loss:
$\mathcal{L}R (r{\phi}, \mathcal{D}) = - \mathbb{E}{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma \big(r{\phi} (x, y_w) - r_{\phi} (x, y_l) \big) \right]$
RL optimization: During RL phase, we aim to maximize the reward given by the learned reward model and penalized by KL divergence as:
$\max_{\pi_{\theta}} \mathbb{E}{x \sim \mathcal{D}, y \sim \pi{\theta}(y \mid x)} \left[ r_{\phi}(x, y) \right] - \beta \mathbb{D}{\text{KL}} \left[ \pi{\theta}(y \mid x) \| \pi_{\text{ref}}(y \mid x) \right],$
Our approach leverages a particular choice of reward model parameterization that enables extraction of its optimal policy in closed form, without an RL training loop. The key insight is to leverage an analytical mapping from reward functions to optimal policies, which enables us to transform a loss function over reward functions into a loss function over policies.