Skip to main content
Reinforcement Learning for Deep Learning

Reinforcement Learning for Deep Learning

Beyond Supervised Learning

Sometimes the “correct answer” isn’t available or isn’t what we want. RL provides a framework for:
  • Learning from human preferences (RLHF)
  • Optimizing non-differentiable objectives
  • Aligning AI systems with human values
  • Training agents that interact with environments
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass

torch.manual_seed(42)

RL Fundamentals for Deep Learning

The RL Framework

┌─────────────┐    action    ┌─────────────────┐
│             │──────────────▶│                 │
│    Agent    │              │   Environment   │
│   (Model)   │◀──────────────│                 │
│             │  state,reward │                 │
└─────────────┘               └─────────────────┘
Key concepts:
  • State ss: Current situation
  • Action aa: What the agent does
  • Reward rr: Feedback signal
  • Policy π(as)\pi(a|s): Probability of taking action aa in state ss
  • Value V(s)V(s): Expected cumulative reward from state ss

Policy Gradient

The fundamental theorem for learning policies: θJ(θ)=Eτπθ[t=0Tθlogπθ(atst)R(τ)]\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot R(\tau)\right]
class PolicyNetwork(nn.Module):
    """Simple policy network for discrete actions."""
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, state: torch.Tensor) -> torch.distributions.Categorical:
        logits = self.network(state)
        return torch.distributions.Categorical(logits=logits)
    
    def get_action(self, state: torch.Tensor) -> Tuple[int, torch.Tensor]:
        dist = self.forward(state)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob


class ValueNetwork(nn.Module):
    """Value function estimator."""
    
    def __init__(self, state_dim: int, hidden_dim: int = 256):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.network(state).squeeze(-1)

REINFORCE Algorithm

class REINFORCE:
    """Basic policy gradient algorithm."""
    
    def __init__(
        self,
        policy: PolicyNetwork,
        lr: float = 1e-3,
        gamma: float = 0.99
    ):
        self.policy = policy
        self.optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
        self.gamma = gamma
    
    def compute_returns(self, rewards: List[float]) -> torch.Tensor:
        """Compute discounted returns."""
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + self.gamma * R
            returns.insert(0, R)
        
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        return returns
    
    def update(self, trajectories: List[Dict]):
        """Update policy from collected trajectories."""
        
        total_loss = 0
        
        for trajectory in trajectories:
            log_probs = trajectory['log_probs']
            rewards = trajectory['rewards']
            
            returns = self.compute_returns(rewards)
            
            # Policy gradient loss
            policy_loss = 0
            for log_prob, R in zip(log_probs, returns):
                policy_loss -= log_prob * R
            
            total_loss += policy_loss
        
        # Optimize
        self.optimizer.zero_grad()
        (total_loss / len(trajectories)).backward()
        self.optimizer.step()
        
        return total_loss.item() / len(trajectories)

Proximal Policy Optimization (PPO)

PPO is the workhorse of modern RL, used in RLHF and many robotics applications.

The PPO Objective

LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]L^{CLIP}(\theta) = \mathbb{E}_t\left[\min\left(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t\right)\right] where rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}
@dataclass
class PPOConfig:
    """PPO hyperparameters."""
    clip_epsilon: float = 0.2
    gamma: float = 0.99
    gae_lambda: float = 0.95
    value_coef: float = 0.5
    entropy_coef: float = 0.01
    max_grad_norm: float = 0.5
    ppo_epochs: int = 4
    batch_size: int = 64


