Skip to main content
Foundation Models

Foundation Models & LLMs

The Foundation Model Paradigm

Foundation models are large models trained on broad data that can be adapted to many downstream tasks. Key characteristics:
  • Scale (billions of parameters)
  • Self-supervised pretraining
  • Emergent capabilities
  • Transfer to diverse tasks

Scaling Laws

The Chinchilla Scaling Law

For compute-optimal training: NoptC0.5,DoptC0.5N_{\text{opt}} \propto C^{0.5}, \quad D_{\text{opt}} \propto C^{0.5} Where:
  • NN = number of parameters
  • DD = dataset size (tokens)
  • CC = compute budget (FLOPs)
Rule of thumb: Train on ~20 tokens per parameter.
ModelParametersTraining TokensRatio
GPT-3175B300B1.7
Chinchilla70B1.4T20
LLaMA 270B2T29
Mistral7BUnknown-

LLM Architecture

Modern Transformer Improvements

import torch
import torch.nn as nn
import torch.nn.functional as F

class ModernTransformerBlock(nn.Module):
    """LLaMA-style transformer block with modern improvements."""
    
    def __init__(self, dim, num_heads, mlp_ratio=4, dropout=0.0):
        super().__init__()
        
        # Pre-normalization with RMSNorm
        self.norm1 = RMSNorm(dim)
        self.norm2 = RMSNorm(dim)
        
        # Grouped Query Attention
        self.attn = GroupedQueryAttention(dim, num_heads, num_kv_heads=num_heads // 4)
        
        # SwiGLU MLP
        self.mlp = SwiGLU(dim, int(dim * mlp_ratio * 2/3))
    
    def forward(self, x, freqs_cis=None):
        # Pre-norm + residual
        x = x + self.attn(self.norm1(x), freqs_cis)
        x = x + self.mlp(self.norm2(x))
        return x


class RMSNorm(nn.Module):
    """Root Mean Square Normalization."""
    
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class SwiGLU(nn.Module):
    """SwiGLU activation (better than ReLU/GELU for LLMs)."""
    
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
    
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

Rotary Position Embeddings (RoPE)

def precompute_freqs_cis(dim, max_seq_len, base=10000):
    """Precompute rotary embedding frequencies."""
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)  # Complex exponentials


def apply_rotary_emb(xq, xk, freqs_cis):
    """Apply rotary embeddings to queries and keys."""
    # Reshape to complex
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # Apply rotation
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)

Pretraining Objectives

Causal Language Modeling (GPT-style)

L=t=1TlogP(xtx<t)\mathcal{L} = -\sum_{t=1}^{T} \log P(x_t | x_{<t})
def causal_lm_loss(logits, labels):
    """Next token prediction loss."""
    # Shift so we predict next token
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=-100  # Padding
    )
    return loss

Masked Language Modeling (BERT-style)

def create_mlm_inputs(tokens, mask_prob=0.15, vocab_size=32000):
    """Create masked inputs for MLM training."""
    labels = tokens.clone()
    
    # Random mask selection
    probability_matrix = torch.full(tokens.shape, mask_prob)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    
    # Don't mask special tokens
    labels[~masked_indices] = -100
    
    # 80% [MASK], 10% random, 10% unchanged
    indices_replaced = masked_indices & (torch.rand(tokens.shape) < 0.8)
    tokens[indices_replaced] = MASK_TOKEN_ID
    
    indices_random = masked_indices & ~indices_replaced & (torch.rand(tokens.shape) < 0.5)
    tokens[indices_random] = torch.randint(vocab_size, tokens.shape)[indices_random]
    
    return tokens, labels

Emergent Capabilities

As models scale, new abilities emerge:
ScaleEmergent Capability
~1BBasic language understanding
~10BFew-shot learning
~100BComplex reasoning, code generation
~500B+Multi-step reasoning, tool use

Training LLMs

Distributed Training Setup

import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

def train_llm():
    # Initialize distributed
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    
    # Create model and wrap with FSDP
    model = LLM(config)
    model = FSDP(
        model,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
        ),
        sharding_strategy=ShardingStrategy.FULL_SHARD,
    )
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=3e-4,
        betas=(0.9, 0.95),
        weight_decay=0.1,
    )
    
    # Training loop
    for step, batch in enumerate(dataloader):
        loss = model(batch).loss
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        optimizer.zero_grad()
        
        if rank == 0 and step % 100 == 0:
            print(f"Step {step}: loss = {loss.item():.4f}")

Instruction Tuning

INSTRUCTION_TEMPLATE = """<|system|>
{system_message}
<|user|>
{instruction}
<|assistant|>
{response}"""

def format_instruction_data(example):
    return INSTRUCTION_TEMPLATE.format(
        system_message="You are a helpful assistant.",
        instruction=example["instruction"],
        response=example["response"],
    )

# Fine-tune on instruction dataset
instruction_dataset = load_dataset("instruction_data")
formatted = instruction_dataset.map(format_instruction_data)

RLHF (Reinforcement Learning from Human Feedback)

class RewardModel(nn.Module):
    """Reward model trained on human preferences."""
    
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.reward_head = nn.Linear(base_model.config.hidden_size, 1)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.base(input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state[:, -1, :]  # Last token
        return self.reward_head(last_hidden)


def compute_preference_loss(reward_model, chosen, rejected):
    """Bradley-Terry preference model loss."""
    reward_chosen = reward_model(**chosen)
    reward_rejected = reward_model(**rejected)
    
    loss = -F.logsigmoid(reward_chosen - reward_rejected).mean()
    return loss

PPO Training

def ppo_step(policy, ref_policy, reward_model, prompts, kl_coef=0.1):
    """Single PPO update step."""
    # Generate responses
    responses = policy.generate(prompts)
    
    # Compute rewards
    rewards = reward_model(prompts, responses)
    
    # Compute KL penalty
    with torch.no_grad():
        ref_logprobs = ref_policy.log_prob(prompts, responses)
    policy_logprobs = policy.log_prob(prompts, responses)
    kl = policy_logprobs - ref_logprobs
    
    # Total reward = reward - KL penalty
    total_reward = rewards - kl_coef * kl
    
    # PPO clipped objective
    ratio = torch.exp(policy_logprobs - old_logprobs)
    clipped = torch.clamp(ratio, 1 - eps, 1 + eps)
    loss = -torch.min(ratio * total_reward, clipped * total_reward).mean()
    
    return loss

Using Foundation Models

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load pretrained LLM
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

def generate(prompt, max_tokens=100):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=0.7,
        do_sample=True,
        top_p=0.9,
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

Model Comparison

ModelSizeOpenStrengths
GPT-4~1T?NoMultimodal, reasoning
Claude 3UnknownNoSafety, long context
LLaMA 38B-70BYesOpen, efficient
Mistral7BYesQuality/size ratio
Gemma2B-7BYesSmall, efficient

Exercises

Plot loss vs compute for different model sizes. Verify the Chinchilla scaling law.
Train a small (10M parameter) causal language model on a text corpus.
Fine-tune a small LLM on instruction data using LoRA. Compare before/after.

What’s Next