Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

Quantization Deep Dive

Quantization Deep Dive

Why Quantization?

Think of quantization like compressing a high-resolution photograph into a JPEG: you are trading some fidelity for a dramatically smaller file that loads faster. In neural networks, we replace 32-bit floating-point weights with lower-precision integers (8-bit, 4-bit, or even fewer). The model gets smaller, inference gets faster, and — if done carefully — accuracy barely moves. This is not a theoretical trick: virtually every model you interact with on your phone (keyboard prediction, face unlock, voice assistant) runs quantized. The math is simple but the gains are dramatic:
PrecisionBitsMemorySpeedTypical Accuracy Loss
FP32321x1xBaseline
FP16160.5x~2xNegligible
INT880.25x~4xLess than 1% for most models
INT440.125x~8x1-3%, varies by architecture
Quantization is not free lunch. Models with very narrow weight distributions (e.g., small transformer layers) quantize well, but models with outlier weights (common in large language models) can degrade significantly. Always measure accuracy after quantization — do not assume the table above applies to your specific model.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization as quant
from torch.quantization import QConfig, default_observer, default_weight_observer
from typing import Optional, Tuple, Dict, List, Callable
import numpy as np
from dataclasses import dataclass

torch.manual_seed(42)

Quantization Fundamentals

At its core, quantization is a mapping problem: given a continuous range of floating-point values, map them to a smaller set of discrete integer values. The art is choosing this mapping so that the values the model cares about most (the ones that actually affect predictions) are represented accurately, while values in less-important ranges can tolerate more error. Think of it like a piano: FP32 is a piano with thousands of keys spanning the full audible range. INT8 is a piano with only 256 keys. Quantization decides which 256 pitches to include. A good quantization scheme puts more keys where the music actually plays (near the weight distribution’s center) and fewer keys in the rarely-used extremes.

Number Representation

@dataclass
class QuantizationParams:
    """Parameters for quantization mapping."""
    scale: float
    zero_point: int
    qmin: int
    qmax: int
    dtype: torch.dtype = torch.qint8


class QuantizationMath:
    """
    Core quantization operations.
    
    Quantization maps floating point values to integers:
    q = round(x / scale) + zero_point
    x = (q - zero_point) * scale
    
    For symmetric quantization:
    scale = max(|x|) / (qmax)
    zero_point = 0
    
    For asymmetric quantization:
    scale = (max(x) - min(x)) / (qmax - qmin)
    zero_point = round(-min(x) / scale)
    """
    
    @staticmethod
    def compute_scale_zero_point(
        tensor: torch.Tensor,
        qmin: int = -128,
        qmax: int = 127,
        symmetric: bool = True
    ) -> Tuple[float, int]:
        """
        Compute scale and zero point for quantization.
        
        Args:
            tensor: Input tensor to quantize
            qmin, qmax: Quantization range
            symmetric: Use symmetric quantization
        """
        x_min = tensor.min().item()
        x_max = tensor.max().item()
        
        if symmetric:
            # Symmetric: zero_point = 0, scale based on max abs
            max_abs = max(abs(x_min), abs(x_max))
            scale = max_abs / qmax if max_abs != 0 else 1.0
            zero_point = 0
        else:
            # Asymmetric: full range mapping
            scale = (x_max - x_min) / (qmax - qmin) if x_max != x_min else 1.0
            zero_point = round(-x_min / scale) + qmin
            zero_point = max(qmin, min(qmax, zero_point))
        
        return scale, zero_point
    
    @staticmethod
    def quantize(
        tensor: torch.Tensor,
        scale: float,
        zero_point: int,
        qmin: int = -128,
        qmax: int = 127
    ) -> torch.Tensor:
        """Quantize a floating point tensor."""
        q = torch.round(tensor / scale) + zero_point
        q = torch.clamp(q, qmin, qmax)
        return q.to(torch.int8)
    
    @staticmethod
    def dequantize(
        q_tensor: torch.Tensor,
        scale: float,
        zero_point: int
    ) -> torch.Tensor:
        """Dequantize an integer tensor."""
        return (q_tensor.float() - zero_point) * scale


