Quantization Deep Dive
Why Quantization?
Reduce model size and increase inference speed:| Precision | Bits | Memory | Speed |
|---|---|---|---|
| FP32 | 32 | 1x | 1x |
| FP16 | 16 | 0.5x | ~2x |
| INT8 | 8 | 0.25x | ~4x |
| INT4 | 4 | 0.125x | ~8x |
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization as quant
from torch.quantization import QConfig, default_observer, default_weight_observer
from typing import Optional, Tuple, Dict, List, Callable
import numpy as np
from dataclasses import dataclass
torch.manual_seed(42)
Quantization Fundamentals
Number Representation
Copy
@dataclass
class QuantizationParams:
"""Parameters for quantization mapping."""
scale: float
zero_point: int
qmin: int
qmax: int
dtype: torch.dtype = torch.qint8
class QuantizationMath:
"""
Core quantization operations.
Quantization maps floating point values to integers:
q = round(x / scale) + zero_point
x = (q - zero_point) * scale
For symmetric quantization:
scale = max(|x|) / (qmax)
zero_point = 0
For asymmetric quantization:
scale = (max(x) - min(x)) / (qmax - qmin)
zero_point = round(-min(x) / scale)
"""
@staticmethod
def compute_scale_zero_point(
tensor: torch.Tensor,
qmin: int = -128,
qmax: int = 127,
symmetric: bool = True
) -> Tuple[float, int]:
"""
Compute scale and zero point for quantization.
Args:
tensor: Input tensor to quantize
qmin, qmax: Quantization range
symmetric: Use symmetric quantization
"""
x_min = tensor.min().item()
x_max = tensor.max().item()
if symmetric:
# Symmetric: zero_point = 0, scale based on max abs
max_abs = max(abs(x_min), abs(x_max))
scale = max_abs / qmax if max_abs != 0 else 1.0
zero_point = 0
else:
# Asymmetric: full range mapping
scale = (x_max - x_min) / (qmax - qmin) if x_max != x_min else 1.0
zero_point = round(-x_min / scale) + qmin
zero_point = max(qmin, min(qmax, zero_point))
return scale, zero_point
@staticmethod
def quantize(
tensor: torch.Tensor,
scale: float,
zero_point: int,
qmin: int = -128,
qmax: int = 127
) -> torch.Tensor:
"""Quantize a floating point tensor."""
q = torch.round(tensor / scale) + zero_point
q = torch.clamp(q, qmin, qmax)
return q.to(torch.int8)
@staticmethod
def dequantize(
q_tensor: torch.Tensor,
scale: float,
zero_point: int
) -> torch.Tensor:
"""Dequantize an integer tensor."""
return (q_tensor.float() - zero_point) * scale
# Demonstrate quantization
def demo_quantization():
"""Show quantization in action."""
# Original values
x = torch.tensor([0.1, 0.5, 1.0, -0.3, 2.0, -1.5])
# Compute parameters
scale, zp = QuantizationMath.compute_scale_zero_point(x)
print(f"Scale: {scale:.4f}, Zero Point: {zp}")
# Quantize
q = QuantizationMath.quantize(x, scale, zp)
print(f"Quantized: {q}")
# Dequantize
x_deq = QuantizationMath.dequantize(q, scale, zp)
print(f"Dequantized: {x_deq}")
# Error
error = torch.abs(x - x_deq)
print(f"Quantization error: {error}")
demo_quantization()
Granularity Levels
Copy
class QuantizationGranularity:
"""
Different levels of quantization granularity.
Finer granularity = better accuracy but more overhead.
"""
@staticmethod
def per_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, float, int]:
"""
One scale/zero_point for entire tensor.
Pros: Simple, minimal overhead
Cons: Less accurate for varying distributions
"""
scale, zp = QuantizationMath.compute_scale_zero_point(tensor)
q = QuantizationMath.quantize(tensor, scale, zp)
return q, scale, zp
@staticmethod
def per_channel(tensor: torch.Tensor, axis: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Separate scale/zero_point per channel.
Pros: Better accuracy for weights
Cons: More storage for parameters
"""
n_channels = tensor.shape[axis]
scales = torch.zeros(n_channels)
zero_points = torch.zeros(n_channels, dtype=torch.int32)
q = torch.zeros_like(tensor, dtype=torch.int8)
for c in range(n_channels):
if axis == 0:
channel_data = tensor[c]
else:
channel_data = tensor.select(axis, c)
scale, zp = QuantizationMath.compute_scale_zero_point(channel_data)
scales[c] = scale
zero_points[c] = zp
if axis == 0:
q[c] = QuantizationMath.quantize(channel_data, scale, zp)
return q, scales, zero_points
@staticmethod
def per_group(
tensor: torch.Tensor,
group_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Separate scale/zero_point per group of elements.
Used in LLM quantization (e.g., GPTQ, AWQ).
"""
flat = tensor.flatten()
n_groups = (len(flat) + group_size - 1) // group_size
# Pad if necessary
padded = F.pad(flat, (0, n_groups * group_size - len(flat)))
groups = padded.view(n_groups, group_size)
scales = torch.zeros(n_groups)
zero_points = torch.zeros(n_groups, dtype=torch.int32)
q_groups = torch.zeros_like(groups, dtype=torch.int8)
for i in range(n_groups):
scale, zp = QuantizationMath.compute_scale_zero_point(groups[i])
scales[i] = scale
zero_points[i] = zp
q_groups[i] = QuantizationMath.quantize(groups[i], scale, zp)
q = q_groups.flatten()[:len(flat)].view(tensor.shape)
return q, scales, zero_points
Post-Training Quantization (PTQ)
Copy
class PostTrainingQuantization:
"""
Quantize a trained model without retraining.
Methods:
1. Dynamic quantization: Quantize weights, compute scales at runtime
2. Static quantization: Calibrate with representative data
"""
@staticmethod
def dynamic_quantization(
model: nn.Module,
dtype: torch.dtype = torch.qint8
) -> nn.Module:
"""
Dynamic quantization: weights are quantized,
activations are quantized dynamically at runtime.
Best for: RNNs, Transformers, models with variable input
"""
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear, nn.LSTM, nn.GRU}, # Layers to quantize
dtype=dtype
)
return quantized_model
@staticmethod
def static_quantization(
model: nn.Module,
calibration_data: torch.utils.data.DataLoader,
backend: str = 'fbgemm' # or 'qnnpack' for mobile
) -> nn.Module:
"""
Static quantization: both weights and activations
are quantized using calibrated scales.
Requires: Representative calibration data
"""
# 1. Prepare model
model.eval()
# Set quantization config
model.qconfig = torch.quantization.get_default_qconfig(backend)
# Fuse operations (Conv+BN+ReLU, Linear+ReLU, etc.)
model_fused = torch.quantization.fuse_modules(
model,
[['conv', 'bn', 'relu']] # Example fusion
)
# 2. Insert observers
model_prepared = torch.quantization.prepare(model_fused)
# 3. Calibrate with representative data
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, (tuple, list)):
inputs = batch[0]
else:
inputs = batch
model_prepared(inputs)
# 4. Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared)
return model_quantized
class CalibrationMethods:
"""
Methods to determine optimal quantization parameters.
"""
@staticmethod
def min_max_calibration(activations: List[torch.Tensor]) -> Tuple[float, float]:
"""Simple min/max calibration."""
all_min = min(a.min().item() for a in activations)
all_max = max(a.max().item() for a in activations)
return all_min, all_max
@staticmethod
def percentile_calibration(
activations: List[torch.Tensor],
percentile: float = 99.99
) -> Tuple[float, float]:
"""
Use percentiles instead of min/max.
More robust to outliers.
"""
all_values = torch.cat([a.flatten() for a in activations])
low = np.percentile(all_values.numpy(), 100 - percentile)
high = np.percentile(all_values.numpy(), percentile)
return low, high
@staticmethod
def entropy_calibration(
activations: List[torch.Tensor],
num_bins: int = 2048,
num_quantized_bins: int = 128
) -> float:
"""
KL divergence based calibration (TensorRT style).
Find threshold that minimizes KL divergence between
original and quantized distributions.
"""
all_values = torch.cat([a.flatten() for a in activations])
# Build histogram
hist, bin_edges = np.histogram(all_values.numpy(), bins=num_bins)
# Try different thresholds
best_threshold = bin_edges[-1]
best_kl = float('inf')
for i in range(num_quantized_bins, num_bins):
# Reference distribution (up to threshold)
p = hist[:i].copy().astype(np.float64)
p[p == 0] = 1e-10
p /= p.sum()
# Quantized distribution
q = np.zeros(num_quantized_bins, dtype=np.float64)
bin_size = i // num_quantized_bins
for j in range(num_quantized_bins):
start = j * bin_size
end = start + bin_size
q[j] = hist[start:end].sum()
# Expand q back to original size
q_expanded = np.repeat(q, bin_size)[:i]
q_expanded = q_expanded.astype(np.float64)
q_expanded[q_expanded == 0] = 1e-10
q_expanded /= q_expanded.sum()
# KL divergence
kl = np.sum(p * np.log(p / q_expanded))
if kl < best_kl:
best_kl = kl
best_threshold = bin_edges[i]
return best_threshold
Quantization-Aware Training (QAT)
Copy
class QuantizationAwareTraining:
"""
Train with simulated quantization for better accuracy.
Key idea: Insert fake quantization during training
so the model learns to be robust to quantization noise.
"""
@staticmethod
def prepare_qat(
model: nn.Module,
backend: str = 'fbgemm'
) -> nn.Module:
"""Prepare model for QAT."""
model.train()
# QAT config with fake quantization
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
# Prepare for QAT (inserts fake quant modules)
model_prepared = torch.quantization.prepare_qat(model)
return model_prepared
@staticmethod
def convert_qat(model: nn.Module) -> nn.Module:
"""Convert QAT model to quantized model."""
model.eval()
model_quantized = torch.quantization.convert(model)
return model_quantized
class FakeQuantize(torch.autograd.Function):
"""
Fake quantization for training.
Forward: Quantize then immediately dequantize
Backward: Straight-through estimator (gradient unchanged)
"""
@staticmethod
def forward(ctx, x, scale, zero_point, qmin=-128, qmax=127):
# Quantize
q = torch.round(x / scale) + zero_point
q = torch.clamp(q, qmin, qmax)
# Immediately dequantize
x_q = (q - zero_point) * scale
# Save for backward (for gradient clipping)
ctx.save_for_backward(x, torch.tensor([qmin * scale, qmax * scale]))
return x_q
@staticmethod
def backward(ctx, grad_output):
x, bounds = ctx.saved_tensors
qmin_scaled, qmax_scaled = bounds[0].item(), bounds[1].item()
# Straight-through estimator with clipping
# Zero gradient for values outside quantization range
grad_input = grad_output.clone()
grad_input[x < qmin_scaled] = 0
grad_input[x > qmax_scaled] = 0
return grad_input, None, None, None, None
class QATLinear(nn.Module):
"""
Quantization-aware linear layer.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True
):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
# Observers for scale/zero_point
self.weight_fake_quant = torch.quantization.FakeQuantize(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8
)
self.activation_fake_quant = torch.quantization.FakeQuantize(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Fake quantize input
x_q = self.activation_fake_quant(x)
# Fake quantize weights
w_q = self.weight_fake_quant(self.linear.weight)
# Compute with fake-quantized values
return F.linear(x_q, w_q, self.linear.bias)
# LSQ: Learned Step Size Quantization
class LSQQuantizer(nn.Module):
"""
Learned Step Size Quantization.
Learn the quantization step size (scale) during training
instead of using fixed calibration.
Reference: "Learned Step Size Quantization" (Esser et al., 2020)
"""
def __init__(
self,
num_bits: int = 8,
symmetric: bool = True,
per_channel: bool = False,
num_channels: int = 1
):
super().__init__()
self.num_bits = num_bits
self.symmetric = symmetric
self.per_channel = per_channel
# Compute quantization bounds
if symmetric:
self.qmin = -2**(num_bits - 1)
self.qmax = 2**(num_bits - 1) - 1
else:
self.qmin = 0
self.qmax = 2**num_bits - 1
# Learnable step size (scale)
if per_channel:
self.scale = nn.Parameter(torch.ones(num_channels))
else:
self.scale = nn.Parameter(torch.tensor(1.0))
self.initialized = False
def init_scale(self, x: torch.Tensor):
"""Initialize scale based on data."""
if self.per_channel:
# Per-channel initialization
for c in range(x.shape[0]):
max_val = x[c].abs().max()
self.scale.data[c] = 2 * max_val / (self.qmax - self.qmin)
else:
max_val = x.abs().max()
self.scale.data.fill_(2 * max_val / (self.qmax - self.qmin))
self.initialized = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.initialized:
self.init_scale(x)
# Ensure positive scale
scale = self.scale.abs() + 1e-8
# Quantize
if self.per_channel:
scale_shape = [1] * x.dim()
scale_shape[0] = -1
scale = scale.view(*scale_shape)
x_scaled = x / scale
x_clipped = torch.clamp(x_scaled, self.qmin, self.qmax)
x_rounded = torch.round(x_clipped)
# Dequantize with gradient trick
x_q = (x_rounded - x_clipped).detach() + x_clipped
x_deq = x_q * scale
return x_deq
Advanced Quantization Methods
GPTQ (LLM Quantization)
Copy
class GPTQQuantizer:
"""
GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
Key idea: Use second-order information (Hessian) to minimize
quantization error layer by layer.
Simplified implementation for educational purposes.
"""
def __init__(
self,
num_bits: int = 4,
group_size: int = 128,
symmetric: bool = True
):
self.num_bits = num_bits
self.group_size = group_size
self.symmetric = symmetric
if symmetric:
self.qmin = -2**(num_bits - 1)
self.qmax = 2**(num_bits - 1) - 1
else:
self.qmin = 0
self.qmax = 2**num_bits - 1
def quantize_layer(
self,
weight: torch.Tensor,
H: torch.Tensor # Hessian (X @ X.T from calibration)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize a weight matrix using GPTQ algorithm.
Args:
weight: Weight matrix [out_features, in_features]
H: Hessian matrix [in_features, in_features]
"""
W = weight.clone().float()
d, n = W.shape
# Add damping to Hessian for numerical stability
damp = 0.01 * torch.diag(H).mean()
H = H + damp * torch.eye(n)
# Cholesky decomposition
H_inv = torch.linalg.cholesky(H)
H_inv = torch.cholesky_inverse(H_inv)
# Process in groups
Q = torch.zeros_like(W)
scales = torch.zeros(d, (n + self.group_size - 1) // self.group_size)
for i1 in range(0, n, self.group_size):
i2 = min(i1 + self.group_size, n)
# Get group weights
W_group = W[:, i1:i2].clone()
# Compute scale for group
if self.symmetric:
max_val = W_group.abs().max(dim=1, keepdim=True)[0]
scale = max_val / self.qmax
else:
min_val = W_group.min(dim=1, keepdim=True)[0]
max_val = W_group.max(dim=1, keepdim=True)[0]
scale = (max_val - min_val) / (self.qmax - self.qmin)
scale = scale.clamp(min=1e-10)
scales[:, i1 // self.group_size] = scale.squeeze()
# Quantize columns one by one
for i in range(i1, i2):
# Quantize column
q = torch.round(W[:, i] / scale.squeeze())
q = torch.clamp(q, self.qmin, self.qmax)
Q[:, i] = q
# Compute quantization error
error = W[:, i] - q * scale.squeeze()
# Update remaining weights to compensate
if i < n - 1:
W[:, i+1:] -= error.unsqueeze(1) * H_inv[i, i+1:].unsqueeze(0)
return Q.to(torch.int8), scales
def collect_hessian(
self,
model: nn.Module,
calibration_data: torch.utils.data.DataLoader,
layer_name: str
) -> torch.Tensor:
"""
Collect Hessian for a layer using calibration data.
H = sum(X @ X.T) where X is the input to the layer.
"""
H = None
def hook(module, input, output):
nonlocal H
x = input[0]
x = x.view(-1, x.shape[-1]).float()
if H is None:
H = x.T @ x
else:
H += x.T @ x
# Register hook
target_module = dict(model.named_modules())[layer_name]
handle = target_module.register_forward_hook(hook)
# Run calibration
model.eval()
with torch.no_grad():
for batch in calibration_data:
if isinstance(batch, (tuple, list)):
inputs = batch[0]
else:
inputs = batch
model(inputs)
handle.remove()
return H
class AWQQuantizer:
"""
Activation-aware Weight Quantization.
Key insight: Not all weights are equally important.
Scale weights by their activation magnitudes.
Simplified implementation.
"""
def __init__(self, num_bits: int = 4):
self.num_bits = num_bits
self.qmin = 0
self.qmax = 2**num_bits - 1
def compute_importance(
self,
weight: torch.Tensor,
activations: List[torch.Tensor]
) -> torch.Tensor:
"""
Compute importance of each weight column
based on activation magnitudes.
"""
# Average activation magnitude per input dimension
act_cat = torch.cat([a.flatten(end_dim=-2) for a in activations])
importance = act_cat.abs().mean(dim=0)
return importance
def quantize_with_scaling(
self,
weight: torch.Tensor,
importance: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantize weights with importance-aware scaling.
"""
# Find optimal scaling factor
best_error = float('inf')
best_scale = 1.0
for alpha in np.linspace(0.1, 1.0, 20):
# Scale by importance^alpha
scaling = importance.pow(alpha)
scaling = scaling / scaling.max() # Normalize
# Apply scaling
W_scaled = weight * scaling.unsqueeze(0)
# Quantize
max_val = W_scaled.abs().max()
scale = max_val / self.qmax
W_q = torch.round(W_scaled / scale)
W_q = torch.clamp(W_q, self.qmin, self.qmax)
# Dequantize and undo scaling
W_deq = W_q * scale / scaling.unsqueeze(0)
# Compute error
error = ((weight - W_deq) ** 2).mean()
if error < best_error:
best_error = error
best_scale = alpha
# Final quantization with best scaling
scaling = importance.pow(best_scale)
scaling = scaling / scaling.max()
W_scaled = weight * scaling.unsqueeze(0)
max_val = W_scaled.abs().max()
quant_scale = max_val / self.qmax
W_q = torch.round(W_scaled / quant_scale)
W_q = torch.clamp(W_q, self.qmin, self.qmax)
return W_q.to(torch.int8), quant_scale, scaling
Mixed-Precision Quantization
Copy
class MixedPrecisionQuantizer:
"""
Different bit-widths for different layers.
Sensitive layers get higher precision,
less important layers get lower precision.
"""
def __init__(self, default_bits: int = 8):
self.default_bits = default_bits
self.layer_bits = {}
def analyze_sensitivity(
self,
model: nn.Module,
calibration_data: torch.utils.data.DataLoader,
criterion: nn.Module
) -> Dict[str, float]:
"""
Analyze each layer's sensitivity to quantization.
Method: Quantize each layer individually and measure
the impact on model output/loss.
"""
sensitivity = {}
model.eval()
# Get baseline output
with torch.no_grad():
baseline_outputs = []
for batch in calibration_data:
inputs, targets = batch
outputs = model(inputs)
baseline_outputs.append(outputs)
# Test each layer
for name, module in model.named_modules():
if not isinstance(module, (nn.Linear, nn.Conv2d)):
continue
# Save original weights
orig_weight = module.weight.data.clone()
# Quantize this layer to low precision
scale, zp = QuantizationMath.compute_scale_zero_point(orig_weight)
q_weight = QuantizationMath.quantize(orig_weight, scale, zp)
deq_weight = QuantizationMath.dequantize(q_weight, scale, zp)
module.weight.data = deq_weight
# Measure impact
total_error = 0
with torch.no_grad():
for i, batch in enumerate(calibration_data):
inputs, targets = batch
outputs = model(inputs)
error = F.mse_loss(outputs, baseline_outputs[i])
total_error += error.item()
sensitivity[name] = total_error
# Restore original weights
module.weight.data = orig_weight
return sensitivity
def assign_bit_widths(
self,
sensitivity: Dict[str, float],
target_size_ratio: float = 0.25,
bit_options: List[int] = [2, 4, 8]
) -> Dict[str, int]:
"""
Assign bit-widths to layers based on sensitivity.
More sensitive layers get more bits.
"""
# Normalize sensitivity scores
max_sens = max(sensitivity.values())
normalized = {k: v / max_sens for k, v in sensitivity.items()}
# Assign bits based on sensitivity
# High sensitivity -> more bits
self.layer_bits = {}
for name, sens in normalized.items():
if sens > 0.7:
bits = max(bit_options)
elif sens > 0.3:
bits = bit_options[len(bit_options) // 2]
else:
bits = min(bit_options)
self.layer_bits[name] = bits
return self.layer_bits
def quantize_model(self, model: nn.Module) -> nn.Module:
"""Apply mixed-precision quantization."""
for name, module in model.named_modules():
if name not in self.layer_bits:
continue
bits = self.layer_bits[name]
qmin = -2**(bits - 1)
qmax = 2**(bits - 1) - 1
if hasattr(module, 'weight'):
w = module.weight.data
scale, zp = QuantizationMath.compute_scale_zero_point(
w, qmin=qmin, qmax=qmax
)
q = QuantizationMath.quantize(w, scale, zp, qmin, qmax)
module.weight.data = QuantizationMath.dequantize(q, scale, zp)
return model
Hardware-Aware Quantization
Copy
class HardwareAwareQuantization:
"""
Quantization considering target hardware constraints.
"""
HARDWARE_PROFILES = {
'cpu_x86': {
'supported_dtypes': ['int8'],
'optimal_alignment': 32,
'vectorized_ops': True
},
'gpu_nvidia': {
'supported_dtypes': ['int8', 'fp16', 'bf16', 'fp8'],
'tensor_cores': True,
'optimal_batch': 64
},
'mobile_arm': {
'supported_dtypes': ['int8', 'int4'],
'neon_simd': True,
'memory_limited': True
},
'edge_tpu': {
'supported_dtypes': ['int8'],
'requires_symmetric': True,
'batch_size': 1
}
}
def __init__(self, target: str = 'cpu_x86'):
self.target = target
self.profile = self.HARDWARE_PROFILES.get(target, {})
def get_optimal_config(self) -> Dict:
"""Get optimal quantization config for target."""
config = {
'dtype': self.profile.get('supported_dtypes', ['int8'])[0],
'symmetric': self.profile.get('requires_symmetric', False),
}
if self.target == 'gpu_nvidia':
config['use_tensor_cores'] = True
config['dtype'] = 'int8' # For INT8 tensor core ops
return config
def estimate_speedup(
self,
model: nn.Module,
original_dtype: str = 'fp32'
) -> Dict[str, float]:
"""
Estimate speedup from quantization.
Very rough estimates - actual speedup varies by hardware.
"""
speedups = {
'fp32_to_fp16': 1.5,
'fp32_to_int8': 2.5,
'fp32_to_int4': 4.0
}
target_dtype = self.profile.get('supported_dtypes', ['int8'])[0]
key = f'{original_dtype}_to_{target_dtype}'
base_speedup = speedups.get(key, 1.0)
# Adjust for hardware features
if self.profile.get('tensor_cores'):
base_speedup *= 1.5
if self.profile.get('vectorized_ops'):
base_speedup *= 1.2
return {
'compute_speedup': base_speedup,
'memory_reduction': 32 / int(target_dtype.replace('int', '').replace('fp', ''))
}
# Summary of quantization methods
def quantization_summary():
"""Print quantization methods summary."""
summary = """
╔════════════════════════════════════════════════════════════════════╗
║ QUANTIZATION METHODS SUMMARY ║
╠════════════════════════════════════════════════════════════════════╣
║ ║
║ POST-TRAINING QUANTIZATION (PTQ) ║
║ • Dynamic: Weights quantized, activations at runtime ║
║ • Static: Both quantized using calibration data ║
║ • Pros: Fast, no training needed ║
║ • Cons: May lose accuracy, especially at low bits ║
║ ║
║ QUANTIZATION-AWARE TRAINING (QAT) ║
║ • Simulate quantization during training ║
║ • Model learns to be robust to quantization ║
║ • Pros: Best accuracy ║
║ • Cons: Requires training, more complex ║
║ ║
║ LLM-SPECIFIC METHODS ║
║ • GPTQ: Second-order optimization ║
║ • AWQ: Activation-aware weight scaling ║
║ • SqueezeLLM: Sparsity + quantization ║
║ ║
║ MIXED PRECISION ║
║ • Different bits for different layers ║
║ • Balance accuracy vs compression ║
║ ║
╠════════════════════════════════════════════════════════════════════╣
║ RECOMMENDED APPROACH: ║
║ 1. Start with PTQ (fast baseline) ║
║ 2. If accuracy drops, try QAT ║
║ 3. For LLMs, use GPTQ/AWQ ║
║ 4. For production, consider mixed precision ║
╚════════════════════════════════════════════════════════════════════╝
"""
print(summary)
quantization_summary()
Exercises
Exercise 1: Compare Calibration Methods
Exercise 1: Compare Calibration Methods
Implement and compare:
- Min-max calibration
- Percentile calibration
- Entropy calibration
Exercise 2: Implement Per-Group Quantization
Exercise 2: Implement Per-Group Quantization
Implement 4-bit quantization with:
- Group size of 128
- Separate scale per group
- Compare to per-tensor quantization
Exercise 3: Mixed-Precision Search
Exercise 3: Mixed-Precision Search
Implement automated mixed-precision search:
- Use sensitivity analysis
- Optimize for target model size
- Meet accuracy constraint