Skip to main content
Scalability Patterns - From Single Server to Millions
Senior Level: This covers advanced scaling patterns expected at L5+ interviews. Know when and why to apply each pattern.

Horizontal vs Vertical: The Real Trade-offs

Vertical vs Horizontal Scaling
Interview Insight: Don’t immediately jump to “scale horizontally.” First ask: “What’s the actual bottleneck?” Sometimes a bigger machine or query optimization is the right answer.

Stateless vs Stateful Services

Making Services Stateless

Stateless Architecture

When Stateful is Okay

Stateful services are fine when:
• WebSocket connections (natural affinity)
• Real-time gaming (session state)
• In-memory caching (local cache + distributed)
• Batch processing (worker owns work)

Key: Design for graceful degradation when state is lost

Caching at Scale

Multi-Level Caching

Multi-Level Caching

Multi-Level Cache Implementation

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional, Any, Dict, List, Callable
from datetime import datetime, timedelta
import asyncio
import hashlib
import json
import logging

logger = logging.getLogger(__name__)

@dataclass
class CacheEntry:
    value: Any
    created_at: datetime
    ttl: timedelta
    tags: List[str] = field(default_factory=list)
    
    @property
    def is_expired(self) -> bool:
        return datetime.now() > self.created_at + self.ttl
    
    @property
    def remaining_ttl(self) -> timedelta:
        remaining = (self.created_at + self.ttl) - datetime.now()
        return max(timedelta(0), remaining)


class CacheLayer(ABC):
    """Abstract cache layer"""
    
    @abstractmethod
    async def get(self, key: str) -> Optional[CacheEntry]:
        pass
    
    @abstractmethod
    async def set(self, key: str, entry: CacheEntry) -> None:
        pass
    
    @abstractmethod
    async def delete(self, key: str) -> None:
        pass
    
    @abstractmethod
    async def delete_by_tag(self, tag: str) -> int:
        pass


class LocalCache(CacheLayer):
    """L1: In-process memory cache"""
    
    def __init__(self, max_size: int = 10000):
        self.cache: Dict[str, CacheEntry] = {}
        self.max_size = max_size
        self.access_order: List[str] = []
        self.tag_index: Dict[str, set] = {}
    
    async def get(self, key: str) -> Optional[CacheEntry]:
        entry = self.cache.get(key)
        if entry and not entry.is_expired:
            # Move to end (LRU)
            if key in self.access_order:
                self.access_order.remove(key)
            self.access_order.append(key)
            return entry
        elif entry:
            await self.delete(key)
        return None
    
    async def set(self, key: str, entry: CacheEntry) -> None:
        # Evict if full
        while len(self.cache) >= self.max_size:
            oldest = self.access_order.pop(0)
            del self.cache[oldest]
        
        self.cache[key] = entry
        self.access_order.append(key)
        
        # Update tag index
        for tag in entry.tags:
            if tag not in self.tag_index:
                self.tag_index[tag] = set()
            self.tag_index[tag].add(key)
    
    async def delete(self, key: str) -> None:
        if key in self.cache:
            entry = self.cache.pop(key)
            if key in self.access_order:
                self.access_order.remove(key)
            for tag in entry.tags:
                if tag in self.tag_index:
                    self.tag_index[tag].discard(key)
    
    async def delete_by_tag(self, tag: str) -> int:
        keys = self.tag_index.get(tag, set()).copy()
        for key in keys:
            await self.delete(key)
        return len(keys)


class RedisCache(CacheLayer):
    """L2: Distributed Redis cache"""
    
    def __init__(self, redis_client, prefix: str = "cache"):
        self.redis = redis_client
        self.prefix = prefix
    
    def _key(self, key: str) -> str:
        return f"{self.prefix}:{key}"
    
    def _tag_key(self, tag: str) -> str:
        return f"{self.prefix}:tag:{tag}"
    
    async def get(self, key: str) -> Optional[CacheEntry]:
        data = await self.redis.get(self._key(key))
        if data:
            entry_data = json.loads(data)
            return CacheEntry(
                value=entry_data["value"],
                created_at=datetime.fromisoformat(entry_data["created_at"]),
                ttl=timedelta(seconds=entry_data["ttl_seconds"]),
                tags=entry_data.get("tags", [])
            )
        return None
    
    async def set(self, key: str, entry: CacheEntry) -> None:
        data = json.dumps({
            "value": entry.value,
            "created_at": entry.created_at.isoformat(),
            "ttl_seconds": entry.ttl.total_seconds(),
            "tags": entry.tags
        })
        
        ttl_seconds = int(entry.remaining_ttl.total_seconds())
        if ttl_seconds > 0:
            await self.redis.setex(self._key(key), ttl_seconds, data)
            
            # Update tag sets
            for tag in entry.tags:
                await self.redis.sadd(self._tag_key(tag), key)
    
    async def delete(self, key: str) -> None:
        await self.redis.delete(self._key(key))
    
    async def delete_by_tag(self, tag: str) -> int:
        keys = await self.redis.smembers(self._tag_key(tag))
        if keys:
            await self.redis.delete(*[self._key(k) for k in keys])
            await self.redis.delete(self._tag_key(tag))
        return len(keys)