# Demonstrate quantization
def demo_quantization():
    """Show quantization in action."""
    
    # Original values
    x = torch.tensor([0.1, 0.5, 1.0, -0.3, 2.0, -1.5])
    
    # Compute parameters
    scale, zp = QuantizationMath.compute_scale_zero_point(x)
    print(f"Scale: {scale:.4f}, Zero Point: {zp}")
    
    # Quantize
    q = QuantizationMath.quantize(x, scale, zp)
    print(f"Quantized: {q}")
    
    # Dequantize
    x_deq = QuantizationMath.dequantize(q, scale, zp)
    print(f"Dequantized: {x_deq}")
    
    # Error
    error = torch.abs(x - x_deq)
    print(f"Quantization error: {error}")

demo_quantization()

Granularity Levels

class QuantizationGranularity:
    """
    Different levels of quantization granularity.
    
    Finer granularity = better accuracy but more overhead.
    """
    
    @staticmethod
    def per_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, int]:
        """
        One scale/zero_point for entire tensor.
        
        Pros: Simple, minimal overhead
        Cons: Less accurate for varying distributions
        """
        scale, zp = QuantizationMath.compute_scale_zero_point(tensor)
        q = QuantizationMath.quantize(tensor, scale, zp)
        return q, scale, zp
    
    @staticmethod
    def per_channel(tensor: torch.Tensor, axis: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Separate scale/zero_point per channel.
        
        Pros: Better accuracy for weights
        Cons: More storage for parameters
        """
        n_channels = tensor.shape[axis]
        scales = torch.zeros(n_channels)
        zero_points = torch.zeros(n_channels, dtype=torch.int32)
        q = torch.zeros_like(tensor, dtype=torch.int8)
        
        for c in range(n_channels):
            if axis == 0:
                channel_data = tensor[c]
            else:
                channel_data = tensor.select(axis, c)
            
            scale, zp = QuantizationMath.compute_scale_zero_point(channel_data)
            scales[c] = scale
            zero_points[c] = zp
            
            if axis == 0:
                q[c] = QuantizationMath.quantize(channel_data, scale, zp)
        
        return q, scales, zero_points
    
    @staticmethod
    def per_group(
        tensor: torch.Tensor,
        group_size: int = 128
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Separate scale/zero_point per group of elements.
        
        Used in LLM quantization (e.g., GPTQ, AWQ).
        """
        flat = tensor.flatten()
        n_groups = (len(flat) + group_size - 1) // group_size
        
        # Pad if necessary
        padded = F.pad(flat, (0, n_groups * group_size - len(flat)))
        groups = padded.view(n_groups, group_size)
        
        scales = torch.zeros(n_groups)
        zero_points = torch.zeros(n_groups, dtype=torch.int32)
        q_groups = torch.zeros_like(groups, dtype=torch.int8)
        
        for i in range(n_groups):
            scale, zp = QuantizationMath.compute_scale_zero_point(groups[i])
            scales[i] = scale
            zero_points[i] = zp
            q_groups[i] = QuantizationMath.quantize(groups[i], scale, zp)
        
        q = q_groups.flatten()[:len(flat)].view(tensor.shape)
        return q, scales, zero_points

Post-Training Quantization (PTQ)

PTQ is the “quick and easy” path to quantization: take a fully-trained FP32 model and convert it to lower precision without any additional training. The appeal is obvious — no retraining means no GPU hours, no hyperparameter tuning, and you can quantize models you do not even have the training data for. The trade-off is that PTQ typically loses more accuracy than QAT (Quantization-Aware Training), especially at very low bit widths (4-bit).
For most models, start with dynamic quantization (quantize weights, compute activation scales at runtime). It requires zero calibration data and provides 2-4x speedup on CPU inference with typically under 1% accuracy loss. Only move to static quantization if dynamic does not meet your latency requirements, since static quantization requires collecting calibration data but provides additional speedup by pre-computing activation scales.
class PostTrainingQuantization:
    """
    Quantize a trained model without retraining.
    
    Two flavors:
    1. Dynamic quantization: Weights are quantized offline, activations are
       quantized on-the-fly at runtime. Zero calibration data needed.
       Best for: RNNs, Transformers, CPU inference.
    2. Static quantization: Both weights and activations are quantized using
       pre-calibrated scales. Requires representative data.
       Best for: CNNs, mobile/edge deployment.
    """
    
    @staticmethod
    def dynamic_quantization(
        model: nn.Module,
        dtype: torch.dtype = torch.qint8
    ) -> nn.Module:
        """
        Dynamic quantization: weights are quantized,
        activations are quantized dynamically at runtime.
        
        Best for: RNNs, Transformers, models with variable input
        """
        quantized_model = torch.quantization.quantize_dynamic(
            model,
            {nn.Linear, nn.LSTM, nn.GRU},  # Layers to quantize
            dtype=dtype
        )
        return quantized_model
    
    @staticmethod
    def static_quantization(
        model: nn.Module,
        calibration_data: torch.utils.data.DataLoader,
        backend: str = 'fbgemm'  # or 'qnnpack' for mobile
    ) -> nn.Module:
        """
        Static quantization: both weights and activations
        are quantized using calibrated scales.
        
        Requires: Representative calibration data
        """
        # 1. Prepare model
        model.eval()
        
        # Set quantization config
        model.qconfig = torch.quantization.get_default_qconfig(backend)
        
        # Fuse operations (Conv+BN+ReLU, Linear+ReLU, etc.)
        model_fused = torch.quantization.fuse_modules(
            model, 
            [['conv', 'bn', 'relu']]  # Example fusion
        )
        
        # 2. Insert observers
        model_prepared = torch.quantization.prepare(model_fused)
        
        # 3. Calibrate with representative data
        with torch.no_grad():
            for batch in calibration_data:
                if isinstance(batch, (tuple, list)):
                    inputs = batch[0]
                else:
                    inputs = batch
                model_prepared(inputs)
        
        # 4. Convert to quantized model
        model_quantized = torch.quantization.convert(model_prepared)
        
        return model_quantized


class CalibrationMethods:
    """
    Methods to determine optimal quantization parameters.
    """
    
    @staticmethod
    def min_max_calibration(activations: List[torch.Tensor]) -> Tuple[float, float]:
        """Simple min/max calibration."""
        all_min = min(a.min().item() for a in activations)
        all_max = max(a.max().item() for a in activations)
        return all_min, all_max
    
    @staticmethod
    def percentile_calibration(
        activations: List[torch.Tensor],
        percentile: float = 99.99
    ) -> Tuple[float, float]:
        """
        Use percentiles instead of min/max.
        
        More robust to outliers.
        """
        all_values = torch.cat([a.flatten() for a in activations])
        low = np.percentile(all_values.numpy(), 100 - percentile)
        high = np.percentile(all_values.numpy(), percentile)
        return low, high
    
    @staticmethod
    def entropy_calibration(
        activations: List[torch.Tensor],
        num_bins: int = 2048,
        num_quantized_bins: int = 128
    ) -> float:
        """
        KL divergence based calibration (TensorRT style).
        
        Find threshold that minimizes KL divergence between
        original and quantized distributions.
        """
        all_values = torch.cat([a.flatten() for a in activations])
        
        # Build histogram
        hist, bin_edges = np.histogram(all_values.numpy(), bins=num_bins)
        
        # Try different thresholds
        best_threshold = bin_edges[-1]
        best_kl = float('inf')
        
        for i in range(num_quantized_bins, num_bins):
            # Reference distribution (up to threshold)
            p = hist[:i].copy().astype(np.float64)
            p[p == 0] = 1e-10
            p /= p.sum()
            
            # Quantized distribution
            q = np.zeros(num_quantized_bins, dtype=np.float64)
            bin_size = i // num_quantized_bins
            
            for j in range(num_quantized_bins):
                start = j * bin_size
                end = start + bin_size
                q[j] = hist[start:end].sum()
            
            # Expand q back to original size
            q_expanded = np.repeat(q, bin_size)[:i]
            q_expanded = q_expanded.astype(np.float64)
            q_expanded[q_expanded == 0] = 1e-10
            q_expanded /= q_expanded.sum()
            
            # KL divergence
            kl = np.sum(p * np.log(p / q_expanded))
            
            if kl < best_kl:
                best_kl = kl
                best_threshold = bin_edges[i]
        
        return best_threshold

Quantization-Aware Training (QAT)

class QuantizationAwareTraining:
    """
    Train with simulated quantization for better accuracy.
    
    Key idea: Insert fake quantization during training
    so the model learns to be robust to quantization noise.
    """
    
    @staticmethod
    def prepare_qat(
        model: nn.Module,
        backend: str = 'fbgemm'
    ) -> nn.Module:
        """Prepare model for QAT."""
        model.train()
        
        # QAT config with fake quantization
        model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
        
        # Prepare for QAT (inserts fake quant modules)
        model_prepared = torch.quantization.prepare_qat(model)
        
        return model_prepared
    
    @staticmethod
    def convert_qat(model: nn.Module) -> nn.Module:
        """Convert QAT model to quantized model."""
        model.eval()
        model_quantized = torch.quantization.convert(model)
        return model_quantized


class FakeQuantize(torch.autograd.Function):
    """
    Fake quantization for training.
    
    Forward: Quantize then immediately dequantize
    Backward: Straight-through estimator (gradient unchanged)
    """
    
    @staticmethod
    def forward(ctx, x, scale, zero_point, qmin=-128, qmax=127):
        # Quantize
        q = torch.round(x / scale) + zero_point
        q = torch.clamp(q, qmin, qmax)
        
        # Immediately dequantize
        x_q = (q - zero_point) * scale
        
        # Save for backward (for gradient clipping)
        ctx.save_for_backward(x, torch.tensor([qmin * scale, qmax * scale]))
        
        return x_q
    
    @staticmethod
    def backward(ctx, grad_output):
        x, bounds = ctx.saved_tensors
        qmin_scaled, qmax_scaled = bounds[0].item(), bounds[1].item()
        
        # Straight-through estimator with clipping
        # Zero gradient for values outside quantization range
        grad_input = grad_output.clone()
        grad_input[x < qmin_scaled] = 0
        grad_input[x > qmax_scaled] = 0
        
        return grad_input, None, None, None, None


class QATLinear(nn.Module):
    """
    Quantization-aware linear layer.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True
    ):
        super().__init__()
        
        self.linear = nn.Linear(in_features, out_features, bias)
        
        # Observers for scale/zero_point
        self.weight_fake_quant = torch.quantization.FakeQuantize(
            observer=torch.quantization.MovingAverageMinMaxObserver,
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8
        )
        
        self.activation_fake_quant = torch.quantization.FakeQuantize(
            observer=torch.quantization.MovingAverageMinMaxObserver,
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Fake quantize input
        x_q = self.activation_fake_quant(x)
        
        # Fake quantize weights
        w_q = self.weight_fake_quant(self.linear.weight)
        
        # Compute with fake-quantized values
        return F.linear(x_q, w_q, self.linear.bias)


# LSQ: Learned Step Size Quantization
class LSQQuantizer(nn.Module):
    """
    Learned Step Size Quantization.
    
    Learn the quantization step size (scale) during training
    instead of using fixed calibration.
    
    Reference: "Learned Step Size Quantization" (Esser et al., 2020)
    """
    
    def __init__(
        self,
        num_bits: int = 8,
        symmetric: bool = True,
        per_channel: bool = False,
        num_channels: int = 1
    ):
        super().__init__()
        
        self.num_bits = num_bits
        self.symmetric = symmetric
        self.per_channel = per_channel
        
        # Compute quantization bounds
        if symmetric:
            self.qmin = -2**(num_bits - 1)
            self.qmax = 2**(num_bits - 1) - 1
        else:
            self.qmin = 0
            self.qmax = 2**num_bits - 1
        
        # Learnable step size (scale)
        if per_channel:
            self.scale = nn.Parameter(torch.ones(num_channels))
        else:
            self.scale = nn.Parameter(torch.tensor(1.0))
        
        self.initialized = False
    
    def init_scale(self, x: torch.Tensor):
        """Initialize scale based on data."""
        if self.per_channel:
            # Per-channel initialization
            for c in range(x.shape[0]):
                max_val = x[c].abs().max()
                self.scale.data[c] = 2 * max_val / (self.qmax - self.qmin)
        else:
            max_val = x.abs().max()
            self.scale.data.fill_(2 * max_val / (self.qmax - self.qmin))
        
        self.initialized = True
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.initialized:
            self.init_scale(x)
        
        # Ensure positive scale
        scale = self.scale.abs() + 1e-8
        
        # Quantize
        if self.per_channel:
            scale_shape = [1] * x.dim()
            scale_shape[0] = -1
            scale = scale.view(*scale_shape)
        
        x_scaled = x / scale
        x_clipped = torch.clamp(x_scaled, self.qmin, self.qmax)
        x_rounded = torch.round(x_clipped)
        
        # Dequantize with gradient trick
        x_q = (x_rounded - x_clipped).detach() + x_clipped
        x_deq = x_q * scale
        
        return x_deq

Advanced Quantization Methods

The methods above (PTQ, QAT) work well for models up to a few hundred million parameters. But large language models (LLMs) with billions of parameters present unique challenges: they have outlier activations that break standard quantization, retraining them (QAT) costs millions of dollars, and they need to run on consumer hardware. The following methods were specifically designed for this regime.
For quantizing LLMs in practice, the landscape in 2024-2025 has converged: use GPTQ or AWQ for weight-only quantization (good for inference), and use bitsandbytes for training (QLoRA). The choice between GPTQ and AWQ is mostly about toolchain preference — accuracy is similar. For models under 7B parameters, INT8 is usually sufficient. For 13B+ models, INT4 with group quantization (group_size=128) gives the best size-accuracy trade-off.

GPTQ (LLM Quantization)

class GPTQQuantizer:
    """
    GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
    
    Key idea: Use second-order information (Hessian) to minimize
    quantization error layer by layer. When you quantize one weight,
    adjust its neighbors to compensate -- like adjusting the rest of
    a jigsaw puzzle when one piece is slightly off.
    
    Simplified implementation for educational purposes.
    In practice, use the `auto-gptq` or `gptqmodel` library.
    """
    
    def __init__(
        self,
        num_bits: int = 4,
        group_size: int = 128,
        symmetric: bool = True
    ):
        self.num_bits = num_bits
        self.group_size = group_size
        self.symmetric = symmetric
        
        if symmetric:
            self.qmin = -2**(num_bits - 1)
            self.qmax = 2**(num_bits - 1) - 1
        else:
            self.qmin = 0
            self.qmax = 2**num_bits - 1
    
    def quantize_layer(
        self,
        weight: torch.Tensor,
        H: torch.Tensor  # Hessian (X @ X.T from calibration)
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Quantize a weight matrix using GPTQ algorithm.
        
        Args:
            weight: Weight matrix [out_features, in_features]
            H: Hessian matrix [in_features, in_features]
        """
        W = weight.clone().float()
        d, n = W.shape
        
        # Add damping to Hessian for numerical stability
        damp = 0.01 * torch.diag(H).mean()
        H = H + damp * torch.eye(n)
        
        # Cholesky decomposition
        H_inv = torch.linalg.cholesky(H)
        H_inv = torch.cholesky_inverse(H_inv)
        
        # Process in groups
        Q = torch.zeros_like(W)
        scales = torch.zeros(d, (n + self.group_size - 1) // self.group_size)
        
        for i1 in range(0, n, self.group_size):
            i2 = min(i1 + self.group_size, n)
            
            # Get group weights
            W_group = W[:, i1:i2].clone()
            
            # Compute scale for group
            if self.symmetric:
                max_val = W_group.abs().max(dim=1, keepdim=True)[0]
                scale = max_val / self.qmax
            else:
                min_val = W_group.min(dim=1, keepdim=True)[0]
                max_val = W_group.max(dim=1, keepdim=True)[0]
                scale = (max_val - min_val) / (self.qmax - self.qmin)
            
            scale = scale.clamp(min=1e-10)
            scales[:, i1 // self.group_size] = scale.squeeze()
            
            # Quantize columns one by one
            for i in range(i1, i2):
                # Quantize column
                q = torch.round(W[:, i] / scale.squeeze())
                q = torch.clamp(q, self.qmin, self.qmax)
                Q[:, i] = q
                
                # Compute quantization error
                error = W[:, i] - q * scale.squeeze()
                
                # Update remaining weights to compensate
                if i < n - 1:
                    W[:, i+1:] -= error.unsqueeze(1) * H_inv[i, i+1:].unsqueeze(0)
        
        return Q.to(torch.int8), scales
    
    def collect_hessian(
        self,
        model: nn.Module,
        calibration_data: torch.utils.data.DataLoader,
        layer_name: str
    ) -> torch.Tensor:
        """
        Collect Hessian for a layer using calibration data.
        
        H = sum(X @ X.T) where X is the input to the layer.
        """
        H = None
        
        def hook(module, input, output):
            nonlocal H
            x = input[0]
            x = x.view(-1, x.shape[-1]).float()
            
            if H is None:
                H = x.T @ x
            else:
                H += x.T @ x
        
        # Register hook
        target_module = dict(model.named_modules())[layer_name]
        handle = target_module.register_forward_hook(hook)
        
        # Run calibration
        model.eval()
        with torch.no_grad():
            for batch in calibration_data:
                if isinstance(batch, (tuple, list)):
                    inputs = batch[0]
                else:
                    inputs = batch
                model(inputs)
        
        handle.remove()
        
        return H


class AWQQuantizer:
    """
    Activation-aware Weight Quantization.
    
    Key insight: Not all weights are equally important.
    Scale weights by their activation magnitudes.
    
    Simplified implementation.
    """
    
    def __init__(self, num_bits: int = 4):
        self.num_bits = num_bits
        self.qmin = 0
        self.qmax = 2**num_bits - 1
    
    def compute_importance(
        self,
        weight: torch.Tensor,
        activations: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute importance of each weight column
        based on activation magnitudes.
        """
        # Average activation magnitude per input dimension
        act_cat = torch.cat([a.flatten(end_dim=-2) for a in activations])
        importance = act_cat.abs().mean(dim=0)
        
        return importance
    
    def quantize_with_scaling(
        self,
        weight: torch.Tensor,
        importance: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize weights with importance-aware scaling.
        """
        # Find optimal scaling factor
        best_error = float('inf')
        best_scale = 1.0
        
        for alpha in np.linspace(0.1, 1.0, 20):
            # Scale by importance^alpha
            scaling = importance.pow(alpha)
            scaling = scaling / scaling.max()  # Normalize
            
            # Apply scaling
            W_scaled = weight * scaling.unsqueeze(0)
            
            # Quantize
            max_val = W_scaled.abs().max()
            scale = max_val / self.qmax
            W_q = torch.round(W_scaled / scale)
            W_q = torch.clamp(W_q, self.qmin, self.qmax)
            
            # Dequantize and undo scaling
            W_deq = W_q * scale / scaling.unsqueeze(0)
            
            # Compute error
            error = ((weight - W_deq) ** 2).mean()
            
            if error < best_error:
                best_error = error
                best_scale = alpha
        
        # Final quantization with best scaling
        scaling = importance.pow(best_scale)
        scaling = scaling / scaling.max()
        
        W_scaled = weight * scaling.unsqueeze(0)
        max_val = W_scaled.abs().max()
        quant_scale = max_val / self.qmax
        
        W_q = torch.round(W_scaled / quant_scale)
        W_q = torch.clamp(W_q, self.qmin, self.qmax)
        
        return W_q.to(torch.int8), quant_scale, scaling

Mixed-Precision Quantization

class MixedPrecisionQuantizer:
    """
    Different bit-widths for different layers.
    
    Sensitive layers get higher precision,
    less important layers get lower precision.
    """
    
    def __init__(self, default_bits: int = 8):
        self.default_bits = default_bits
        self.layer_bits = {}
        
    def analyze_sensitivity(
        self,
        model: nn.Module,
        calibration_data: torch.utils.data.DataLoader,
        criterion: nn.Module
    ) -> Dict[str, float]:
        """
        Analyze each layer's sensitivity to quantization.
        
        Method: Quantize each layer individually and measure
        the impact on model output/loss.
        """
        sensitivity = {}
        
        model.eval()
        
        # Get baseline output
        with torch.no_grad():
            baseline_outputs = []
            for batch in calibration_data:
                inputs, targets = batch
                outputs = model(inputs)
                baseline_outputs.append(outputs)
        
        # Test each layer
        for name, module in model.named_modules():
            if not isinstance(module, (nn.Linear, nn.Conv2d)):
                continue
            
            # Save original weights
            orig_weight = module.weight.data.clone()
            
            # Quantize this layer to low precision
            scale, zp = QuantizationMath.compute_scale_zero_point(orig_weight)
            q_weight = QuantizationMath.quantize(orig_weight, scale, zp)
            deq_weight = QuantizationMath.dequantize(q_weight, scale, zp)
            
            module.weight.data = deq_weight
            
            # Measure impact
            total_error = 0
            with torch.no_grad():
                for i, batch in enumerate(calibration_data):
                    inputs, targets = batch
                    outputs = model(inputs)
                    error = F.mse_loss(outputs, baseline_outputs[i])
                    total_error += error.item()
            
            sensitivity[name] = total_error
            
            # Restore original weights
            module.weight.data = orig_weight
        
        return sensitivity
    
    def assign_bit_widths(
        self,
        sensitivity: Dict[str, float],
        target_size_ratio: float = 0.25,
        bit_options: List[int] = [2, 4, 8]
    ) -> Dict[str, int]:
        """
        Assign bit-widths to layers based on sensitivity.
        
        More sensitive layers get more bits.
        """
        # Normalize sensitivity scores
        max_sens = max(sensitivity.values())
        normalized = {k: v / max_sens for k, v in sensitivity.items()}
        
        # Assign bits based on sensitivity
        # High sensitivity -> more bits
        self.layer_bits = {}
        
        for name, sens in normalized.items():
            if sens > 0.7:
                bits = max(bit_options)
            elif sens > 0.3:
                bits = bit_options[len(bit_options) // 2]
            else:
                bits = min(bit_options)
            
            self.layer_bits[name] = bits
        
        return self.layer_bits
    
    def quantize_model(self, model: nn.Module) -> nn.Module:
        """Apply mixed-precision quantization."""
        
        for name, module in model.named_modules():
            if name not in self.layer_bits:
                continue
            
            bits = self.layer_bits[name]
            qmin = -2**(bits - 1)
            qmax = 2**(bits - 1) - 1
            
            if hasattr(module, 'weight'):
                w = module.weight.data
                scale, zp = QuantizationMath.compute_scale_zero_point(
                    w, qmin=qmin, qmax=qmax
                )
                q = QuantizationMath.quantize(w, scale, zp, qmin, qmax)
                module.weight.data = QuantizationMath.dequantize(q, scale, zp)
        
        return model

Hardware-Aware Quantization

class HardwareAwareQuantization:
    """
    Quantization considering target hardware constraints.
    """
    
    HARDWARE_PROFILES = {
        'cpu_x86': {
            'supported_dtypes': ['int8'],
            'optimal_alignment': 32,
            'vectorized_ops': True
        },
        'gpu_nvidia': {
            'supported_dtypes': ['int8', 'fp16', 'bf16', 'fp8'],
            'tensor_cores': True,
            'optimal_batch': 64
        },
        'mobile_arm': {
            'supported_dtypes': ['int8', 'int4'],
            'neon_simd': True,
            'memory_limited': True
        },
        'edge_tpu': {
            'supported_dtypes': ['int8'],
            'requires_symmetric': True,
            'batch_size': 1
        }
    }
    
    def __init__(self, target: str = 'cpu_x86'):
        self.target = target
        self.profile = self.HARDWARE_PROFILES.get(target, {})
    
    def get_optimal_config(self) -> Dict:
        """Get optimal quantization config for target."""
        config = {
            'dtype': self.profile.get('supported_dtypes', ['int8'])[0],
            'symmetric': self.profile.get('requires_symmetric', False),
        }
        
        if self.target == 'gpu_nvidia':
            config['use_tensor_cores'] = True
            config['dtype'] = 'int8'  # For INT8 tensor core ops
        
        return config
    
    def estimate_speedup(
        self,
        model: nn.Module,
        original_dtype: str = 'fp32'
    ) -> Dict[str, float]:
        """
        Estimate speedup from quantization.
        
        Very rough estimates - actual speedup varies by hardware.
        """
        speedups = {
            'fp32_to_fp16': 1.5,
            'fp32_to_int8': 2.5,
            'fp32_to_int4': 4.0
        }
        
        target_dtype = self.profile.get('supported_dtypes', ['int8'])[0]
        key = f'{original_dtype}_to_{target_dtype}'
        
        base_speedup = speedups.get(key, 1.0)
        
        # Adjust for hardware features
        if self.profile.get('tensor_cores'):
            base_speedup *= 1.5
        if self.profile.get('vectorized_ops'):
            base_speedup *= 1.2
        
        return {
            'compute_speedup': base_speedup,
            'memory_reduction': 32 / int(target_dtype.replace('int', '').replace('fp', ''))
        }


# Summary of quantization methods
def quantization_summary():
    """Print quantization methods summary."""
    
    summary = """
    ╔════════════════════════════════════════════════════════════════════╗
    ║                    QUANTIZATION METHODS SUMMARY                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                                                                     ║
    ║  POST-TRAINING QUANTIZATION (PTQ)                                   ║
    ║  • Dynamic: Weights quantized, activations at runtime               ║
    ║  • Static: Both quantized using calibration data                    ║
    ║  • Pros: Fast, no training needed                                   ║
    ║  • Cons: May lose accuracy, especially at low bits                  ║
    ║                                                                     ║
    ║  QUANTIZATION-AWARE TRAINING (QAT)                                  ║
    ║  • Simulate quantization during training                            ║
    ║  • Model learns to be robust to quantization                        ║
    ║  • Pros: Best accuracy                                              ║
    ║  • Cons: Requires training, more complex                            ║
    ║                                                                     ║
    ║  LLM-SPECIFIC METHODS                                               ║
    ║  • GPTQ: Second-order optimization                                  ║
    ║  • AWQ: Activation-aware weight scaling                             ║
    ║  • SqueezeLLM: Sparsity + quantization                              ║
    ║                                                                     ║
    ║  MIXED PRECISION                                                    ║
    ║  • Different bits for different layers                              ║
    ║  • Balance accuracy vs compression                                  ║
    ║                                                                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║  RECOMMENDED APPROACH:                                              ║
    ║  1. Start with PTQ (fast baseline)                                  ║
    ║  2. If accuracy drops, try QAT                                      ║
    ║  3. For LLMs, use GPTQ/AWQ                                          ║
    ║  4. For production, consider mixed precision                        ║
    ╚════════════════════════════════════════════════════════════════════╝
    """
    print(summary)

quantization_summary()

Exercises

Implement and compare:
  • Min-max calibration
  • Percentile calibration
  • Entropy calibration
Measure accuracy vs. calibration set size.
Implement 4-bit quantization with:
  • Group size of 128
  • Separate scale per group
  • Compare to per-tensor quantization

What’s Next?

Knowledge Distillation

Transfer knowledge to smaller models

Continual Learning

Learn new tasks without forgetting