Skip to main content
Quantization Deep Dive

Quantization Deep Dive

Why Quantization?

Reduce model size and increase inference speed:
PrecisionBitsMemorySpeed
FP32321x1x
FP16160.5x~2x
INT880.25x~4x
INT440.125x~8x
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization as quant
from torch.quantization import QConfig, default_observer, default_weight_observer
from typing import Optional, Tuple, Dict, List, Callable
import numpy as np
from dataclasses import dataclass

torch.manual_seed(42)

Quantization Fundamentals

Number Representation

@dataclass
class QuantizationParams:
    """Parameters for quantization mapping."""
    scale: float
    zero_point: int
    qmin: int
    qmax: int
    dtype: torch.dtype = torch.qint8


class QuantizationMath:
    """
    Core quantization operations.
    
    Quantization maps floating point values to integers:
    q = round(x / scale) + zero_point
    x = (q - zero_point) * scale
    
    For symmetric quantization:
    scale = max(|x|) / (qmax)
    zero_point = 0
    
    For asymmetric quantization:
    scale = (max(x) - min(x)) / (qmax - qmin)
    zero_point = round(-min(x) / scale)
    """
    
    @staticmethod
    def compute_scale_zero_point(
        tensor: torch.Tensor,
        qmin: int = -128,
        qmax: int = 127,
        symmetric: bool = True
    ) -> Tuple[float, int]:
        """
        Compute scale and zero point for quantization.
        
        Args:
            tensor: Input tensor to quantize
            qmin, qmax: Quantization range
            symmetric: Use symmetric quantization
        """
        x_min = tensor.min().item()
        x_max = tensor.max().item()
        
        if symmetric:
            # Symmetric: zero_point = 0, scale based on max abs
            max_abs = max(abs(x_min), abs(x_max))
            scale = max_abs / qmax if max_abs != 0 else 1.0
            zero_point = 0
        else:
            # Asymmetric: full range mapping
            scale = (x_max - x_min) / (qmax - qmin) if x_max != x_min else 1.0
            zero_point = round(-x_min / scale) + qmin
            zero_point = max(qmin, min(qmax, zero_point))
        
        return scale, zero_point
    
    @staticmethod
    def quantize(
        tensor: torch.Tensor,
        scale: float,
        zero_point: int,
        qmin: int = -128,
        qmax: int = 127
    ) -> torch.Tensor:
        """Quantize a floating point tensor."""
        q = torch.round(tensor / scale) + zero_point
        q = torch.clamp(q, qmin, qmax)
        return q.to(torch.int8)
    
    @staticmethod
    def dequantize(
        q_tensor: torch.Tensor,
        scale: float,
        zero_point: int
    ) -> torch.Tensor:
        """Dequantize an integer tensor."""
        return (q_tensor.float() - zero_point) * scale


# Demonstrate quantization
def demo_quantization():
    """Show quantization in action."""
    
    # Original values
    x = torch.tensor([0.1, 0.5, 1.0, -0.3, 2.0, -1.5])
    
    # Compute parameters
    scale, zp = QuantizationMath.compute_scale_zero_point(x)
    print(f"Scale: {scale:.4f}, Zero Point: {zp}")
    
    # Quantize
    q = QuantizationMath.quantize(x, scale, zp)
    print(f"Quantized: {q}")
    
    # Dequantize
    x_deq = QuantizationMath.dequantize(q, scale, zp)
    print(f"Dequantized: {x_deq}")
    
    # Error
    error = torch.abs(x - x_deq)
    print(f"Quantization error: {error}")

demo_quantization()

Granularity Levels