class MultiLevelCache:
    """
    Multi-level cache with automatic promotion/demotion.
    L1: Local in-memory (fastest, smallest)
    L2: Redis distributed (fast, shared)
    L3: Database (source of truth)
    """
    
    def __init__(
        self,
        l1_cache: LocalCache,
        l2_cache: RedisCache,
        db_loader: Callable,
        default_ttl: timedelta = timedelta(minutes=5)
    ):
        self.l1 = l1_cache
        self.l2 = l2_cache
        self.db_loader = db_loader
        self.default_ttl = default_ttl
        
        # Metrics
        self.hits = {"l1": 0, "l2": 0, "db": 0}
        self.misses = 0
    
    async def get(
        self, 
        key: str, 
        ttl: Optional[timedelta] = None,
        tags: List[str] = None
    ) -> Optional[Any]:
        """Get with automatic cache population"""
        
        # Try L1 (local)
        entry = await self.l1.get(key)
        if entry:
            self.hits["l1"] += 1
            return entry.value
        
        # Try L2 (Redis)
        entry = await self.l2.get(key)
        if entry:
            self.hits["l2"] += 1
            # Promote to L1
            await self.l1.set(key, entry)
            return entry.value
        
        # Load from DB
        value = await self.db_loader(key)
        if value is not None:
            self.hits["db"] += 1
            entry = CacheEntry(
                value=value,
                created_at=datetime.now(),
                ttl=ttl or self.default_ttl,
                tags=tags or []
            )
            # Populate both cache levels
            await asyncio.gather(
                self.l1.set(key, entry),
                self.l2.set(key, entry)
            )
            return value
        
        self.misses += 1
        return None
    
    async def invalidate(self, key: str) -> None:
        """Invalidate across all levels"""
        await asyncio.gather(
            self.l1.delete(key),
            self.l2.delete(key)
        )
    
    async def invalidate_by_tag(self, tag: str) -> int:
        """Invalidate all entries with a tag"""
        l1_count = await self.l1.delete_by_tag(tag)
        l2_count = await self.l2.delete_by_tag(tag)
        return max(l1_count, l2_count)
    
    def get_hit_rates(self) -> Dict[str, float]:
        total = sum(self.hits.values()) + self.misses
        if total == 0:
            return {"l1": 0, "l2": 0, "db": 0, "miss": 0}
        return {
            "l1": self.hits["l1"] / total,
            "l2": self.hits["l2"] / total,
            "db": self.hits["db"] / total,
            "miss": self.misses / total
        }


# Usage example
async def create_cache_system(redis_client, db):
    async def load_user(key: str):
        # key format: "user:123"
        user_id = key.split(":")[1]
        return await db.fetch_user(user_id)
    
    cache = MultiLevelCache(
        l1_cache=LocalCache(max_size=1000),
        l2_cache=RedisCache(redis_client, prefix="app"),
        db_loader=load_user,
        default_ttl=timedelta(minutes=10)
    )
    
    # Get user (auto-populates cache)
    user = await cache.get(
        "user:123",
        tags=["users", "user:123"]
    )
    
    # Invalidate on update
    await cache.invalidate("user:123")
    
    # Invalidate all users
    await cache.invalidate_by_tag("users")
    
    # Check hit rates
    print(cache.get_hit_rates())

Cache Consistency Strategies

# Strategy 1: Cache-Aside with TTL (Simple)
def get_user(user_id):
    # 1. Try cache
    user = cache.get(f"user:{user_id}")
    if user:
        return user
    
    # 2. Load from DB
    user = db.query("SELECT * FROM users WHERE id = ?", user_id)
    
    # 3. Set cache with TTL
    cache.set(f"user:{user_id}", user, ttl=300)  # 5 min
    return user

def update_user(user_id, data):
    # Update DB
    db.update(user_id, data)
    # Invalidate cache
    cache.delete(f"user:{user_id}")


