Skip to main content
Model Deployment

Model Deployment

From Research to Production

Training is only half the battle. Production requires:
  • Fast inference
  • Minimal dependencies
  • Reproducibility
  • Monitoring

Export Formats

TorchScript (JIT Compilation)

import torch

model = MyModel()
model.eval()

# Option 1: Tracing (captures single path)
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

# Option 2: Scripting (handles control flow)
scripted_model = torch.jit.script(model)

# Save
traced_model.save("model_traced.pt")

# Load (no Python dependencies!)
loaded = torch.jit.load("model_traced.pt")
output = loaded(example_input)
Use tracing for straightforward models, scripting for models with control flow (if/else, loops).

ONNX Export

import torch.onnx

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    opset_version=17,
)
Verify the export:
import onnx
import onnxruntime as ort

# Check model
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

# Run inference
session = ort.InferenceSession("model.onnx")
output = session.run(None, {"input": dummy_input.numpy()})

Model Optimization

Quantization (INT8)

Reduce precision for faster inference:
import torch.quantization as quant

# Post-training quantization
model.eval()
model_fp32 = model

# Dynamic quantization (weights only)
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,
    {torch.nn.Linear},  # Layers to quantize
    dtype=torch.qint8
)

# Static quantization (weights + activations)
model.qconfig = quant.get_default_qconfig('fbgemm')
model_prepared = quant.prepare(model)

# Calibrate with representative data
for batch in calibration_loader:
    model_prepared(batch)

model_quantized = quant.convert(model_prepared)

Model Size Comparison

PrecisionModel SizeSpeedAccuracy Drop
FP32100 MB1xBaseline
FP1650 MB1.5-2x~0%
INT825 MB2-4x0-1%

Pruning

Remove unimportant weights:
import torch.nn.utils.prune as prune

# Prune 30% of weights in linear layers
for module in model.modules():
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.3)

# Make pruning permanent
for module in model.modules():
    if isinstance(module, torch.nn.Linear):
        prune.remove(module, 'weight')

Serving with FastAPI

from fastapi import FastAPI, File, UploadFile
from PIL import Image
import torch
import torchvision.transforms as T
import io

app = FastAPI()

# Load model once at startup
model = torch.jit.load("model_traced.pt")
model.eval()

transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # Read image
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    
    # Preprocess
    input_tensor = transform(image).unsqueeze(0)
    
    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.softmax(output, dim=1)
        pred_class = probs.argmax(dim=1).item()
        confidence = probs[0, pred_class].item()
    
    return {
        "class": pred_class,
        "confidence": confidence,
    }

GPU Serving with Triton

# config.pbtxt
name: "image_classifier"
platform: "onnxruntime_onnx"
max_batch_size: 32

input [
  {
    name: "input"
    data_type: TYPE_FP32
    dims: [3, 224, 224]
  }
]

output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [1000]
  }
]

instance_group [
  {
    count: 2
    kind: KIND_GPU
  }
]

dynamic_batching {
  max_queue_delay_microseconds: 100
}

Docker Deployment

FROM python:3.10-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy model and code
COPY model_traced.pt .
COPY app.py .

# Expose port
EXPOSE 8000

# Run server
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
# Build and run
docker build -t model-server .
docker run -p 8000:8000 model-server

Edge Deployment

ONNX Runtime Mobile

# Export for mobile
torch.onnx.export(
    model,
    dummy_input,
    "model_mobile.onnx",
    opset_version=13,  # Compatible with mobile
    do_constant_folding=True,
)

# Optimize for mobile
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    "model_mobile.onnx",
    "model_mobile_quantized.onnx",
    weight_type=QuantType.QUInt8,
)

TensorFlow Lite Conversion

import tensorflow as tf

# Convert ONNX to TF SavedModel first (using onnx-tf)
# Then convert to TFLite
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open("model.tflite", "wb") as f:
    f.write(tflite_model)

Monitoring in Production

import time
from prometheus_client import Counter, Histogram, start_http_server

# Metrics
PREDICTIONS = Counter('predictions_total', 'Total predictions')
PREDICTION_LATENCY = Histogram('prediction_latency_seconds', 'Prediction latency')

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    start_time = time.time()
    
    # ... inference code ...
    
    # Record metrics
    PREDICTIONS.inc()
    PREDICTION_LATENCY.observe(time.time() - start_time)
    
    return result

# Start metrics server
start_http_server(8001)

Deployment Checklist

StepCheck
Model export✅ TorchScript/ONNX exports correctly
Validation✅ Output matches original model
Optimization✅ Quantization/pruning applied
Batching✅ Dynamic batching configured
Monitoring✅ Latency/throughput metrics
Error handling✅ Graceful failure modes
Scaling✅ Horizontal scaling ready

Exercises

Export a ResNet model to both TorchScript and ONNX. Compare inference speeds.
Apply INT8 quantization to a model. Measure size reduction and accuracy change.
Build a complete image classification API with proper error handling and documentation.

What’s Next