class QuantizationGranularity:
    """
    Different levels of quantization granularity.
    
    Finer granularity = better accuracy but more overhead.
    """
    
    @staticmethod
    def per_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, int]:
        """
        One scale/zero_point for entire tensor.
        
        Pros: Simple, minimal overhead
        Cons: Less accurate for varying distributions
        """
        scale, zp = QuantizationMath.compute_scale_zero_point(tensor)
        q = QuantizationMath.quantize(tensor, scale, zp)
        return q, scale, zp
    
    @staticmethod
    def per_channel(tensor: torch.Tensor, axis: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Separate scale/zero_point per channel.
        
        Pros: Better accuracy for weights
        Cons: More storage for parameters
        """
        n_channels = tensor.shape[axis]
        scales = torch.zeros(n_channels)
        zero_points = torch.zeros(n_channels, dtype=torch.int32)
        q = torch.zeros_like(tensor, dtype=torch.int8)
        
        for c in range(n_channels):
            if axis == 0:
                channel_data = tensor[c]
            else:
                channel_data = tensor.select(axis, c)
            
            scale, zp = QuantizationMath.compute_scale_zero_point(channel_data)
            scales[c] = scale
            zero_points[c] = zp
            
            if axis == 0:
                q[c] = QuantizationMath.quantize(channel_data, scale, zp)
        
        return q, scales, zero_points
    
    @staticmethod
    def per_group(
        tensor: torch.Tensor,
        group_size: int = 128
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Separate scale/zero_point per group of elements.
        
        Used in LLM quantization (e.g., GPTQ, AWQ).
        """
        flat = tensor.flatten()
        n_groups = (len(flat) + group_size - 1) // group_size
        
        # Pad if necessary
        padded = F.pad(flat, (0, n_groups * group_size - len(flat)))
        groups = padded.view(n_groups, group_size)
        
        scales = torch.zeros(n_groups)
        zero_points = torch.zeros(n_groups, dtype=torch.int32)
        q_groups = torch.zeros_like(groups, dtype=torch.int8)
        
        for i in range(n_groups):
            scale, zp = QuantizationMath.compute_scale_zero_point(groups[i])
            scales[i] = scale
            zero_points[i] = zp
            q_groups[i] = QuantizationMath.quantize(groups[i], scale, zp)
        
        q = q_groups.flatten()[:len(flat)].view(tensor.shape)
        return q, scales, zero_points

Post-Training Quantization (PTQ)

class PostTrainingQuantization:
    """
    Quantize a trained model without retraining.
    
    Methods:
    1. Dynamic quantization: Quantize weights, compute scales at runtime
    2. Static quantization: Calibrate with representative data
    """
    
    @staticmethod
    def dynamic_quantization(
        model: nn.Module,
        dtype: torch.dtype = torch.qint8
    ) -> nn.Module:
        """
        Dynamic quantization: weights are quantized,
        activations are quantized dynamically at runtime.
        
        Best for: RNNs, Transformers, models with variable input
        """
        quantized_model = torch.quantization.quantize_dynamic(
            model,
            {nn.Linear, nn.LSTM, nn.GRU},  # Layers to quantize
            dtype=dtype
        )
        return quantized_model
    
    @staticmethod
    def static_quantization(
        model: nn.Module,
        calibration_data: torch.utils.data.DataLoader,
        backend: str = 'fbgemm'  # or 'qnnpack' for mobile
    ) -> nn.Module:
        """
        Static quantization: both weights and activations
        are quantized using calibrated scales.
        
        Requires: Representative calibration data
        """
        # 1. Prepare model
        model.eval()
        
        # Set quantization config
        model.qconfig = torch.quantization.get_default_qconfig(backend)
        
        # Fuse operations (Conv+BN+ReLU, Linear+ReLU, etc.)
        model_fused = torch.quantization.fuse_modules(
            model, 
            [['conv', 'bn', 'relu']]  # Example fusion
        )
        
        # 2. Insert observers
        model_prepared = torch.quantization.prepare(model_fused)
        
        # 3. Calibrate with representative data
        with torch.no_grad():
            for batch in calibration_data:
                if isinstance(batch, (tuple, list)):
                    inputs = batch[0]
                else:
                    inputs = batch
                model_prepared(inputs)
        
        # 4. Convert to quantized model
        model_quantized = torch.quantization.convert(model_prepared)
        
        return model_quantized


class CalibrationMethods:
    """
    Methods to determine optimal quantization parameters.
    """
    
    @staticmethod
    def min_max_calibration(activations: List[torch.Tensor]) -> Tuple[float, float]:
        """Simple min/max calibration."""
        all_min = min(a.min().item() for a in activations)
        all_max = max(a.max().item() for a in activations)
        return all_min, all_max
    
    @staticmethod
    def percentile_calibration(
        activations: List[torch.Tensor],
        percentile: float = 99.99
    ) -> Tuple[float, float]:
        """
        Use percentiles instead of min/max.
        
        More robust to outliers.
        """
        all_values = torch.cat([a.flatten() for a in activations])
        low = np.percentile(all_values.numpy(), 100 - percentile)
        high = np.percentile(all_values.numpy(), percentile)
        return low, high
    
    @staticmethod
    def entropy_calibration(
        activations: List[torch.Tensor],
        num_bins: int = 2048,
        num_quantized_bins: int = 128
    ) -> float:
        """
        KL divergence based calibration (TensorRT style).
        
        Find threshold that minimizes KL divergence between
        original and quantized distributions.
        """
        all_values = torch.cat([a.flatten() for a in activations])
        
        # Build histogram
        hist, bin_edges = np.histogram(all_values.numpy(), bins=num_bins)
        
        # Try different thresholds
        best_threshold = bin_edges[-1]
        best_kl = float('inf')
        
        for i in range(num_quantized_bins, num_bins):
            # Reference distribution (up to threshold)
            p = hist[:i].copy().astype(np.float64)
            p[p == 0] = 1e-10
            p /= p.sum()
            
            # Quantized distribution
            q = np.zeros(num_quantized_bins, dtype=np.float64)
            bin_size = i // num_quantized_bins
            
            for j in range(num_quantized_bins):
                start = j * bin_size
                end = start + bin_size
                q[j] = hist[start:end].sum()
            
            # Expand q back to original size
            q_expanded = np.repeat(q, bin_size)[:i]
            q_expanded = q_expanded.astype(np.float64)
            q_expanded[q_expanded == 0] = 1e-10
            q_expanded /= q_expanded.sum()
            
            # KL divergence
            kl = np.sum(p * np.log(p / q_expanded))
            
            if kl < best_kl:
                best_kl = kl
                best_threshold = bin_edges[i]
        
        return best_threshold

Quantization-Aware Training (QAT)

class QuantizationAwareTraining:
    """
    Train with simulated quantization for better accuracy.
    
    Key idea: Insert fake quantization during training
    so the model learns to be robust to quantization noise.
    """
    
    @staticmethod
    def prepare_qat(
        model: nn.Module,
        backend: str = 'fbgemm'
    ) -> nn.Module:
        """Prepare model for QAT."""
        model.train()
        
        # QAT config with fake quantization
        model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
        
        # Prepare for QAT (inserts fake quant modules)
        model_prepared = torch.quantization.prepare_qat(model)
        
        return model_prepared
    
    @staticmethod
    def convert_qat(model: nn.Module) -> nn.Module:
        """Convert QAT model to quantized model."""
        model.eval()
        model_quantized = torch.quantization.convert(model)
        return model_quantized


class FakeQuantize(torch.autograd.Function):
    """
    Fake quantization for training.
    
    Forward: Quantize then immediately dequantize
    Backward: Straight-through estimator (gradient unchanged)
    """
    
    @staticmethod
    def forward(ctx, x, scale, zero_point, qmin=-128, qmax=127):
        # Quantize
        q = torch.round(x / scale) + zero_point
        q = torch.clamp(q, qmin, qmax)
        
        # Immediately dequantize
        x_q = (q - zero_point) * scale
        
        # Save for backward (for gradient clipping)
        ctx.save_for_backward(x, torch.tensor([qmin * scale, qmax * scale]))
        
        return x_q
    
    @staticmethod
    def backward(ctx, grad_output):
        x, bounds = ctx.saved_tensors
        qmin_scaled, qmax_scaled = bounds[0].item(), bounds[1].item()
        
        # Straight-through estimator with clipping
        # Zero gradient for values outside quantization range
        grad_input = grad_output.clone()
        grad_input[x < qmin_scaled] = 0
        grad_input[x > qmax_scaled] = 0
        
        return grad_input, None, None, None, None


class QATLinear(nn.Module):
    """
    Quantization-aware linear layer.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True
    ):
        super().__init__()
        
        self.linear = nn.Linear(in_features, out_features, bias)
        
        # Observers for scale/zero_point
        self.weight_fake_quant = torch.quantization.FakeQuantize(
            observer=torch.quantization.MovingAverageMinMaxObserver,
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8
        )
        
        self.activation_fake_quant = torch.quantization.FakeQuantize(
            observer=torch.quantization.MovingAverageMinMaxObserver,
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Fake quantize input
        x_q = self.activation_fake_quant(x)
        
        # Fake quantize weights
        w_q = self.weight_fake_quant(self.linear.weight)
        
        # Compute with fake-quantized values
        return F.linear(x_q, w_q, self.linear.bias)


# LSQ: Learned Step Size Quantization
class LSQQuantizer(nn.Module):
    """
    Learned Step Size Quantization.
    
    Learn the quantization step size (scale) during training
    instead of using fixed calibration.
    
    Reference: "Learned Step Size Quantization" (Esser et al., 2020)
    """
    
    def __init__(
        self,
        num_bits: int = 8,
        symmetric: bool = True,
        per_channel: bool = False,
        num_channels: int = 1
    ):
        super().__init__()
        
        self.num_bits = num_bits
        self.symmetric = symmetric
        self.per_channel = per_channel
        
        # Compute quantization bounds
        if symmetric:
            self.qmin = -2**(num_bits - 1)
            self.qmax = 2**(num_bits - 1) - 1
        else:
            self.qmin = 0
            self.qmax = 2**num_bits - 1
        
        # Learnable step size (scale)
        if per_channel:
            self.scale = nn.Parameter(torch.ones(num_channels))
        else:
            self.scale = nn.Parameter(torch.tensor(1.0))
        
        self.initialized = False
    
    def init_scale(self, x: torch.Tensor):
        """Initialize scale based on data."""
        if self.per_channel:
            # Per-channel initialization
            for c in range(x.shape[0]):
                max_val = x[c].abs().max()
                self.scale.data[c] = 2 * max_val / (self.qmax - self.qmin)
        else:
            max_val = x.abs().max()
            self.scale.data.fill_(2 * max_val / (self.qmax - self.qmin))
        
        self.initialized = True
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.initialized:
            self.init_scale(x)
        
        # Ensure positive scale
        scale = self.scale.abs() + 1e-8
        
        # Quantize
        if self.per_channel:
            scale_shape = [1] * x.dim()
            scale_shape[0] = -1
            scale = scale.view(*scale_shape)
        
        x_scaled = x / scale
        x_clipped = torch.clamp(x_scaled, self.qmin, self.qmax)
        x_rounded = torch.round(x_clipped)
        
        # Dequantize with gradient trick
        x_q = (x_rounded - x_clipped).detach() + x_clipped
        x_deq = x_q * scale
        
        return x_deq

Advanced Quantization Methods

GPTQ (LLM Quantization)

class GPTQQuantizer:
    """
    GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
    
    Key idea: Use second-order information (Hessian) to minimize
    quantization error layer by layer.
    
    Simplified implementation for educational purposes.
    """
    
    def __init__(
        self,
        num_bits: int = 4,
        group_size: int = 128,
        symmetric: bool = True
    ):
        self.num_bits = num_bits
        self.group_size = group_size
        self.symmetric = symmetric
        
        if symmetric:
            self.qmin = -2**(num_bits - 1)
            self.qmax = 2**(num_bits - 1) - 1
        else:
            self.qmin = 0
            self.qmax = 2**num_bits - 1
    
    def quantize_layer(
        self,
        weight: torch.Tensor,
        H: torch.Tensor  # Hessian (X @ X.T from calibration)
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Quantize a weight matrix using GPTQ algorithm.
        
        Args:
            weight: Weight matrix [out_features, in_features]
            H: Hessian matrix [in_features, in_features]
        """
        W = weight.clone().float()
        d, n = W.shape
        
        # Add damping to Hessian for numerical stability
        damp = 0.01 * torch.diag(H).mean()
        H = H + damp * torch.eye(n)
        
        # Cholesky decomposition
        H_inv = torch.linalg.cholesky(H)
        H_inv = torch.cholesky_inverse(H_inv)
        
        # Process in groups
        Q = torch.zeros_like(W)
        scales = torch.zeros(d, (n + self.group_size - 1) // self.group_size)
        
        for i1 in range(0, n, self.group_size):
            i2 = min(i1 + self.group_size, n)
            
            # Get group weights
            W_group = W[:, i1:i2].clone()
            
            # Compute scale for group
            if self.symmetric:
                max_val = W_group.abs().max(dim=1, keepdim=True)[0]
                scale = max_val / self.qmax
            else:
                min_val = W_group.min(dim=1, keepdim=True)[0]
                max_val = W_group.max(dim=1, keepdim=True)[0]
                scale = (max_val - min_val) / (self.qmax - self.qmin)
            
            scale = scale.clamp(min=1e-10)
            scales[:, i1 // self.group_size] = scale.squeeze()
            
            # Quantize columns one by one
            for i in range(i1, i2):
                # Quantize column
                q = torch.round(W[:, i] / scale.squeeze())
                q = torch.clamp(q, self.qmin, self.qmax)
                Q[:, i] = q
                
                # Compute quantization error
                error = W[:, i] - q * scale.squeeze()
                
                # Update remaining weights to compensate
                if i < n - 1:
                    W[:, i+1:] -= error.unsqueeze(1) * H_inv[i, i+1:].unsqueeze(0)
        
        return Q.to(torch.int8), scales
    
    def collect_hessian(
        self,
        model: nn.Module,
        calibration_data: torch.utils.data.DataLoader,
        layer_name: str
    ) -> torch.Tensor:
        """
        Collect Hessian for a layer using calibration data.
        
        H = sum(X @ X.T) where X is the input to the layer.
        """
        H = None
        
        def hook(module, input, output):
            nonlocal H
            x = input[0]
            x = x.view(-1, x.shape[-1]).float()
            
            if H is None:
                H = x.T @ x
            else:
                H += x.T @ x
        
        # Register hook
        target_module = dict(model.named_modules())[layer_name]
        handle = target_module.register_forward_hook(hook)
        
        # Run calibration
        model.eval()
        with torch.no_grad():
            for batch in calibration_data:
                if isinstance(batch, (tuple, list)):
                    inputs = batch[0]
                else:
                    inputs = batch
                model(inputs)
        
        handle.remove()
        
        return H


class AWQQuantizer:
    """
    Activation-aware Weight Quantization.
    
    Key insight: Not all weights are equally important.
    Scale weights by their activation magnitudes.
    
    Simplified implementation.
    """
    
    def __init__(self, num_bits: int = 4):
        self.num_bits = num_bits
        self.qmin = 0
        self.qmax = 2**num_bits - 1
    
    def compute_importance(
        self,
        weight: torch.Tensor,
        activations: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute importance of each weight column
        based on activation magnitudes.
        """
        # Average activation magnitude per input dimension
        act_cat = torch.cat([a.flatten(end_dim=-2) for a in activations])
        importance = act_cat.abs().mean(dim=0)
        
        return importance
    
    def quantize_with_scaling(
        self,
        weight: torch.Tensor,
        importance: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize weights with importance-aware scaling.
        """
        # Find optimal scaling factor
        best_error = float('inf')
        best_scale = 1.0
        
        for alpha in np.linspace(0.1, 1.0, 20):
            # Scale by importance^alpha
            scaling = importance.pow(alpha)
            scaling = scaling / scaling.max()  # Normalize
            
            # Apply scaling
            W_scaled = weight * scaling.unsqueeze(0)
            
            # Quantize
            max_val = W_scaled.abs().max()
            scale = max_val / self.qmax
            W_q = torch.round(W_scaled / scale)
            W_q = torch.clamp(W_q, self.qmin, self.qmax)
            
            # Dequantize and undo scaling
            W_deq = W_q * scale / scaling.unsqueeze(0)
            
            # Compute error
            error = ((weight - W_deq) ** 2).mean()
            
            if error < best_error:
                best_error = error
                best_scale = alpha
        
        # Final quantization with best scaling
        scaling = importance.pow(best_scale)
        scaling = scaling / scaling.max()
        
        W_scaled = weight * scaling.unsqueeze(0)
        max_val = W_scaled.abs().max()
        quant_scale = max_val / self.qmax
        
        W_q = torch.round(W_scaled / quant_scale)
        W_q = torch.clamp(W_q, self.qmin, self.qmax)
        
        return W_q.to(torch.int8), quant_scale, scaling

Mixed-Precision Quantization

class MixedPrecisionQuantizer:
    """
    Different bit-widths for different layers.
    
    Sensitive layers get higher precision,
    less important layers get lower precision.
    """
    
    def __init__(self, default_bits: int = 8):
        self.default_bits = default_bits
        self.layer_bits = {}
        
    def analyze_sensitivity(
        self,
        model: nn.Module,
        calibration_data: torch.utils.data.DataLoader,
        criterion: nn.Module
    ) -> Dict[str, float]:
        """
        Analyze each layer's sensitivity to quantization.
        
        Method: Quantize each layer individually and measure
        the impact on model output/loss.
        """
        sensitivity = {}
        
        model.eval()
        
        # Get baseline output
        with torch.no_grad():
            baseline_outputs = []
            for batch in calibration_data:
                inputs, targets = batch
                outputs = model(inputs)
                baseline_outputs.append(outputs)
        
        # Test each layer
        for name, module in model.named_modules():
            if not isinstance(module, (nn.Linear, nn.Conv2d)):
                continue
            
            # Save original weights
            orig_weight = module.weight.data.clone()
            
            # Quantize this layer to low precision
            scale, zp = QuantizationMath.compute_scale_zero_point(orig_weight)
            q_weight = QuantizationMath.quantize(orig_weight, scale, zp)
            deq_weight = QuantizationMath.dequantize(q_weight, scale, zp)
            
            module.weight.data = deq_weight
            
            # Measure impact
            total_error = 0
            with torch.no_grad():
                for i, batch in enumerate(calibration_data):
                    inputs, targets = batch
                    outputs = model(inputs)
                    error = F.mse_loss(outputs, baseline_outputs[i])
                    total_error += error.item()
            
            sensitivity[name] = total_error
            
            # Restore original weights
            module.weight.data = orig_weight
        
        return sensitivity
    
    def assign_bit_widths(
        self,
        sensitivity: Dict[str, float],
        target_size_ratio: float = 0.25,
        bit_options: List[int] = [2, 4, 8]
    ) -> Dict[str, int]:
        """
        Assign bit-widths to layers based on sensitivity.
        
        More sensitive layers get more bits.
        """
        # Normalize sensitivity scores
        max_sens = max(sensitivity.values())
        normalized = {k: v / max_sens for k, v in sensitivity.items()}
        
        # Assign bits based on sensitivity
        # High sensitivity -> more bits
        self.layer_bits = {}
        
        for name, sens in normalized.items():
            if sens > 0.7:
                bits = max(bit_options)
            elif sens > 0.3:
                bits = bit_options[len(bit_options) // 2]
            else:
                bits = min(bit_options)
            
            self.layer_bits[name] = bits
        
        return self.layer_bits
    
    def quantize_model(self, model: nn.Module) -> nn.Module:
        """Apply mixed-precision quantization."""
        
        for name, module in model.named_modules():
            if name not in self.layer_bits:
                continue
            
            bits = self.layer_bits[name]
            qmin = -2**(bits - 1)
            qmax = 2**(bits - 1) - 1
            
            if hasattr(module, 'weight'):
                w = module.weight.data
                scale, zp = QuantizationMath.compute_scale_zero_point(
                    w, qmin=qmin, qmax=qmax
                )
                q = QuantizationMath.quantize(w, scale, zp, qmin, qmax)
                module.weight.data = QuantizationMath.dequantize(q, scale, zp)
        
        return model

Hardware-Aware Quantization

class HardwareAwareQuantization:
    """
    Quantization considering target hardware constraints.
    """
    
    HARDWARE_PROFILES = {
        'cpu_x86': {
            'supported_dtypes': ['int8'],
            'optimal_alignment': 32,
            'vectorized_ops': True
        },
        'gpu_nvidia': {
            'supported_dtypes': ['int8', 'fp16', 'bf16', 'fp8'],
            'tensor_cores': True,
            'optimal_batch': 64
        },
        'mobile_arm': {
            'supported_dtypes': ['int8', 'int4'],
            'neon_simd': True,
            'memory_limited': True
        },
        'edge_tpu': {
            'supported_dtypes': ['int8'],
            'requires_symmetric': True,
            'batch_size': 1
        }
    }
    
    def __init__(self, target: str = 'cpu_x86'):
        self.target = target
        self.profile = self.HARDWARE_PROFILES.get(target, {})
    
    def get_optimal_config(self) -> Dict:
        """Get optimal quantization config for target."""
        config = {
            'dtype': self.profile.get('supported_dtypes', ['int8'])[0],
            'symmetric': self.profile.get('requires_symmetric', False),
        }
        
        if self.target == 'gpu_nvidia':
            config['use_tensor_cores'] = True
            config['dtype'] = 'int8'  # For INT8 tensor core ops
        
        return config
    
    def estimate_speedup(
        self,
        model: nn.Module,
        original_dtype: str = 'fp32'
    ) -> Dict[str, float]:
        """
        Estimate speedup from quantization.
        
        Very rough estimates - actual speedup varies by hardware.
        """
        speedups = {
            'fp32_to_fp16': 1.5,
            'fp32_to_int8': 2.5,
            'fp32_to_int4': 4.0
        }
        
        target_dtype = self.profile.get('supported_dtypes', ['int8'])[0]
        key = f'{original_dtype}_to_{target_dtype}'
        
        base_speedup = speedups.get(key, 1.0)
        
        # Adjust for hardware features
        if self.profile.get('tensor_cores'):
            base_speedup *= 1.5
        if self.profile.get('vectorized_ops'):
            base_speedup *= 1.2
        
        return {
            'compute_speedup': base_speedup,
            'memory_reduction': 32 / int(target_dtype.replace('int', '').replace('fp', ''))
        }


# Summary of quantization methods
def quantization_summary():
    """Print quantization methods summary."""
    
    summary = """
    ╔════════════════════════════════════════════════════════════════════╗
    ║                    QUANTIZATION METHODS SUMMARY                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║                                                                     ║
    ║  POST-TRAINING QUANTIZATION (PTQ)                                   ║
    ║  • Dynamic: Weights quantized, activations at runtime               ║
    ║  • Static: Both quantized using calibration data                    ║
    ║  • Pros: Fast, no training needed                                   ║
    ║  • Cons: May lose accuracy, especially at low bits                  ║
    ║                                                                     ║
    ║  QUANTIZATION-AWARE TRAINING (QAT)                                  ║
    ║  • Simulate quantization during training                            ║
    ║  • Model learns to be robust to quantization                        ║
    ║  • Pros: Best accuracy                                              ║
    ║  • Cons: Requires training, more complex                            ║
    ║                                                                     ║
    ║  LLM-SPECIFIC METHODS                                               ║
    ║  • GPTQ: Second-order optimization                                  ║
    ║  • AWQ: Activation-aware weight scaling                             ║
    ║  • SqueezeLLM: Sparsity + quantization                              ║
    ║                                                                     ║
    ║  MIXED PRECISION                                                    ║
    ║  • Different bits for different layers                              ║
    ║  • Balance accuracy vs compression                                  ║
    ║                                                                     ║
    ╠════════════════════════════════════════════════════════════════════╣
    ║  RECOMMENDED APPROACH:                                              ║
    ║  1. Start with PTQ (fast baseline)                                  ║
    ║  2. If accuracy drops, try QAT                                      ║
    ║  3. For LLMs, use GPTQ/AWQ                                          ║
    ║  4. For production, consider mixed precision                        ║
    ╚════════════════════════════════════════════════════════════════════╝
    """
    print(summary)

quantization_summary()

Exercises

Implement and compare:
  • Min-max calibration
  • Percentile calibration
  • Entropy calibration
Measure accuracy vs. calibration set size.
Implement 4-bit quantization with:
  • Group size of 128
  • Separate scale per group
  • Compare to per-tensor quantization

What’s Next?