# Strategy 2: Write-Through (Strong Consistency)
def update_user(user_id, data):
    with transaction():
        # Update DB
        db.update(user_id, data)
        # Update cache in same transaction
        cache.set(f"user:{user_id}", data)


# Strategy 3: Event-Driven Invalidation (Best for microservices)
def update_user(user_id, data):
    db.update(user_id, data)
    # Publish event
    event_bus.publish("user.updated", {"user_id": user_id})

# Cache service listens for events
@event_handler("user.updated")
def invalidate_user_cache(event):
    cache.delete(f"user:{event.user_id}")

Database Scaling Patterns

Read Replicas Pattern

Read Replicas
class ReadWriteRouter:
    def __init__(self, primary, replicas):
        self.primary = primary
        self.replicas = replicas
        self.replica_index = 0
    
    def get_connection(self, query_type: str, user_session=None):
        if query_type == "write":
            return self.primary
        
        # Read-your-writes: Check if user recently wrote
        if user_session and user_session.last_write_time:
            if time.time() - user_session.last_write_time < 5:
                # Recent write, use primary to avoid stale reads
                return self.primary
        
        # Round-robin across replicas
        replica = self.replicas[self.replica_index % len(self.replicas)]
        self.replica_index += 1
        return replica

Sharding Strategies Deep Dive

Sharding Strategies

Cross-Shard Operations

# Problem: Query needs data from multiple shards

# Solution 1: Scatter-Gather
async def get_user_orders_all_time(user_id):
    # User's orders might be on different time-based shards
    shard_ids = get_all_shards()
    
    # Query all shards in parallel
    tasks = [query_shard(shard, user_id) for shard in shard_ids]
    results = await asyncio.gather(*tasks)
    
    # Merge results
    return merge_and_sort(results)

# Solution 2: Denormalization
# Store frequently-joined data together
# Instead of: users shard + orders shard + products shard
# Store: user_orders (denormalized) on user's shard

# Solution 3: Global Tables
# Some tables replicated to all shards (read-only)
# Example: countries, currencies, product categories

Async Processing Patterns

Task Queue Architecture

Task Queue
class ReliableTaskProcessor:
    """
    Production-grade task processor with exactly-once semantics
    """
    
    def __init__(self, queue, db, max_retries=3):
        self.queue = queue
        self.db = db
        self.max_retries = max_retries
    
    async def process_task(self, task):
        task_id = task.id
        
        # Idempotency check
        if await self.is_processed(task_id):
            await self.queue.ack(task)
            return
        
        try:
            # Process with timeout
            async with asyncio.timeout(30):
                result = await self.do_work(task)
            
            # Mark as processed (atomically with result storage)
            await self.mark_completed(task_id, result)
            await self.queue.ack(task)
            
        except asyncio.TimeoutError:
            await self.handle_timeout(task)
            
        except RetryableError as e:
            await self.retry_or_dlq(task, e)
            
        except Exception as e:
            # Non-retryable, send to DLQ immediately
            await self.send_to_dlq(task, e)
    
    async def retry_or_dlq(self, task, error):
        if task.retry_count < self.max_retries:
            delay = 2 ** task.retry_count  # Exponential backoff
            await self.queue.retry(task, delay_seconds=delay)
        else:
            await self.send_to_dlq(task, error)

Event-Driven Architecture

Event-Driven Architecture

Load Shedding & Backpressure

Graceful Degradation

Load Shedding
import asyncio
import time
import random
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Optional, Dict, Any
from enum import Enum
import logging

logger = logging.getLogger(__name__)

class Priority(Enum):
    CRITICAL = 0  # Never shed (health checks, admin)
    HIGH = 1      # Shed last (paid users, transactions)
    NORMAL = 2    # Standard requests
    LOW = 3       # Shed first (analytics, prefetch)

@dataclass
class LoadShedderConfig:
    target_latency_ms: float = 100.0
    max_latency_ms: float = 500.0
    window_size: int = 100
    adjustment_interval: float = 1.0  # seconds
    min_accept_rate: float = 0.1  # Always accept 10%

