Model Deployment
From Research to Production
Training is only half the battle. Production requires:
Fast inference
Minimal dependencies
Reproducibility
Monitoring
TorchScript (JIT Compilation)
import torch
model = MyModel()
model.eval()
# Option 1: Tracing (captures single path)
example_input = torch.randn( 1 , 3 , 224 , 224 )
traced_model = torch.jit.trace(model, example_input)
# Option 2: Scripting (handles control flow)
scripted_model = torch.jit.script(model)
# Save
traced_model.save( "model_traced.pt" )
# Load (no Python dependencies!)
loaded = torch.jit.load( "model_traced.pt" )
output = loaded(example_input)
Use tracing for straightforward models, scripting for models with control flow (if/else, loops).
ONNX Export
import torch.onnx
model.eval()
dummy_input = torch.randn( 1 , 3 , 224 , 224 )
torch.onnx.export(
model,
dummy_input,
"model.onnx" ,
input_names = [ "input" ],
output_names = [ "output" ],
dynamic_axes = {
"input" : { 0 : "batch_size" },
"output" : { 0 : "batch_size" },
},
opset_version = 17 ,
)
Verify the export:
import onnx
import onnxruntime as ort
# Check model
onnx_model = onnx.load( "model.onnx" )
onnx.checker.check_model(onnx_model)
# Run inference
session = ort.InferenceSession( "model.onnx" )
output = session.run( None , { "input" : dummy_input.numpy()})
Model Optimization
Quantization (INT8)
Reduce precision for faster inference:
import torch.quantization as quant
# Post-training quantization
model.eval()
model_fp32 = model
# Dynamic quantization (weights only)
model_int8 = torch.quantization.quantize_dynamic(
model_fp32,
{torch.nn.Linear}, # Layers to quantize
dtype = torch.qint8
)
# Static quantization (weights + activations)
model.qconfig = quant.get_default_qconfig( 'fbgemm' )
model_prepared = quant.prepare(model)
# Calibrate with representative data
for batch in calibration_loader:
model_prepared(batch)
model_quantized = quant.convert(model_prepared)
Model Size Comparison
Precision Model Size Speed Accuracy Drop FP32 100 MB 1x Baseline FP16 50 MB 1.5-2x ~0% INT8 25 MB 2-4x 0-1%
Pruning
Remove unimportant weights:
import torch.nn.utils.prune as prune
# Prune 30% of weights in linear layers
for module in model.modules():
if isinstance (module, torch.nn.Linear):
prune.l1_unstructured(module, name = 'weight' , amount = 0.3 )
# Make pruning permanent
for module in model.modules():
if isinstance (module, torch.nn.Linear):
prune.remove(module, 'weight' )
Serving with FastAPI
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import torch
import torchvision.transforms as T
import io
app = FastAPI()
# Load model once at startup
model = torch.jit.load( "model_traced.pt" )
model.eval()
transform = T.Compose([
T.Resize( 256 ),
T.CenterCrop( 224 ),
T.ToTensor(),
T.Normalize( mean = [ 0.485 , 0.456 , 0.406 ], std = [ 0.229 , 0.224 , 0.225 ]),
])
@app.post ( "/predict" )
async def predict ( file : UploadFile = File( ... )):
# Read image
image_bytes = await file .read()
image = Image.open(io.BytesIO(image_bytes)).convert( "RGB" )
# Preprocess
input_tensor = transform(image).unsqueeze( 0 )
# Inference
with torch.no_grad():
output = model(input_tensor)
probs = torch.softmax(output, dim = 1 )
pred_class = probs.argmax( dim = 1 ).item()
confidence = probs[ 0 , pred_class].item()
return {
"class" : pred_class,
"confidence" : confidence,
}
GPU Serving with Triton
# config.pbtxt
name: "image_classifier"
platform: "onnxruntime_onnx"
max_batch_size: 32
input [
{
name: "input"
data_type: TYPE_FP32
dims: [ 3 , 224 , 224 ]
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [ 1000 ]
}
]
instance_group [
{
count: 2
kind: KIND_GPU
}
]
dynamic_batching {
max_queue_delay_microseconds: 100
}
Docker Deployment
FROM python:3.10-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy model and code
COPY model_traced.pt .
COPY app.py .
# Expose port
EXPOSE 8000
# Run server
CMD [ "uvicorn" , "app:app" , "--host" , "0.0.0.0" , "--port" , "8000" ]
# Build and run
docker build -t model-server .
docker run -p 8000:8000 model-server
Edge Deployment
ONNX Runtime Mobile
# Export for mobile
torch.onnx.export(
model,
dummy_input,
"model_mobile.onnx" ,
opset_version = 13 , # Compatible with mobile
do_constant_folding = True ,
)
# Optimize for mobile
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
"model_mobile.onnx" ,
"model_mobile_quantized.onnx" ,
weight_type = QuantType.QUInt8,
)
TensorFlow Lite Conversion
import tensorflow as tf
# Convert ONNX to TF SavedModel first (using onnx-tf)
# Then convert to TFLite
converter = tf.lite.TFLiteConverter.from_saved_model( "saved_model" )
converter.optimizations = [tf.lite.Optimize. DEFAULT ]
tflite_model = converter.convert()
with open ( "model.tflite" , "wb" ) as f:
f.write(tflite_model)
Monitoring in Production
import time
from prometheus_client import Counter, Histogram, start_http_server
# Metrics
PREDICTIONS = Counter( 'predictions_total' , 'Total predictions' )
PREDICTION_LATENCY = Histogram( 'prediction_latency_seconds' , 'Prediction latency' )
@app.post ( "/predict" )
async def predict ( file : UploadFile = File( ... )):
start_time = time.time()
# ... inference code ...
# Record metrics
PREDICTIONS .inc()
PREDICTION_LATENCY .observe(time.time() - start_time)
return result
# Start metrics server
start_http_server( 8001 )
Deployment Checklist
Step Check Model export ✅ TorchScript/ONNX exports correctly Validation ✅ Output matches original model Optimization ✅ Quantization/pruning applied Batching ✅ Dynamic batching configured Monitoring ✅ Latency/throughput metrics Error handling ✅ Graceful failure modes Scaling ✅ Horizontal scaling ready
Exercises
Exercise 1: Export Pipeline
Export a ResNet model to both TorchScript and ONNX. Compare inference speeds.
Exercise 2: Quantization Impact
Apply INT8 quantization to a model. Measure size reduction and accuracy change.
Exercise 3: FastAPI Service
Build a complete image classification API with proper error handling and documentation.
What’s Next