Documentation Index
Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt
Use this file to discover all available pages before exploring further.
Quantization Deep Dive
Why Quantization?
Think of quantization like compressing a high-resolution photograph into a JPEG: you are trading some fidelity for a dramatically smaller file that loads faster. In neural networks, we replace 32-bit floating-point weights with lower-precision integers (8-bit, 4-bit, or even fewer). The model gets smaller, inference gets faster, and — if done carefully — accuracy barely moves. This is not a theoretical trick: virtually every model you interact with on your phone (keyboard prediction, face unlock, voice assistant) runs quantized. The math is simple but the gains are dramatic:| Precision | Bits | Memory | Speed | Typical Accuracy Loss |
|---|---|---|---|---|
| FP32 | 32 | 1x | 1x | Baseline |
| FP16 | 16 | 0.5x | ~2x | Negligible |
| INT8 | 8 | 0.25x | ~4x | Less than 1% for most models |
| INT4 | 4 | 0.125x | ~8x | 1-3%, varies by architecture |
Quantization is not free lunch. Models with very narrow weight distributions (e.g., small transformer layers) quantize well, but models with outlier weights (common in large language models) can degrade significantly. Always measure accuracy after quantization — do not assume the table above applies to your specific model.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization as quant
from torch.quantization import QConfig, default_observer, default_weight_observer
from typing import Optional, Tuple, Dict, List, Callable
import numpy as np
from dataclasses import dataclass
torch.manual_seed(42)
Quantization Fundamentals
At its core, quantization is a mapping problem: given a continuous range of floating-point values, map them to a smaller set of discrete integer values. The art is choosing this mapping so that the values the model cares about most (the ones that actually affect predictions) are represented accurately, while values in less-important ranges can tolerate more error. Think of it like a piano: FP32 is a piano with thousands of keys spanning the full audible range. INT8 is a piano with only 256 keys. Quantization decides which 256 pitches to include. A good quantization scheme puts more keys where the music actually plays (near the weight distribution’s center) and fewer keys in the rarely-used extremes.Number Representation
@dataclass
class QuantizationParams:
"""Parameters for quantization mapping."""
scale: float
zero_point: int
qmin: int
qmax: int
dtype: torch.dtype = torch.qint8
class QuantizationMath:
"""
Core quantization operations.
Quantization maps floating point values to integers:
q = round(x / scale) + zero_point
x = (q - zero_point) * scale
For symmetric quantization:
scale = max(|x|) / (qmax)
zero_point = 0
For asymmetric quantization:
scale = (max(x) - min(x)) / (qmax - qmin)
zero_point = round(-min(x) / scale)
"""
@staticmethod
def compute_scale_zero_point(
tensor: torch.Tensor,
qmin: int = -128,
qmax: int = 127,
symmetric: bool = True
) -> Tuple[float, int]:
"""
Compute scale and zero point for quantization.
Args:
tensor: Input tensor to quantize
qmin, qmax: Quantization range
symmetric: Use symmetric quantization
"""
x_min = tensor.min().item()
x_max = tensor.max().item()
if symmetric:
# Symmetric: zero_point = 0, scale based on max abs
max_abs = max(abs(x_min), abs(x_max))
scale = max_abs / qmax if max_abs != 0 else 1.0
zero_point = 0
else:
# Asymmetric: full range mapping
scale = (x_max - x_min) / (qmax - qmin) if x_max != x_min else 1.0
zero_point = round(-x_min / scale) + qmin
zero_point = max(qmin, min(qmax, zero_point))
return scale, zero_point
@staticmethod
def quantize(
tensor: torch.Tensor,
scale: float,
zero_point: int,
qmin: int = -128,
qmax: int = 127
) -> torch.Tensor:
"""Quantize a floating point tensor."""
q = torch.round(tensor / scale) + zero_point
q = torch.clamp(q, qmin, qmax)
return q.to(torch.int8)
@staticmethod
def dequantize(
q_tensor: torch.Tensor,
scale: float,
zero_point: int
) -> torch.Tensor:
"""Dequantize an integer tensor."""
return (q_tensor.float() - zero_point) * scale
# Demonstrate quantization
def demo_quantization():
"""Show quantization in action."""
# Original values
x = torch.tensor([0.1, 0.5, 1.0, -0.3, 2.0, -1.5])
# Compute parameters
scale, zp = QuantizationMath.compute_scale_zero_point(x)
print(f"Scale: {scale:.4f}, Zero Point: {zp}")
# Quantize
q = QuantizationMath.quantize(x, scale, zp)
print(f"Quantized: {q}")
# Dequantize
x_deq = QuantizationMath.dequantize(q, scale, zp)
print(f"Dequantized: {x_deq}")
# Error
error = torch.abs(x - x_deq)
print(f"Quantization error: {error}")
demo_quantization()
Granularity Levels
class QuantizationGranularity:
"""
Different levels of quantization granularity.
Finer granularity = better accuracy but more overhead.
"""
@staticmethod
def per_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, int]:
"""
One scale/zero_point for entire tensor.
Pros: Simple, minimal overhead
Cons: Less accurate for varying distributions
"""
scale, zp = QuantizationMath.compute_scale_zero_point(tensor)
q = QuantizationMath.quantize(tensor, scale, zp)
return q, scale, zp
@staticmethod
def per_channel(tensor: torch.Tensor, axis: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Separate scale/zero_point per channel.
Pros: Better accuracy for weights
Cons: More storage for parameters
"""
n_channels = tensor.shape[axis]
scales = torch.zeros(n_channels)
zero_points = torch.zeros(n_channels, dtype=torch.int32)
q = torch.zeros_like(tensor, dtype=torch.int8)
for c in range(n_channels):
if axis == 0:
channel_data = tensor[c]
else:
channel_data = tensor.select(axis, c)
scale, zp = QuantizationMath.compute_scale_zero_point(channel_data)
scales[c] = scale
zero_points[c] = zp
if axis == 0:
q[c] = QuantizationMath.quantize(channel_data, scale, zp)
return q, scales, zero_points
@staticmethod
def per_group(
tensor: torch.Tensor,
group_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Separate scale/zero_point per group of elements.
Used in LLM quantization (e.g., GPTQ, AWQ).
"""
flat = tensor.flatten()
n_groups = (len(flat) + group_size - 1) // group_size
# Pad if necessary
padded = F.pad(flat, (0, n_groups * group_size - len(flat)))
groups = padded.view(n_groups, group_size)
scales = torch.zeros(n_groups)
zero_points = torch.zeros(n_groups, dtype=torch.int32)
q_groups = torch.zeros_like(groups, dtype=torch.int8)
for i in range(n_groups):
scale, zp = QuantizationMath.compute_scale_zero_point(groups[i])
scales[i] = scale
zero_points[i] = zp
q_groups[i] = QuantizationMath.quantize(groups[i], scale, zp)
q = q_groups.flatten()[:len(flat)].view(tensor.shape)
return q, scales, zero_points
Post-Training Quantization (PTQ)
PTQ is the “quick and easy” path to quantization: take a fully-trained FP32 model and convert it to lower precision without any additional training. The appeal is obvious — no retraining means no GPU hours, no hyperparameter tuning, and you can quantize models you do not even have the training data for. The trade-off is that PTQ typically loses more accuracy than QAT (Quantization-Aware Training), especially at very low bit widths (4-bit).For most models, start with dynamic quantization (quantize weights, compute activation scales at runtime). It requires zero calibration data and provides 2-4x speedup on CPU inference with typically under 1% accuracy loss. Only move to static quantization if dynamic does not meet your latency requirements, since static quantization requires collecting calibration data but provides additional speedup by pre-computing activation scales.
class PostTrainingQuantization:
"""
Quantize a trained model without retraining.
Two flavors:
1. Dynamic quantization: Weights are quantized offline, activations are
quantized on-the-fly at runtime. Zero calibration data needed.
Best for: RNNs, Transformers, CPU inference.
2. Static quantization: Both weights and activations are quantized using
pre-calibrated scales. Requires representative data.
Best for: CNNs, mobile/edge deployment.
"""
@staticmethod
def dynamic_quantization(
model: nn.Module,
dtype: torch.dtype = torch.qint8
) -> nn.Module:
"""
Dynamic quantization: weights are quantized,
activations are quantized dynamically at runtime.
Best for: RNNs, Transformers, models with variable input
"""
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear, nn.LSTM, nn.GRU}, # Layers to quantize
dtype=dtype
)
return quantized_model
@staticmethod
def static_quantization(
model: nn.Module,
calibration_data: torch.utils.data.DataLoader,
backend: str = 'fbgemm' # or 'qnnpack' for mobile
) -> nn.Module:
"""
Static quantization: both weights and activations
are quantized using calibrated scales.
Requires: Representative calibration data
"""
# 1. Prepare model
model.eval()
# Set quantization config
model.qconfig = torch.quantization.get_default_qconfig(backend)
# Fuse operations (Conv+BN+ReLU, Linear+ReLU, etc.)
model_fused = torch.quantization.fuse_modules(
model,
[['conv', 'bn', 'relu']] # Example fusion
)
# 2. Insert observers
model_prepared = torch.quantization.prepare(model_fused)
# 3. Calibrate with representative data
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, (tuple, list)):
inputs = batch[0]
else:
inputs = batch
model_prepared(inputs)
# 4. Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared)
return model_quantized
class CalibrationMethods:
"""
Methods to determine optimal quantization parameters.
"""
@staticmethod
def min_max_calibration(activations: List[torch.Tensor]) -> Tuple[float, float]:
"""Simple min/max calibration."""
all_min = min(a.min().item() for a in activations)
all_max = max(a.max().item() for a in activations)
return all_min, all_max
@staticmethod
def percentile_calibration(
activations: List[torch.Tensor],
percentile: float = 99.99
) -> Tuple[float, float]:
"""
Use percentiles instead of min/max.
More robust to outliers.
"""
all_values = torch.cat([a.flatten() for a in activations])
low = np.percentile(all_values.numpy(), 100 - percentile)
high = np.percentile(all_values.numpy(), percentile)
return low, high
@staticmethod
def entropy_calibration(
activations: List[torch.Tensor],
num_bins: int = 2048,
num_quantized_bins: int = 128
) -> float:
"""
KL divergence based calibration (TensorRT style).
Find threshold that minimizes KL divergence between
original and quantized distributions.
"""
all_values = torch.cat([a.flatten() for a in activations])
# Build histogram
hist, bin_edges = np.histogram(all_values.numpy(), bins=num_bins)
# Try different thresholds
best_threshold = bin_edges[-1]
best_kl = float('inf')
for i in range(num_quantized_bins, num_bins):
# Reference distribution (up to threshold)
p = hist[:i].copy().astype(np.float64)
p[p == 0] = 1e-10
p /= p.sum()
# Quantized distribution
q = np.zeros(num_quantized_bins, dtype=np.float64)
bin_size = i // num_quantized_bins
for j in range(num_quantized_bins):
start = j * bin_size
end = start + bin_size
q[j] = hist[start:end].sum()
# Expand q back to original size
q_expanded = np.repeat(q, bin_size)[:i]
q_expanded = q_expanded.astype(np.float64)
q_expanded[q_expanded == 0] = 1e-10
q_expanded /= q_expanded.sum()
# KL divergence
kl = np.sum(p * np.log(p / q_expanded))
if kl < best_kl:
best_kl = kl
best_threshold = bin_edges[i]
return best_threshold
Quantization-Aware Training (QAT)
class QuantizationAwareTraining:
"""
Train with simulated quantization for better accuracy.
Key idea: Insert fake quantization during training
so the model learns to be robust to quantization noise.
"""
@staticmethod
def prepare_qat(
model: nn.Module,
backend: str = 'fbgemm'
) -> nn.Module:
"""Prepare model for QAT."""
model.train()
# QAT config with fake quantization
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
# Prepare for QAT (inserts fake quant modules)
model_prepared = torch.quantization.prepare_qat(model)
return model_prepared
@staticmethod
def convert_qat(model: nn.Module) -> nn.Module:
"""Convert QAT model to quantized model."""
model.eval()
model_quantized = torch.quantization.convert(model)
return model_quantized
class FakeQuantize(torch.autograd.Function):
"""
Fake quantization for training.
Forward: Quantize then immediately dequantize
Backward: Straight-through estimator (gradient unchanged)
"""
@staticmethod
def forward(ctx, x, scale, zero_point, qmin=-128, qmax=127):
# Quantize
q = torch.round(x / scale) + zero_point
q = torch.clamp(q, qmin, qmax)
# Immediately dequantize
x_q = (q - zero_point) * scale
# Save for backward (for gradient clipping)
ctx.save_for_backward(x, torch.tensor([qmin * scale, qmax * scale]))
return x_q
@staticmethod
def backward(ctx, grad_output):
x, bounds = ctx.saved_tensors
qmin_scaled, qmax_scaled = bounds[0].item(), bounds[1].item()
# Straight-through estimator with clipping
# Zero gradient for values outside quantization range
grad_input = grad_output.clone()
grad_input[x < qmin_scaled] = 0
grad_input[x > qmax_scaled] = 0
return grad_input, None, None, None, None
class QATLinear(nn.Module):
"""
Quantization-aware linear layer.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True
):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
# Observers for scale/zero_point
self.weight_fake_quant = torch.quantization.FakeQuantize(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8
)
self.activation_fake_quant = torch.quantization.FakeQuantize(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Fake quantize input
x_q = self.activation_fake_quant(x)
# Fake quantize weights
w_q = self.weight_fake_quant(self.linear.weight)
# Compute with fake-quantized values
return F.linear(x_q, w_q, self.linear.bias)
# LSQ: Learned Step Size Quantization
class LSQQuantizer(nn.Module):
"""
Learned Step Size Quantization.
Learn the quantization step size (scale) during training
instead of using fixed calibration.
Reference: "Learned Step Size Quantization" (Esser et al., 2020)
"""
def __init__(
self,
num_bits: int = 8,
symmetric: bool = True,
per_channel: bool = False,
num_channels: int = 1
):
super().__init__()
self.num_bits = num_bits
self.symmetric = symmetric
self.per_channel = per_channel
# Compute quantization bounds
if symmetric:
self.qmin = -2**(num_bits - 1)
self.qmax = 2**(num_bits - 1) - 1
else:
self.qmin = 0
self.qmax = 2**num_bits - 1
# Learnable step size (scale)
if per_channel:
self.scale = nn.Parameter(torch.ones(num_channels))
else:
self.scale = nn.Parameter(torch.tensor(1.0))
self.initialized = False
def init_scale(self, x: torch.Tensor):
"""Initialize scale based on data."""
if self.per_channel:
# Per-channel initialization
for c in range(x.shape[0]):
max_val = x[c].abs().max()
self.scale.data[c] = 2 * max_val / (self.qmax - self.qmin)
else:
max_val = x.abs().max()
self.scale.data.fill_(2 * max_val / (self.qmax - self.qmin))
self.initialized = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.initialized:
self.init_scale(x)
# Ensure positive scale
scale = self.scale.abs() + 1e-8
# Quantize
if self.per_channel:
scale_shape = [1] * x.dim()
scale_shape[0] = -1
scale = scale.view(*scale_shape)
x_scaled = x / scale
x_clipped = torch.clamp(x_scaled, self.qmin, self.qmax)
x_rounded = torch.round(x_clipped)
# Dequantize with gradient trick
x_q = (x_rounded - x_clipped).detach() + x_clipped
x_deq = x_q * scale
return x_deq
Advanced Quantization Methods
The methods above (PTQ, QAT) work well for models up to a few hundred million parameters. But large language models (LLMs) with billions of parameters present unique challenges: they have outlier activations that break standard quantization, retraining them (QAT) costs millions of dollars, and they need to run on consumer hardware. The following methods were specifically designed for this regime.For quantizing LLMs in practice, the landscape in 2024-2025 has converged: use GPTQ or AWQ for weight-only quantization (good for inference), and use bitsandbytes for training (QLoRA). The choice between GPTQ and AWQ is mostly about toolchain preference — accuracy is similar. For models under 7B parameters, INT8 is usually sufficient. For 13B+ models, INT4 with group quantization (group_size=128) gives the best size-accuracy trade-off.
GPTQ (LLM Quantization)
class GPTQQuantizer:
"""
GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
Key idea: Use second-order information (Hessian) to minimize
quantization error layer by layer. When you quantize one weight,
adjust its neighbors to compensate -- like adjusting the rest of
a jigsaw puzzle when one piece is slightly off.
Simplified implementation for educational purposes.
In practice, use the `auto-gptq` or `gptqmodel` library.
"""
def __init__(
self,
num_bits: int = 4,
group_size: int = 128,
symmetric: bool = True
):
self.num_bits = num_bits
self.group_size = group_size
self.symmetric = symmetric
if symmetric:
self.qmin = -2**(num_bits - 1)
self.qmax = 2**(num_bits - 1) - 1
else:
self.qmin = 0
self.qmax = 2**num_bits - 1
def quantize_layer(
self,
weight: torch.Tensor,
H: torch.Tensor # Hessian (X @ X.T from calibration)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize a weight matrix using GPTQ algorithm.
Args:
weight: Weight matrix [out_features, in_features]
H: Hessian matrix [in_features, in_features]
"""
W = weight.clone().float()
d, n = W.shape
# Add damping to Hessian for numerical stability
damp = 0.01 * torch.diag(H).mean()
H = H + damp * torch.eye(n)
# Cholesky decomposition
H_inv = torch.linalg.cholesky(H)
H_inv = torch.cholesky_inverse(H_inv)
# Process in groups
Q = torch.zeros_like(W)
scales = torch.zeros(d, (n + self.group_size - 1) // self.group_size)
for i1 in range(0, n, self.group_size):
i2 = min(i1 + self.group_size, n)
# Get group weights
W_group = W[:, i1:i2].clone()
# Compute scale for group
if self.symmetric:
max_val = W_group.abs().max(dim=1, keepdim=True)[0]
scale = max_val / self.qmax
else:
min_val = W_group.min(dim=1, keepdim=True)[0]
max_val = W_group.max(dim=1, keepdim=True)[0]
scale = (max_val - min_val) / (self.qmax - self.qmin)
scale = scale.clamp(min=1e-10)
scales[:, i1 // self.group_size] = scale.squeeze()
# Quantize columns one by one
for i in range(i1, i2):
# Quantize column
q = torch.round(W[:, i] / scale.squeeze())
q = torch.clamp(q, self.qmin, self.qmax)
Q[:, i] = q
# Compute quantization error
error = W[:, i] - q * scale.squeeze()
# Update remaining weights to compensate
if i < n - 1:
W[:, i+1:] -= error.unsqueeze(1) * H_inv[i, i+1:].unsqueeze(0)
return Q.to(torch.int8), scales
def collect_hessian(
self,
model: nn.Module,
calibration_data: torch.utils.data.DataLoader,
layer_name: str
) -> torch.Tensor:
"""
Collect Hessian for a layer using calibration data.
H = sum(X @ X.T) where X is the input to the layer.
"""
H = None
def hook(module, input, output):
nonlocal H
x = input[0]
x = x.view(-1, x.shape[-1]).float()
if H is None:
H = x.T @ x
else:
H += x.T @ x
# Register hook
target_module = dict(model.named_modules())[layer_name]
handle = target_module.register_forward_hook(hook)
# Run calibration
model.eval()
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, (tuple, list)):
inputs = batch[0]
else:
inputs = batch
model(inputs)
handle.remove()
return H
class AWQQuantizer:
"""
Activation-aware Weight Quantization.
Key insight: Not all weights are equally important.
Scale weights by their activation magnitudes.
Simplified implementation.
"""
def __init__(self, num_bits: int = 4):
self.num_bits = num_bits
self.qmin = 0
self.qmax = 2**num_bits - 1
def compute_importance(
self,
weight: torch.Tensor,
activations: List[torch.Tensor]
) -> torch.Tensor:
"""
Compute importance of each weight column
based on activation magnitudes.
"""
# Average activation magnitude per input dimension
act_cat = torch.cat([a.flatten(end_dim=-2) for a in activations])
importance = act_cat.abs().mean(dim=0)
return importance
def quantize_with_scaling(
self,
weight: torch.Tensor,
importance: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantize weights with importance-aware scaling.
"""
# Find optimal scaling factor
best_error = float('inf')
best_scale = 1.0
for alpha in np.linspace(0.1, 1.0, 20):
# Scale by importance^alpha
scaling = importance.pow(alpha)
scaling = scaling / scaling.max() # Normalize
# Apply scaling
W_scaled = weight * scaling.unsqueeze(0)
# Quantize
max_val = W_scaled.abs().max()
scale = max_val / self.qmax
W_q = torch.round(W_scaled / scale)
W_q = torch.clamp(W_q, self.qmin, self.qmax)
# Dequantize and undo scaling
W_deq = W_q * scale / scaling.unsqueeze(0)
# Compute error
error = ((weight - W_deq) ** 2).mean()
if error < best_error:
best_error = error
best_scale = alpha
# Final quantization with best scaling
scaling = importance.pow(best_scale)
scaling = scaling / scaling.max()
W_scaled = weight * scaling.unsqueeze(0)
max_val = W_scaled.abs().max()
quant_scale = max_val / self.qmax
W_q = torch.round(W_scaled / quant_scale)
W_q = torch.clamp(W_q, self.qmin, self.qmax)
return W_q.to(torch.int8), quant_scale, scaling
Mixed-Precision Quantization
class MixedPrecisionQuantizer:
"""
Different bit-widths for different layers.
Sensitive layers get higher precision,
less important layers get lower precision.
"""
def __init__(self, default_bits: int = 8):
self.default_bits = default_bits
self.layer_bits = {}
def analyze_sensitivity(
self,
model: nn.Module,
calibration_data: torch.utils.data.DataLoader,
criterion: nn.Module
) -> Dict[str, float]:
"""
Analyze each layer's sensitivity to quantization.
Method: Quantize each layer individually and measure
the impact on model output/loss.
"""
sensitivity = {}
model.eval()
# Get baseline output
with torch.no_grad():
baseline_outputs = []
for batch in calibration_data:
inputs, targets = batch
outputs = model(inputs)
baseline_outputs.append(outputs)
# Test each layer
for name, module in model.named_modules():
if not isinstance(module, (nn.Linear, nn.Conv2d)):
continue
# Save original weights
orig_weight = module.weight.data.clone()
# Quantize this layer to low precision
scale, zp = QuantizationMath.compute_scale_zero_point(orig_weight)
q_weight = QuantizationMath.quantize(orig_weight, scale, zp)
deq_weight = QuantizationMath.dequantize(q_weight, scale, zp)
module.weight.data = deq_weight
# Measure impact
total_error = 0
with torch.no_grad():
for i, batch in enumerate(calibration_data):
inputs, targets = batch
outputs = model(inputs)
error = F.mse_loss(outputs, baseline_outputs[i])
total_error += error.item()
sensitivity[name] = total_error
# Restore original weights
module.weight.data = orig_weight
return sensitivity
def assign_bit_widths(
self,
sensitivity: Dict[str, float],
target_size_ratio: float = 0.25,
bit_options: List[int] = [2, 4, 8]
) -> Dict[str, int]:
"""
Assign bit-widths to layers based on sensitivity.
More sensitive layers get more bits.
"""
# Normalize sensitivity scores
max_sens = max(sensitivity.values())
normalized = {k: v / max_sens for k, v in sensitivity.items()}
# Assign bits based on sensitivity
# High sensitivity -> more bits
self.layer_bits = {}
for name, sens in normalized.items():
if sens > 0.7:
bits = max(bit_options)
elif sens > 0.3:
bits = bit_options[len(bit_options) // 2]
else:
bits = min(bit_options)
self.layer_bits[name] = bits
return self.layer_bits
def quantize_model(self, model: nn.Module) -> nn.Module:
"""Apply mixed-precision quantization."""
for name, module in model.named_modules():
if name not in self.layer_bits:
continue
bits = self.layer_bits[name]
qmin = -2**(bits - 1)
qmax = 2**(bits - 1) - 1
if hasattr(module, 'weight'):
w = module.weight.data
scale, zp = QuantizationMath.compute_scale_zero_point(
w, qmin=qmin, qmax=qmax
)
q = QuantizationMath.quantize(w, scale, zp, qmin, qmax)
module.weight.data = QuantizationMath.dequantize(q, scale, zp)
return model
Hardware-Aware Quantization
class HardwareAwareQuantization:
"""
Quantization considering target hardware constraints.
"""
HARDWARE_PROFILES = {
'cpu_x86': {
'supported_dtypes': ['int8'],
'optimal_alignment': 32,
'vectorized_ops': True
},
'gpu_nvidia': {
'supported_dtypes': ['int8', 'fp16', 'bf16', 'fp8'],
'tensor_cores': True,
'optimal_batch': 64
},
'mobile_arm': {
'supported_dtypes': ['int8', 'int4'],
'neon_simd': True,
'memory_limited': True
},
'edge_tpu': {
'supported_dtypes': ['int8'],
'requires_symmetric': True,
'batch_size': 1
}
}
def __init__(self, target: str = 'cpu_x86'):
self.target = target
self.profile = self.HARDWARE_PROFILES.get(target, {})
def get_optimal_config(self) -> Dict:
"""Get optimal quantization config for target."""
config = {
'dtype': self.profile.get('supported_dtypes', ['int8'])[0],
'symmetric': self.profile.get('requires_symmetric', False),
}
if self.target == 'gpu_nvidia':
config['use_tensor_cores'] = True
config['dtype'] = 'int8' # For INT8 tensor core ops
return config
def estimate_speedup(
self,
model: nn.Module,
original_dtype: str = 'fp32'
) -> Dict[str, float]:
"""
Estimate speedup from quantization.
Very rough estimates - actual speedup varies by hardware.
"""
speedups = {
'fp32_to_fp16': 1.5,
'fp32_to_int8': 2.5,
'fp32_to_int4': 4.0
}
target_dtype = self.profile.get('supported_dtypes', ['int8'])[0]
key = f'{original_dtype}_to_{target_dtype}'
base_speedup = speedups.get(key, 1.0)
# Adjust for hardware features
if self.profile.get('tensor_cores'):
base_speedup *= 1.5
if self.profile.get('vectorized_ops'):
base_speedup *= 1.2
return {
'compute_speedup': base_speedup,
'memory_reduction': 32 / int(target_dtype.replace('int', '').replace('fp', ''))
}
# Summary of quantization methods
def quantization_summary():
"""Print quantization methods summary."""
summary = """
╔════════════════════════════════════════════════════════════════════╗
║ QUANTIZATION METHODS SUMMARY ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ POST-TRAINING QUANTIZATION (PTQ) ║
║ • Dynamic: Weights quantized, activations at runtime ║
║ • Static: Both quantized using calibration data ║
║ • Pros: Fast, no training needed ║
║ • Cons: May lose accuracy, especially at low bits ║
║ ║
║ QUANTIZATION-AWARE TRAINING (QAT) ║
║ • Simulate quantization during training ║
║ • Model learns to be robust to quantization ║
║ • Pros: Best accuracy ║
║ • Cons: Requires training, more complex ║
║ ║
║ LLM-SPECIFIC METHODS ║
║ • GPTQ: Second-order optimization ║
║ • AWQ: Activation-aware weight scaling ║
║ • SqueezeLLM: Sparsity + quantization ║
║ ║
║ MIXED PRECISION ║
║ • Different bits for different layers ║
║ • Balance accuracy vs compression ║
║ ║
╠════════════════════════════════════════════════════════════════════╣
║ RECOMMENDED APPROACH: ║
║ 1. Start with PTQ (fast baseline) ║
║ 2. If accuracy drops, try QAT ║
║ 3. For LLMs, use GPTQ/AWQ ║
║ 4. For production, consider mixed precision ║
╚════════════════════════════════════════════════════════════════════╝
"""
print(summary)
quantization_summary()
Exercises
Exercise 1: Compare Calibration Methods
Exercise 1: Compare Calibration Methods
Implement and compare:
- Min-max calibration
- Percentile calibration
- Entropy calibration
Exercise 2: Implement Per-Group Quantization
Exercise 2: Implement Per-Group Quantization
Implement 4-bit quantization with:
- Group size of 128
- Separate scale per group
- Compare to per-tensor quantization
Exercise 3: Mixed-Precision Search
Exercise 3: Mixed-Precision Search
Implement automated mixed-precision search:
- Use sensitivity analysis
- Optimize for target model size
- Meet accuracy constraint
What’s Next?
Knowledge Distillation
Transfer knowledge to smaller models
Continual Learning
Learn new tasks without forgetting