class AdaptiveLoadShedder:
    """
    Adaptively sheds load based on system health metrics.
    Uses latency as the primary signal with priority-based shedding.
    """
    
    def __init__(self, config: LoadShedderConfig = None):
        self.config = config or LoadShedderConfig()
        self.latencies = deque(maxlen=self.config.window_size)
        self.shed_rates: Dict[Priority, float] = {
            Priority.CRITICAL: 0.0,  # Never shed
            Priority.HIGH: 0.0,
            Priority.NORMAL: 0.0,
            Priority.LOW: 0.0
        }
        self.last_adjustment = time.time()
        
        # Metrics
        self.total_requests = 0
        self.shed_requests = 0
        self.accepted_requests = 0
    
    def should_accept(self, priority: Priority = Priority.NORMAL) -> bool:
        """Determine if request should be accepted"""
        self.total_requests += 1
        
        # Critical requests always pass
        if priority == Priority.CRITICAL:
            self.accepted_requests += 1
            return True
        
        # Check shed rate for this priority
        shed_rate = self.shed_rates[priority]
        if random.random() < shed_rate:
            self.shed_requests += 1
            logger.debug(f"Shedding {priority.name} request (rate: {shed_rate:.2%})")
            return False
        
        self.accepted_requests += 1
        return True
    
    def record_latency(self, latency_ms: float) -> None:
        """Record request latency and adjust shed rates"""
        self.latencies.append(latency_ms)
        
        # Adjust periodically
        if time.time() - self.last_adjustment > self.config.adjustment_interval:
            self._adjust_shed_rates()
            self.last_adjustment = time.time()
    
    def _adjust_shed_rates(self) -> None:
        """Adjust shed rates based on current latency"""
        if len(self.latencies) < 10:
            return
        
        sorted_latencies = sorted(self.latencies)
        p50 = sorted_latencies[len(sorted_latencies) // 2]
        p99 = sorted_latencies[int(len(sorted_latencies) * 0.99)]
        
        # Calculate pressure based on latency
        if p99 > self.config.max_latency_ms:
            # Emergency: shed aggressively
            pressure = 0.8
        elif p99 > self.config.target_latency_ms * 2:
            # High pressure
            pressure = 0.5
        elif p99 > self.config.target_latency_ms:
            # Moderate pressure
            pressure = 0.2
        elif p99 < self.config.target_latency_ms * 0.5:
            # Low pressure: recover
            pressure = -0.2
        else:
            pressure = 0.0
        
        # Apply pressure to each priority level differently
        priority_multipliers = {
            Priority.LOW: 1.5,      # Shed first
            Priority.NORMAL: 1.0,
            Priority.HIGH: 0.3,     # Shed last
        }
        
        for priority, multiplier in priority_multipliers.items():
            current = self.shed_rates[priority]
            adjustment = pressure * 0.1 * multiplier
            new_rate = max(0.0, min(0.9, current + adjustment))
            self.shed_rates[priority] = new_rate
        
        logger.info(
            f"Load shedder adjusted: p50={p50:.1f}ms p99={p99:.1f}ms "
            f"rates={{{k.name}: {v:.2%} for k, v in self.shed_rates.items()}}}"
        )
    
    def get_metrics(self) -> Dict[str, Any]:
        return {
            "total_requests": self.total_requests,
            "shed_requests": self.shed_requests,
            "shed_rate": self.shed_requests / max(1, self.total_requests),
            "current_shed_rates": {
                k.name: v for k, v in self.shed_rates.items()
            },
            "latency_p50": sorted(self.latencies)[len(self.latencies) // 2] if self.latencies else 0,
            "latency_p99": sorted(self.latencies)[int(len(self.latencies) * 0.99)] if self.latencies else 0
        }


class TokenBucketRateLimiter:
    """
    Token bucket for rate limiting with burst support.
    Use alongside load shedding for complete traffic management.
    """
    
    def __init__(
        self,
        rate: float,  # tokens per second
        bucket_size: int = None,  # max burst
        initial_tokens: int = None
    ):
        self.rate = rate
        self.bucket_size = bucket_size or int(rate * 2)
        self.tokens = initial_tokens if initial_tokens is not None else self.bucket_size
        self.last_update = time.time()
    
    def acquire(self, tokens: int = 1) -> bool:
        """Try to acquire tokens. Returns True if successful."""
        self._refill()
        
        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        return False
    
    def _refill(self) -> None:
        """Add tokens based on elapsed time"""
        now = time.time()
        elapsed = now - self.last_update
        self.last_update = now
        
        self.tokens = min(
            self.bucket_size,
            self.tokens + (elapsed * self.rate)
        )


class GracefulDegradationManager:
    """
    Manages graceful degradation strategies based on load.
    """
    
    def __init__(self):
        self.load_shedder = AdaptiveLoadShedder()
        self.degradation_level = 0  # 0=normal, 1=reduced, 2=minimal, 3=emergency
        self.feature_flags = {
            "recommendations": True,
            "analytics": True,
            "search_suggestions": True,
            "full_search": True,
            "image_processing": True,
            "notifications": True
        }
    
    def update_degradation_level(self) -> None:
        """Update degradation level based on metrics"""
        metrics = self.load_shedder.get_metrics()
        p99 = metrics.get("latency_p99", 0)
        
        if p99 > 1000:  # > 1 second
            self.degradation_level = 3
        elif p99 > 500:
            self.degradation_level = 2
        elif p99 > 200:
            self.degradation_level = 1
        else:
            self.degradation_level = 0
        
        self._apply_degradation()
    
    def _apply_degradation(self) -> None:
        """Disable features based on degradation level"""
        levels = {
            0: [],  # All features enabled
            1: ["analytics", "recommendations"],
            2: ["analytics", "recommendations", "search_suggestions", "notifications"],
            3: ["analytics", "recommendations", "search_suggestions", 
                "notifications", "image_processing"]
        }
        
        disabled = levels.get(self.degradation_level, [])
        for feature in self.feature_flags:
            self.feature_flags[feature] = feature not in disabled
        
        logger.warning(
            f"Degradation level: {self.degradation_level}, "
            f"Disabled: {disabled}"
        )
    
    def is_feature_enabled(self, feature: str) -> bool:
        return self.feature_flags.get(feature, True)


# FastAPI integration
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.middleware.base import BaseHTTPMiddleware

app = FastAPI()
degradation_manager = GracefulDegradationManager()

class LoadSheddingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Determine priority from request
        priority = self._get_priority(request)
        
        if not degradation_manager.load_shedder.should_accept(priority):
            return Response(
                content='{"error": "Service temporarily overloaded"}',
                status_code=503,
                headers={"Retry-After": "5"}
            )
        
        start = time.time()
        response = await call_next(request)
        latency_ms = (time.time() - start) * 1000
        
        degradation_manager.load_shedder.record_latency(latency_ms)
        degradation_manager.update_degradation_level()
        
        return response
    
    def _get_priority(self, request: Request) -> Priority:
        # Health checks are critical
        if request.url.path == "/health":
            return Priority.CRITICAL
        # Premium users get high priority
        if request.headers.get("X-Premium-User"):
            return Priority.HIGH
        # Analytics/prefetch are low priority
        if request.url.path.startswith("/analytics"):
            return Priority.LOW
        return Priority.NORMAL

app.add_middleware(LoadSheddingMiddleware)

@app.get("/recommendations")
async def get_recommendations():
    if not degradation_manager.is_feature_enabled("recommendations"):
        return {"items": [], "degraded": True}
    return {"items": ["rec1", "rec2", "rec3"]}

Senior Interview Questions

Approach:
  1. Identify the bottleneck: Is it DB? Network? Application?
  2. Batching: Combine multiple writes into one
  3. Async writes: Write to queue, persist later
  4. Sharding: Distribute writes across nodes
  5. LSM-tree databases: Cassandra, RocksDB (optimized for writes)
Example answer: “First, I’d batch writes on the application side - instead of 1000 individual inserts, do bulk insert. Then add a queue like Kafka as a buffer. Finally, use a write-optimized database like Cassandra if the volume is truly massive.”
Framework:
  1. Current baseline: Measure current QPS, latency, resource usage
  2. Growth projection: Expected traffic increase (e.g., 2x in 6 months)
  3. Headroom: Plan for 3x current load (for spikes)
  4. Load testing: Verify system handles projected load
  5. Monitoring: Track capacity metrics, alert at 70% utilization
Key metrics to track:
  • CPU utilization by service
  • Memory usage and GC pressure
  • Database connections and query latency
  • Queue depth and processing rate
  • Network bandwidth
Safe migration strategy:
  1. Dual-write: Write to both old and new schema
  2. Backfill: Migrate historical data in batches
  3. Shadow read: Read from new, compare with old
  4. Cutover: Switch reads to new schema
  5. Cleanup: Remove dual-write, drop old schema
Key principles:
  • Never lock tables in production
  • Migrations must be reversible
  • Test with production-sized data
  • Have a rollback plan
  • Do it during low-traffic periods
Systematic approach:
  1. Identify scope: All users? Some? Specific data?
  2. Check cache layers: CDN, app cache, Redis, DB cache
  3. Verify TTLs: Are caches expiring correctly?
  4. Check replication: Is replica lagging?
  5. Trace the write: Did write actually succeed?
Common causes:
  • Cache not being invalidated on write
  • Reading from stale replica
  • CDN caching dynamic content
  • Race condition between cache invalidation and read