class PPO:
    """Proximal Policy Optimization implementation."""
    
    def __init__(
        self,
        policy: PolicyNetwork,
        value: ValueNetwork,
        config: PPOConfig = PPOConfig()
    ):
        self.policy = policy
        self.value = value
        self.config = config
        
        self.optimizer = torch.optim.Adam([
            {'params': policy.parameters(), 'lr': 3e-4},
            {'params': value.parameters(), 'lr': 1e-3}
        ])
    
    def compute_gae(
        self,
        rewards: torch.Tensor,
        values: torch.Tensor,
        dones: torch.Tensor,
        next_value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute Generalized Advantage Estimation."""
        
        advantages = torch.zeros_like(rewards)
        last_gae = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_val = next_value
            else:
                next_val = values[t + 1]
            
            delta = rewards[t] + self.config.gamma * next_val * (1 - dones[t]) - values[t]
            advantages[t] = delta + self.config.gamma * self.config.gae_lambda * (1 - dones[t]) * last_gae
            last_gae = advantages[t]
        
        returns = advantages + values
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        return advantages, returns
    
    def update(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        old_log_probs: torch.Tensor,
        rewards: torch.Tensor,
        dones: torch.Tensor,
        next_value: torch.Tensor
    ):
        """PPO update step."""
        
        # Compute advantages
        with torch.no_grad():
            values = self.value(states)
            advantages, returns = self.compute_gae(rewards, values, dones, next_value)
        
        # PPO epochs
        for _ in range(self.config.ppo_epochs):
            # Mini-batch updates
            indices = torch.randperm(len(states))
            
            for start in range(0, len(states), self.config.batch_size):
                end = start + self.config.batch_size
                batch_indices = indices[start:end]
                
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]
                
                # Current policy evaluation
                dist = self.policy(batch_states)
                log_probs = dist.log_prob(batch_actions)
                entropy = dist.entropy().mean()
                
                # Ratio for clipping
                ratio = torch.exp(log_probs - batch_old_log_probs)
                
                # Clipped surrogate objective
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # Value loss
                value_pred = self.value(batch_states)
                value_loss = F.mse_loss(value_pred, batch_returns)
                
                # Total loss
                loss = (
                    policy_loss 
                    + self.config.value_coef * value_loss 
                    - self.config.entropy_coef * entropy
                )
                
                # Optimize
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(
                    list(self.policy.parameters()) + list(self.value.parameters()),
                    self.config.max_grad_norm
                )
                self.optimizer.step()
        
        return {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'entropy': entropy.item()
        }

Reinforcement Learning from Human Feedback (RLHF)

RLHF is how models like ChatGPT are aligned with human preferences.

The RLHF Pipeline

1. Supervised Fine-Tuning (SFT)
   └── Train on high-quality demonstrations

2. Reward Model Training
   └── Train a model to predict human preferences

3. RL Fine-Tuning (PPO)
   └── Optimize policy to maximize reward model

Reward Model

class RewardModel(nn.Module):
    """Reward model trained on human preferences."""
    
    def __init__(self, base_model: nn.Module):
        super().__init__()
        
        self.base_model = base_model
        self.value_head = nn.Linear(base_model.hidden_size, 1)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            input_ids: [batch_size, seq_len]
            attention_mask: [batch_size, seq_len]
        
        Returns:
            rewards: [batch_size] - scalar reward for each sequence
        """
        # Get hidden states from base model
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        
        # Use last token's hidden state (for causal LMs)
        # Find last non-padding token
        seq_lengths = attention_mask.sum(dim=1) - 1
        batch_indices = torch.arange(input_ids.size(0))
        last_hidden = hidden_states[batch_indices, seq_lengths]
        
        # Predict reward
        rewards = self.value_head(last_hidden).squeeze(-1)
        
        return rewards
    
    def compute_preference_loss(
        self,
        chosen_ids: torch.Tensor,
        chosen_mask: torch.Tensor,
        rejected_ids: torch.Tensor,
        rejected_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Bradley-Terry preference loss.
        
        P(chosen > rejected) = sigmoid(r_chosen - r_rejected)
        """
        r_chosen = self.forward(chosen_ids, chosen_mask)
        r_rejected = self.forward(rejected_ids, rejected_mask)
        
        # Preference loss: we want r_chosen > r_rejected
        loss = -F.logsigmoid(r_chosen - r_rejected).mean()
        
        return loss


def train_reward_model(
    model: RewardModel,
    dataset,  # List of (chosen, rejected) pairs
    epochs: int = 3,
    batch_size: int = 16
):
    """Train reward model on preference data."""
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in dataloader:
            chosen_ids, chosen_mask = batch['chosen']
            rejected_ids, rejected_mask = batch['rejected']
            
            loss = model.compute_preference_loss(
                chosen_ids, chosen_mask,
                rejected_ids, rejected_mask
            )
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            # Accuracy
            with torch.no_grad():
                r_chosen = model(chosen_ids, chosen_mask)
                r_rejected = model(rejected_ids, rejected_mask)
                correct += (r_chosen > r_rejected).sum().item()
                total += len(r_chosen)
        
        print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}, Acc = {100*correct/total:.2f}%")

RLHF Training Loop

class RLHFTrainer:
    """Complete RLHF training pipeline."""
    
    def __init__(
        self,
        policy_model: nn.Module,
        reference_model: nn.Module,  # Frozen copy
        reward_model: RewardModel,
        tokenizer,
        kl_coef: float = 0.1,
        clip_reward: float = 10.0
    ):
        self.policy = policy_model
        self.reference = reference_model
        self.reward_model = reward_model
        self.tokenizer = tokenizer
        self.kl_coef = kl_coef
        self.clip_reward = clip_reward
        
        # Freeze reference model
        for param in self.reference.parameters():
            param.requires_grad = False
        
        self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=1e-6)
    
    def compute_rewards(
        self,
        query_ids: torch.Tensor,
        response_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict]:
        """Compute rewards with KL penalty."""
        
        # Full sequence = query + response
        full_ids = torch.cat([query_ids, response_ids], dim=1)
        
        # Reward from reward model
        rm_reward = self.reward_model(full_ids, attention_mask)
        rm_reward = torch.clamp(rm_reward, -self.clip_reward, self.clip_reward)
        
        # KL divergence penalty
        with torch.no_grad():
            ref_logits = self.reference(full_ids, attention_mask=attention_mask).logits
        policy_logits = self.policy(full_ids, attention_mask=attention_mask).logits
        
        # Compute KL per token
        ref_log_probs = F.log_softmax(ref_logits, dim=-1)
        policy_log_probs = F.log_softmax(policy_logits, dim=-1)
        
        # KL(policy || reference)
        kl = (torch.exp(policy_log_probs) * (policy_log_probs - ref_log_probs)).sum(dim=-1)
        kl = kl.mean(dim=-1)  # Average over sequence
        
        # Total reward = RM reward - KL penalty
        rewards = rm_reward - self.kl_coef * kl
        
        stats = {
            'rm_reward': rm_reward.mean().item(),
            'kl': kl.mean().item(),
            'total_reward': rewards.mean().item()
        }
        
        return rewards, stats
    
    def generate_and_score(
        self,
        prompts: List[str],
        max_new_tokens: int = 256
    ) -> Tuple[List[str], torch.Tensor]:
        """Generate responses and compute rewards."""
        
        # Tokenize prompts
        inputs = self.tokenizer(prompts, return_tensors='pt', padding=True)
        query_ids = inputs['input_ids'].cuda()
        
        # Generate responses
        with torch.no_grad():
            outputs = self.policy.generate(
                query_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=1.0,
                pad_token_id=self.tokenizer.pad_token_id
            )
        
        response_ids = outputs[:, query_ids.size(1):]
        
        # Create attention mask
        attention_mask = (outputs != self.tokenizer.pad_token_id).long()
        
        # Compute rewards
        rewards, stats = self.compute_rewards(query_ids, response_ids, attention_mask)
        
        # Decode responses
        responses = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
        
        return responses, rewards, stats
    
    def ppo_step(
        self,
        prompts: List[str],
        ppo_epochs: int = 4,
        batch_size: int = 4
    ):
        """One PPO training step."""
        
        # Generate responses and get rewards
        responses, rewards, stats = self.generate_and_score(prompts)
        
        # Get old log probs for PPO
        inputs = self.tokenizer(
            [p + r for p, r in zip(prompts, responses)],
            return_tensors='pt', padding=True
        )
        
        with torch.no_grad():
            old_logits = self.policy(inputs['input_ids'].cuda()).logits
            old_log_probs = F.log_softmax(old_logits, dim=-1)
        
        # PPO epochs
        for _ in range(ppo_epochs):
            new_logits = self.policy(inputs['input_ids'].cuda()).logits
            new_log_probs = F.log_softmax(new_logits, dim=-1)
            
            # Compute ratio
            ratio = torch.exp(new_log_probs - old_log_probs).mean(dim=-1).mean(dim=-1)
            
            # Clipped objective
            clip_epsilon = 0.2
            surr1 = ratio * rewards
            surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * rewards
            loss = -torch.min(surr1, surr2).mean()
            
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
            self.optimizer.step()
        
        stats['ppo_loss'] = loss.item()
        return stats

Direct Preference Optimization (DPO)

DPO simplifies RLHF by eliminating the reward model entirely.

DPO Objective

LDPO=E[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{DPO} = -\mathbb{E}\left[\log \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)}\right)\right]
class DPOTrainer:
    """Direct Preference Optimization trainer."""
    
    def __init__(
        self,
        policy_model: nn.Module,
        reference_model: nn.Module,
        beta: float = 0.1,
        label_smoothing: float = 0.0
    ):
        self.policy = policy_model
        self.reference = reference_model
        self.beta = beta
        self.label_smoothing = label_smoothing
        
        # Freeze reference
        for param in self.reference.parameters():
            param.requires_grad = False
        
        self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=5e-7)
    
    def compute_log_probs(
        self,
        model: nn.Module,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Compute log probability of sequence."""
        
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits[:, :-1]  # Shift for next token prediction
        labels = labels[:, 1:]  # Shift labels
        
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Gather log probs for actual tokens
        token_log_probs = log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
        
        # Mask out padding
        mask = (labels != -100).float()
        sequence_log_prob = (token_log_probs * mask).sum(dim=-1) / mask.sum(dim=-1)
        
        return sequence_log_prob
    
    def dpo_loss(
        self,
        policy_chosen_logps: torch.Tensor,
        policy_rejected_logps: torch.Tensor,
        reference_chosen_logps: torch.Tensor,
        reference_rejected_logps: torch.Tensor
    ) -> torch.Tensor:
        """Compute DPO loss."""
        
        # Log ratios
        chosen_logratios = policy_chosen_logps - reference_chosen_logps
        rejected_logratios = policy_rejected_logps - reference_rejected_logps
        
        # DPO loss
        logits = self.beta * (chosen_logratios - rejected_logratios)
        
        if self.label_smoothing > 0:
            losses = (
                -F.logsigmoid(logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-logits) * self.label_smoothing
            )
        else:
            losses = -F.logsigmoid(logits)
        
        return losses.mean()
    
    def train_step(
        self,
        chosen_input_ids: torch.Tensor,
        chosen_attention_mask: torch.Tensor,
        chosen_labels: torch.Tensor,
        rejected_input_ids: torch.Tensor,
        rejected_attention_mask: torch.Tensor,
        rejected_labels: torch.Tensor
    ) -> Dict:
        """Single DPO training step."""
        
        # Policy log probs
        policy_chosen_logps = self.compute_log_probs(
            self.policy, chosen_input_ids, chosen_labels, chosen_attention_mask
        )
        policy_rejected_logps = self.compute_log_probs(
            self.policy, rejected_input_ids, rejected_labels, rejected_attention_mask
        )
        
        # Reference log probs (no gradient)
        with torch.no_grad():
            reference_chosen_logps = self.compute_log_probs(
                self.reference, chosen_input_ids, chosen_labels, chosen_attention_mask
            )
            reference_rejected_logps = self.compute_log_probs(
                self.reference, rejected_input_ids, rejected_labels, rejected_attention_mask
            )
        
        # DPO loss
        loss = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps
        )
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
        self.optimizer.step()
        
        # Compute metrics
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps)
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps)
        reward_margin = (chosen_rewards - rejected_rewards).mean()
        accuracy = (chosen_rewards > rejected_rewards).float().mean()
        
        return {
            'loss': loss.item(),
            'reward_margin': reward_margin.item(),
            'accuracy': accuracy.item()
        }

Other RL Objectives

REINFORCE with Baseline for Text

class REINFORCETextTrainer:
    """REINFORCE for text generation tasks."""
    
    def __init__(
        self,
        model: nn.Module,
        reward_fn,  # Function: text -> reward
        baseline_model: Optional[nn.Module] = None
    ):
        self.model = model
        self.reward_fn = reward_fn
        self.baseline = baseline_model
        
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    def generate_with_log_probs(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 50
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate sequence and collect log probabilities."""
        
        generated = input_ids.clone()
        log_probs = []
        
        for _ in range(max_new_tokens):
            outputs = self.model(generated)
            next_token_logits = outputs.logits[:, -1, :]
            
            # Sample next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Store log prob
            log_prob = F.log_softmax(next_token_logits, dim=-1)
            token_log_prob = log_prob.gather(1, next_token).squeeze(-1)
            log_probs.append(token_log_prob)
            
            # Append to sequence
            generated = torch.cat([generated, next_token], dim=1)
            
            # Check for EOS
            # (simplified - would check for actual EOS token)
        
        return generated, torch.stack(log_probs, dim=1)
    
    def train_step(
        self,
        prompts: torch.Tensor,
        tokenizer
    ):
        """REINFORCE training step."""
        
        # Generate samples
        generated, log_probs = self.generate_with_log_probs(prompts)
        
        # Decode to text
        texts = tokenizer.batch_decode(generated, skip_special_tokens=True)
        
        # Get rewards
        rewards = torch.tensor([self.reward_fn(t) for t in texts]).cuda()
        
        # Compute baseline if available
        if self.baseline is not None:
            with torch.no_grad():
                baseline_values = self.baseline(prompts)
            advantages = rewards - baseline_values
        else:
            advantages = rewards - rewards.mean()
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Policy gradient loss
        loss = -(log_probs.sum(dim=1) * advantages).mean()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return {
            'loss': loss.item(),
            'mean_reward': rewards.mean().item()
        }

Reward-Weighted Regression

class RewardWeightedRegression:
    """
    AWR / RWR style training.
    Simpler than PPO, works well for offline RL.
    """
    
    def __init__(
        self,
        model: nn.Module,
        beta: float = 1.0,  # Temperature for reward weighting
        top_k: float = 0.5  # Only use top k% of samples
    ):
        self.model = model
        self.beta = beta
        self.top_k = top_k
        
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    def compute_weights(self, rewards: torch.Tensor) -> torch.Tensor:
        """Compute importance weights from rewards."""
        
        # Normalize rewards
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        
        # Exponential weighting
        weights = torch.exp(self.beta * rewards)
        
        # Top-k filtering
        k = int(len(weights) * self.top_k)
        threshold = torch.topk(weights, k).values[-1]
        weights = weights * (weights >= threshold)
        
        # Normalize
        weights = weights / weights.sum()
        
        return weights
    
    def train_step(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        rewards: torch.Tensor,
        attention_mask: torch.Tensor
    ):
        """Weighted maximum likelihood training."""
        
        # Compute weights
        weights = self.compute_weights(rewards)
        
        # Forward pass
        outputs = self.model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits[:, :-1]
        targets = labels[:, 1:]
        
        # Cross-entropy loss per sample
        loss_per_sample = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1),
            reduction='none'
        ).view(logits.size(0), -1).mean(dim=1)
        
        # Weighted loss
        loss = (weights * loss_per_sample).sum()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return {'loss': loss.item()}

Best Practices

def rlhf_best_practices():
    """Key practices for successful RLHF training."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════╗
    ║                    RLHF BEST PRACTICES                         ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. REWARD MODEL                                               ║
    ║     • Use diverse, high-quality preference data                ║
    ║     • Include clear wins AND close comparisons                 ║
    ║     • Validate reward model accuracy (>70% preferred)          ║
    ║     • Monitor for reward hacking                               ║
    ║                                                                ║
    ║  2. KL PENALTY                                                 ║
    ║     • Start with β = 0.01-0.1                                  ║
    ║     • Increase if model diverges too much                      ║
    ║     • Monitor KL divergence during training                    ║
    ║     • Use adaptive KL (target KL penalty)                      ║
    ║                                                                ║
    ║  3. PPO HYPERPARAMETERS                                        ║
    ║     • Clip epsilon: 0.1-0.2                                    ║
    ║     • Value coefficient: 0.5-1.0                               ║
    ║     • Entropy coefficient: 0.01                                ║
    ║     • Use GAE (λ = 0.95)                                       ║
    ║                                                                ║
    ║  4. TRAINING STABILITY                                         ║
    ║     • Very low learning rate (1e-6 to 5e-6)                    ║
    ║     • Gradient clipping (1.0)                                  ║
    ║     • Warm up learning rate                                    ║
    ║     • Save checkpoints frequently                              ║
    ║                                                                ║
    ║  5. EVALUATION                                                 ║
    ║     • Use held-out preference data                             ║
    ║     • Win rate vs reference model                              ║
    ║     • Human evaluation on diverse prompts                      ║
    ║     • Check for regression on capabilities                     ║
    ║                                                                ║
    ║  6. DPO vs PPO                                                 ║
    ║     • DPO: Simpler, no reward model, offline                   ║
    ║     • PPO: More flexible, online, can iterate                  ║
    ║     • DPO often sufficient for single-turn                     ║
    ║     • PPO better for complex, multi-turn scenarios             ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(tips)

rlhf_best_practices()

Exercises

Implement Group Relative Policy Optimization:
class GRPO:
    def compute_group_advantages(self, rewards, group_size):
        # For each prompt, generate group_size responses
        # Compute advantages relative to group mean
        # No value function needed!
Train an ensemble of reward models for more robust preferences:
class RewardEnsemble:
    def __init__(self, models):
        self.models = models
    
    def predict(self, x):
        rewards = [m(x) for m in self.models]
        return torch.stack(rewards).mean(dim=0)
    
    def uncertainty(self, x):
        # Use disagreement as uncertainty
Implement Identity Preference Optimization:
# IPO loss is simpler than DPO
def ipo_loss(chosen_logps, rejected_logps, ref_chosen, ref_rejected):
    h_chosen = chosen_logps - ref_chosen
    h_rejected = rejected_logps - ref_rejected
    return ((h_chosen - h_rejected - 1/beta) ** 2).mean()

What’s Next?