Skip to main content

E2E Encryption with AI Agents

The fundamental challenge of healthcare AI: LLMs need plaintext to process, but E2E encryption means only endpoints have plaintext. This module explores practical solutions for HIPAA-compliant AI chat systems.
Learning Objectives:
  • Understand the encryption-AI tension
  • Implement Signal Protocol for healthcare chat
  • Explore secure enclaves and TEEs
  • Design privacy-preserving AI architectures
  • Build HIPAA-compliant AI medical assistants

The Fundamental Tension

The fundamental tension between E2E encryption and AI processing

The E2E Encryption + AI Challenge

The core challenge is that LLMs need plaintext to process data, while E2E encryption ensures only endpoints have access to plaintext.

Solution Architecture Overview

There is no perfect solution, but several practical approaches exist:

Secure Enclaves (TEE)

Process data in hardware-isolated environments. Highest security, complex to implement.

On-Premise LLMs

Deploy models within your infrastructure. Full control, significant cost.

Endpoint Processing

Run smaller models on user devices. Privacy-first, limited capability.

Hybrid Architecture

Combine approaches based on data sensitivity. Most practical for real-world use.

E2E Encrypted Chat Foundation

Signal Protocol Implementation

Before adding AI, let’s build proper E2E encrypted chat:
"""
Signal Protocol Implementation for Healthcare Chat

Components:
1. X3DH - Extended Triple Diffie-Hellman (initial key exchange)
2. Double Ratchet - Forward secrecy per message
3. Sesame - Multi-device support
"""

from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives import hashes, serialization
from dataclasses import dataclass, field
from typing import Optional, Tuple, List
import os
import hashlib

@dataclass
class IdentityKeyPair:
    """Long-term identity keys"""
    private_key: Ed25519PrivateKey
    public_key: bytes
    
    @classmethod
    def generate(cls) -> "IdentityKeyPair":
        private = Ed25519PrivateKey.generate()
        public = private.public_key().public_bytes(
            serialization.Encoding.Raw,
            serialization.PublicFormat.Raw
        )
        return cls(private_key=private, public_key=public)

@dataclass
class PreKeyBundle:
    """Published keys for X3DH"""
    identity_key: bytes
    signed_prekey: bytes
    signed_prekey_signature: bytes
    one_time_prekey: Optional[bytes] = None

@dataclass
class DoubleRatchetState:
    """State for Double Ratchet algorithm"""
    root_key: bytes
    chain_key_sending: bytes
    chain_key_receiving: bytes
    dh_sending: X25519PrivateKey
    dh_receiving: Optional[X25519PublicKey] = None
    message_number_sending: int = 0
    message_number_receiving: int = 0
    previous_chain_length: int = 0
    skipped_keys: dict = field(default_factory=dict)


class SignalProtocol:
    """
    Signal Protocol implementation for E2E encryption
    
    Provides:
    - Perfect forward secrecy
    - Future secrecy (break-in recovery)
    - Deniability
    """
    
    def __init__(self):
        self.sessions: dict[str, DoubleRatchetState] = {}
        
    def x3dh_initiator(
        self,
        our_identity: IdentityKeyPair,
        their_prekey_bundle: PreKeyBundle
    ) -> Tuple[bytes, bytes]:
        """
        X3DH key agreement (initiator side)
        
        Computes shared secret from:
        - Our identity key + their signed prekey
        - Our ephemeral key + their identity key
        - Our ephemeral key + their signed prekey
        - Our ephemeral key + their one-time prekey (if available)
        """
        
        # Generate ephemeral key pair
        ephemeral_private = X25519PrivateKey.generate()
        ephemeral_public = ephemeral_private.public_key().public_bytes(
            serialization.Encoding.Raw,
            serialization.PublicFormat.Raw
        )
        
        # Load their keys
        their_identity = X25519PublicKey.from_public_bytes(
            their_prekey_bundle.identity_key
        )
        their_signed_prekey = X25519PublicKey.from_public_bytes(
            their_prekey_bundle.signed_prekey
        )
        
        # Convert our identity key for X25519
        our_identity_x25519 = self._ed25519_to_x25519_private(
            our_identity.private_key
        )
        
        # DH1: Our identity * their signed prekey
        dh1 = our_identity_x25519.exchange(their_signed_prekey)
        
        # DH2: Our ephemeral * their identity
        dh2 = ephemeral_private.exchange(their_identity)
        
        # DH3: Our ephemeral * their signed prekey
        dh3 = ephemeral_private.exchange(their_signed_prekey)
        
        # DH4: Our ephemeral * their one-time prekey (if available)
        if their_prekey_bundle.one_time_prekey:
            their_otpk = X25519PublicKey.from_public_bytes(
                their_prekey_bundle.one_time_prekey
            )
            dh4 = ephemeral_private.exchange(their_otpk)
        else:
            dh4 = b""
            
        # Derive shared secret
        shared_secret = self._kdf(dh1 + dh2 + dh3 + dh4, b"X3DH")
        
        return shared_secret, ephemeral_public
    
    def initialize_double_ratchet(
        self,
        session_id: str,
        shared_secret: bytes,
        their_ratchet_key: X25519PublicKey
    ):
        """Initialize Double Ratchet with X3DH shared secret"""
        
        # Generate our first ratchet key pair
        our_ratchet = X25519PrivateKey.generate()
        
        # Derive initial keys
        root_key, chain_key = self._kdf_rk(
            shared_secret,
            our_ratchet.exchange(their_ratchet_key)
        )
        
        self.sessions[session_id] = DoubleRatchetState(
            root_key=root_key,
            chain_key_sending=chain_key,
            chain_key_receiving=b"",
            dh_sending=our_ratchet,
            dh_receiving=their_ratchet_key,
        )
    
    def encrypt_message(
        self,
        session_id: str,
        plaintext: bytes
    ) -> dict:
        """
        Encrypt a message using Double Ratchet
        
        Each message uses a unique key derived from the chain key,
        providing perfect forward secrecy.
        """
        
        state = self.sessions[session_id]
        
        # Derive message key from chain key
        message_key, new_chain_key = self._kdf_ck(state.chain_key_sending)
        state.chain_key_sending = new_chain_key
        
        # Encrypt with AEAD
        nonce = os.urandom(12)
        aesgcm = AESGCM(message_key)
        ciphertext = aesgcm.encrypt(nonce, plaintext, None)
        
        # Prepare header
        header = {
            "dh_public": state.dh_sending.public_key().public_bytes(
                serialization.Encoding.Raw,
                serialization.PublicFormat.Raw
            ).hex(),
            "message_number": state.message_number_sending,
            "previous_chain_length": state.previous_chain_length,
        }
        
        state.message_number_sending += 1
        
        return {
            "header": header,
            "nonce": nonce.hex(),
            "ciphertext": ciphertext.hex(),
        }
    
    def decrypt_message(
        self,
        session_id: str,
        encrypted: dict
    ) -> bytes:
        """Decrypt a message, handling out-of-order delivery"""
        
        state = self.sessions[session_id]
        header = encrypted["header"]
        
        # Check if this triggers a DH ratchet
        their_dh = X25519PublicKey.from_public_bytes(
            bytes.fromhex(header["dh_public"])
        )
        
        if their_dh != state.dh_receiving:
            # Perform DH ratchet
            self._dh_ratchet(state, their_dh)
        
        # Derive message key
        message_key, new_chain_key = self._kdf_ck(state.chain_key_receiving)
        state.chain_key_receiving = new_chain_key
        
        # Decrypt
        nonce = bytes.fromhex(encrypted["nonce"])
        ciphertext = bytes.fromhex(encrypted["ciphertext"])
        
        aesgcm = AESGCM(message_key)
        plaintext = aesgcm.decrypt(nonce, ciphertext, None)
        
        state.message_number_receiving += 1
        
        return plaintext
    
    def _dh_ratchet(self, state: DoubleRatchetState, their_dh: X25519PublicKey):
        """Perform DH ratchet step"""
        
        state.previous_chain_length = state.message_number_sending
        state.message_number_sending = 0
        state.message_number_receiving = 0
        state.dh_receiving = their_dh
        
        # Derive new receiving chain key
        dh_output = state.dh_sending.exchange(their_dh)
        state.root_key, state.chain_key_receiving = self._kdf_rk(
            state.root_key, dh_output
        )
        
        # Generate new sending key pair
        state.dh_sending = X25519PrivateKey.generate()
        dh_output = state.dh_sending.exchange(their_dh)
        state.root_key, state.chain_key_sending = self._kdf_rk(
            state.root_key, dh_output
        )
    
    def _kdf(self, input_key_material: bytes, info: bytes) -> bytes:
        """Key derivation function"""
        hkdf = HKDF(
            algorithm=hashes.SHA256(),
            length=32,
            salt=b"\x00" * 32,
            info=info,
        )
        return hkdf.derive(input_key_material)
    
    def _kdf_rk(self, root_key: bytes, dh_output: bytes) -> Tuple[bytes, bytes]:
        """Root key KDF - derives new root key and chain key"""
        output = HKDF(
            algorithm=hashes.SHA256(),
            length=64,
            salt=root_key,
            info=b"RatchetKDF",
        ).derive(dh_output)
        return output[:32], output[32:]
    
    def _kdf_ck(self, chain_key: bytes) -> Tuple[bytes, bytes]:
        """Chain key KDF - derives message key and new chain key"""
        message_key = hashlib.sha256(chain_key + b"\x01").digest()
        new_chain_key = hashlib.sha256(chain_key + b"\x02").digest()
        return message_key, new_chain_key

Architecture 1: Secure Enclaves (TEE)

Trusted Execution Environment Approach

┌─────────────────────────────────────────────────────────────────────────────┐
│                    SECURE ENCLAVE ARCHITECTURE                               │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│   ┌────────────┐              ┌─────────────────────────────────────────┐   │
│   │   Client   │              │         SECURE ENCLAVE (TEE)            │   │
│   │            │              │  ┌───────────────────────────────────┐  │   │
│   │  Encrypt   │              │  │     Decryption happens here      │  │   │
│   │  with      │──[E2E]──────►│  │                                   │  │   │
│   │  enclave   │              │  │     ┌─────────────────────┐       │  │   │
│   │  public    │              │  │     │    LLM Inference    │       │  │   │
│   │  key       │              │  │     │    (Plaintext)      │       │  │   │
│   │            │◄─[E2E]───────│  │     └─────────────────────┘       │  │   │
│   └────────────┘              │  │                                   │  │   │
│                               │  │     Result encrypted before       │  │   │
│                               │  │     leaving enclave               │  │   │
│                               │  └───────────────────────────────────┘  │   │
│                               │                                         │   │
│                               │  HARDWARE ISOLATION                     │   │
│                               │  • Intel SGX / AMD SEV / ARM TrustZone  │   │
│                               │  • Memory encryption                    │   │
│                               │  • Attestation                          │   │
│                               └─────────────────────────────────────────┘   │
│                                                                              │
│   EVEN THE SERVER OPERATOR CANNOT SEE PLAINTEXT                            │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

Implementation with AWS Nitro Enclaves

"""
AWS Nitro Enclaves implementation for secure AI inference

Nitro Enclaves provide:
- Isolated compute environment
- No persistent storage
- No external networking (only vsock to parent)
- Cryptographic attestation
"""

import socket
import json
from dataclasses import dataclass
from typing import Any

# Parent instance (outside enclave)
class EnclaveClient:
    """Client that communicates with the secure enclave"""
    
    def __init__(self, enclave_cid: int, enclave_port: int = 5000):
        self.enclave_cid = enclave_cid
        self.enclave_port = enclave_port
        
    def process_phi_with_ai(
        self,
        encrypted_message: bytes,
        session_id: str,
        attestation_doc: bytes
    ) -> bytes:
        """
        Send encrypted PHI to enclave for AI processing
        
        The enclave will:
        1. Verify attestation
        2. Decrypt message using session key
        3. Process with LLM
        4. Encrypt response
        5. Return encrypted response
        """
        
        # Connect via vsock (only connection allowed to enclave)
        sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
        sock.connect((self.enclave_cid, self.enclave_port))
        
        request = {
            "action": "process_medical_query",
            "encrypted_message": encrypted_message.hex(),
            "session_id": session_id,
        }
        
        sock.send(json.dumps(request).encode())
        response = sock.recv(65536)
        sock.close()
        
        return bytes.fromhex(json.loads(response)["encrypted_response"])


# Inside the enclave
class EnclaveServer:
    """
    Server running inside Nitro Enclave
    
    Has access to:
    - Decryption keys (fetched via attestation from KMS)
    - LLM model weights
    
    Cannot:
    - Access internet
    - Write to persistent storage
    - Be inspected by parent instance
    """
    
    def __init__(self):
        self.kms_key = self._fetch_key_from_kms()
        self.llm = self._load_llm_model()
        self.sessions = {}
        
    def _fetch_key_from_kms(self) -> bytes:
        """
        Fetch decryption key from KMS using attestation
        
        KMS verifies the enclave's attestation document before
        releasing the key. This ensures only valid enclave code
        can access the key.
        """
        
        # In real implementation, use aws-nitro-enclaves-sdk
        # This requires attestation document signed by Nitro Hypervisor
        attestation_doc = self._get_attestation_document()
        
        # KMS verifies:
        # - Enclave image measurement (code hash)
        # - Signing certificate
        # - Platform configuration
        
        # Only then releases the key
        return kms_client.decrypt_with_attestation(
            attestation_doc=attestation_doc,
            key_id="alias/enclave-master-key"
        )
    
    def _load_llm_model(self):
        """Load quantized LLM model into enclave memory"""
        
        # Use a smaller, quantized model that fits in enclave memory
        # Examples: Llama-7B-GPTQ, Mistral-7B-GGUF
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        model = AutoModelForCausalLM.from_pretrained(
            "/enclave/models/medical-llm-7b-q4",
            device_map="auto",
            load_in_4bit=True,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            "/enclave/models/medical-llm-7b-q4"
        )
        
        return model, tokenizer
    
    def handle_request(self, request: dict) -> dict:
        """Process request inside secure enclave"""
        
        session_id = request["session_id"]
        encrypted_message = bytes.fromhex(request["encrypted_message"])
        
        # Decrypt using session key
        session_key = self._get_session_key(session_id)
        plaintext = self._decrypt(encrypted_message, session_key)
        
        # Process with LLM (plaintext only exists in enclave memory)
        medical_query = plaintext.decode()
        response = self._generate_response(medical_query)
        
        # Encrypt response before returning
        encrypted_response = self._encrypt(response.encode(), session_key)
        
        # Clear plaintext from memory
        del plaintext
        del medical_query
        del response
        
        return {"encrypted_response": encrypted_response.hex()}
    
    def _generate_response(self, query: str) -> str:
        """Generate medical AI response"""
        
        model, tokenizer = self.llm
        
        prompt = f"""You are a medical assistant. Answer the following health query.
        Always recommend consulting a healthcare provider for medical decisions.
        
        Query: {query}
        
        Response:"""
        
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_new_tokens=500)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return response

Architecture 2: On-Premise LLM Deployment

Self-Hosted LLM Architecture

┌─────────────────────────────────────────────────────────────────────────────┐
│                    ON-PREMISE LLM ARCHITECTURE                               │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│   HEALTHCARE ORGANIZATION'S INFRASTRUCTURE                                  │
│   ──────────────────────────────────────────                                │
│                                                                              │
│   ┌────────────┐     ┌─────────────────────────────────────────────────┐   │
│   │   Client   │     │           PRIVATE CLOUD / DATA CENTER           │   │
│   │            │     │                                                  │   │
│   │  E2E       │     │  ┌──────────────┐    ┌──────────────────────┐   │   │
│   │  Encrypted │────►│  │    API       │───►│  LLM Inference       │   │   │
│   │  to API    │     │  │    Gateway   │    │  Server              │   │   │
│   │  Gateway   │◄────│  │              │◄───│  (vLLM/TGI)         │   │   │
│   └────────────┘     │  │  Decrypts    │    │                      │   │   │
│                      │  │  here        │    │  Llama-70B / Mixtral │   │   │
│                      │  └──────────────┘    └──────────────────────┘   │   │
│                      │         │                      │                 │   │
│                      │         ▼                      ▼                 │   │
│                      │  ┌──────────────┐    ┌──────────────────────┐   │   │
│                      │  │   Audit      │    │  Vector Database     │   │   │
│                      │  │   Logging    │    │  (RAG Context)       │   │   │
│                      │  └──────────────┘    └──────────────────────┘   │   │
│                      │                                                  │   │
│                      │  ALL DATA STAYS WITHIN ORGANIZATION'S NETWORK   │   │
│                      └─────────────────────────────────────────────────┘   │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

Implementation with vLLM

"""
On-premise LLM deployment for HIPAA-compliant AI

Benefits:
- Complete data control
- No external API calls
- Full audit trail
- HIPAA BAA not required for LLM provider
"""

from vllm import LLM, SamplingParams
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
from typing import Optional
import asyncio

# Initialize the LLM (runs on your infrastructure)
llm = LLM(
    model="mistralai/Mistral-7B-Instruct-v0.2",  # Or medical fine-tuned model
    tensor_parallel_size=4,  # For multi-GPU
    dtype="float16",
    max_model_len=8192,
)

app = FastAPI()

class MedicalQuery(BaseModel):
    query: str
    patient_context: Optional[str] = None
    session_id: str

class MedicalResponse(BaseModel):
    response: str
    disclaimer: str
    session_id: str

# Medical system prompt
MEDICAL_SYSTEM_PROMPT = """You are a medical AI assistant operating in a HIPAA-compliant environment.

IMPORTANT GUIDELINES:
1. Provide helpful medical information based on the query
2. Always recommend consulting with a healthcare provider
3. Do not diagnose conditions definitively
4. If information could be life-threatening, advise seeking emergency care
5. Keep responses factual and evidence-based
6. Respect patient privacy and confidentiality

"""

@app.post("/api/medical-ai", response_model=MedicalResponse)
async def process_medical_query(
    query: MedicalQuery,
    current_user: User = Depends(get_current_user),
    audit_logger: AuditLogger = Depends(get_audit_logger),
):
    """Process medical query with on-premise LLM"""
    
    # Log the access (without logging actual PHI content)
    await audit_logger.log({
        "event_type": "AI_MEDICAL_QUERY",
        "user_id": current_user.id,
        "session_id": query.session_id,
        "query_length": len(query.query),
        "has_context": bool(query.patient_context),
    })
    
    # Construct prompt
    full_prompt = f"""{MEDICAL_SYSTEM_PROMPT}

Patient Context: {query.patient_context or 'Not provided'}

User Query: {query.query}

Medical AI Response:"""

    # Generate response
    sampling_params = SamplingParams(
        temperature=0.7,
        max_tokens=1000,
        stop=["User Query:", "Patient Context:"],
    )
    
    outputs = llm.generate([full_prompt], sampling_params)
    response_text = outputs[0].outputs[0].text.strip()
    
    # Log completion
    await audit_logger.log({
        "event_type": "AI_RESPONSE_GENERATED",
        "user_id": current_user.id,
        "session_id": query.session_id,
        "response_length": len(response_text),
    })
    
    return MedicalResponse(
        response=response_text,
        disclaimer="This AI response is for informational purposes only. Please consult a healthcare provider for medical advice.",
        session_id=query.session_id,
    )


# RAG Integration for medical knowledge
class MedicalRAG:
    """Retrieval-Augmented Generation for medical queries"""
    
    def __init__(self, vector_store, llm):
        self.vector_store = vector_store
        self.llm = llm
        
    async def query_with_context(
        self,
        query: str,
        patient_history: Optional[str] = None
    ) -> str:
        """
        Query with relevant medical knowledge retrieval
        
        The vector store contains:
        - Medical guidelines
        - Drug interactions
        - Treatment protocols
        - Patient history (encrypted, decrypted for this query)
        """
        
        # Retrieve relevant context
        relevant_docs = await self.vector_store.similarity_search(
            query,
            k=5,
            filter={"category": "medical_knowledge"}
        )
        
        context = "\n\n".join([doc.page_content for doc in relevant_docs])
        
        prompt = f"""{MEDICAL_SYSTEM_PROMPT}

Relevant Medical Knowledge:
{context}

Patient History: {patient_history or 'Not available'}

Query: {query}

Based on the medical knowledge and patient context, provide a helpful response:"""

        outputs = self.llm.generate([prompt], SamplingParams(
            temperature=0.5,  # Lower for more factual responses
            max_tokens=1000,
        ))
        
        return outputs[0].outputs[0].text.strip()

Architecture 3: Hybrid Approach

The Practical Solution

Most real-world healthcare AI systems use a hybrid approach:
"""
Hybrid Architecture for Healthcare AI Chat

Strategy:
1. E2E encrypt all chat messages
2. Use on-premise LLM for PHI processing
3. Use cloud LLMs only for de-identified/general queries
4. Client-side AI for simple tasks
"""

from enum import Enum
from dataclasses import dataclass
from typing import Optional

class QuerySensitivity(Enum):
    PUBLIC = "public"         # General health info, no PHI
    CONTEXTUAL = "contextual" # Needs patient context but can be de-identified
    SENSITIVE = "sensitive"   # Contains PHI, must stay on-premise

@dataclass
class ProcessingDecision:
    sensitivity: QuerySensitivity
    processor: str  # "client", "on_premise", "cloud"
    requires_deidentification: bool

class HybridAIRouter:
    """
    Route queries to appropriate processing environment
    based on sensitivity and content
    """
    
    def __init__(
        self,
        client_model,
        on_premise_llm,
        cloud_llm,
        de_identifier,
    ):
        self.client_model = client_model
        self.on_premise_llm = on_premise_llm
        self.cloud_llm = cloud_llm
        self.de_identifier = de_identifier
        
    def classify_query(self, query: str, context: dict) -> ProcessingDecision:
        """Classify query to determine processing approach"""
        
        # Check for PHI indicators
        phi_patterns = [
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
            r'\b[A-Z][a-z]+ [A-Z][a-z]+\b',  # Names
            r'\bMRN:\s*\d+\b',  # Medical record numbers
            r'\bDOB:\s*\d{2}/\d{2}/\d{4}\b',  # Date of birth
        ]
        
        has_phi = any(re.search(p, query) for p in phi_patterns)
        references_patient = "patient" in context or context.get("patient_id")
        
        if has_phi:
            return ProcessingDecision(
                sensitivity=QuerySensitivity.SENSITIVE,
                processor="on_premise",
                requires_deidentification=False,
            )
        elif references_patient:
            return ProcessingDecision(
                sensitivity=QuerySensitivity.CONTEXTUAL,
                processor="on_premise",  # Or cloud with de-identification
                requires_deidentification=True,
            )
        else:
            return ProcessingDecision(
                sensitivity=QuerySensitivity.PUBLIC,
                processor="cloud",  # Can use more powerful cloud model
                requires_deidentification=False,
            )
    
    async def process_query(
        self,
        query: str,
        context: dict,
        session_id: str,
    ) -> str:
        """Process query using appropriate method"""
        
        decision = self.classify_query(query, context)
        
        if decision.processor == "client":
            # Simple queries handled on device
            return await self.client_model.generate(query)
            
        elif decision.processor == "on_premise":
            # PHI stays on-premise
            return await self.on_premise_llm.generate(
                query=query,
                context=context,
            )
            
        elif decision.processor == "cloud":
            if decision.requires_deidentification:
                # De-identify before sending to cloud
                safe_query = self.de_identifier.deidentify(query)
                safe_context = self.de_identifier.deidentify_context(context)
                
                response = await self.cloud_llm.generate(
                    query=safe_query,
                    context=safe_context,
                )
                
                # Re-identify response
                return self.de_identifier.reidentify(response, context)
            else:
                return await self.cloud_llm.generate(query=query)


class PHIDeIdentifier:
    """
    De-identify PHI for safe cloud processing
    
    Uses consistent pseudonymization so context is preserved
    but actual identifiers are removed.
    """
    
    def __init__(self):
        self.mapping = {}  # Original -> Pseudonym
        self.reverse_mapping = {}  # Pseudonym -> Original
        
    def deidentify(self, text: str) -> str:
        """Replace PHI with consistent pseudonyms"""
        
        result = text
        
        # Names
        names = self._extract_names(text)
        for i, name in enumerate(names):
            pseudonym = f"[PATIENT_{i+1}]"
            self.mapping[name] = pseudonym
            self.reverse_mapping[pseudonym] = name
            result = result.replace(name, pseudonym)
        
        # Dates (shift by random consistent offset)
        dates = self._extract_dates(text)
        for date in dates:
            shifted = self._shift_date(date)
            self.mapping[date] = shifted
            self.reverse_mapping[shifted] = date
            result = result.replace(date, shifted)
        
        # Other identifiers...
        
        return result
    
    def reidentify(self, text: str, original_context: dict) -> str:
        """Restore original identifiers in response"""
        
        result = text
        for pseudonym, original in self.reverse_mapping.items():
            result = result.replace(pseudonym, original)
        
        return result

Architecture 4: Privacy-Preserving AI

Differential Privacy for Training

"""
Differential Privacy for Healthcare AI

Use DP when:
- Fine-tuning models on patient data
- Training on aggregate patterns
- Generating synthetic training data
"""

import torch
from opacus import PrivacyEngine
from torch.utils.data import DataLoader

class DifferentiallyPrivateTraining:
    """
    Train models with differential privacy guarantees
    
    Provides mathematical guarantee that individual patient
    data cannot be extracted from the model.
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        epsilon: float = 1.0,  # Privacy budget
        delta: float = 1e-5,   # Probability of privacy breach
        max_grad_norm: float = 1.0,
    ):
        self.model = model
        self.epsilon = epsilon
        self.delta = delta
        self.max_grad_norm = max_grad_norm
        
    def train_with_dp(
        self,
        train_loader: DataLoader,
        epochs: int,
        learning_rate: float = 1e-4,
    ):
        """Train with differential privacy"""
        
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate
        )
        
        # Wrap with Opacus privacy engine
        privacy_engine = PrivacyEngine()
        
        model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
            module=self.model,
            optimizer=optimizer,
            data_loader=train_loader,
            target_epsilon=self.epsilon,
            target_delta=self.delta,
            epochs=epochs,
            max_grad_norm=self.max_grad_norm,
        )
        
        for epoch in range(epochs):
            for batch in train_loader:
                optimizer.zero_grad()
                
                loss = self._compute_loss(batch)
                loss.backward()
                
                optimizer.step()
            
            # Check privacy spent
            epsilon_spent = privacy_engine.get_epsilon(self.delta)
            print(f"Epoch {epoch}: ε = {epsilon_spent:.2f}")
            
            if epsilon_spent >= self.epsilon:
                print("Privacy budget exhausted")
                break
        
        return model

Federated Learning for Multi-Hospital Collaboration

"""
Federated Learning for Healthcare AI

Train models across multiple hospitals without sharing patient data.
Each hospital keeps their data local; only model updates are shared.
"""

from typing import List, Dict
import numpy as np

class FederatedMedicalAI:
    """
    Federated learning coordinator for healthcare
    
    Architecture:
    - Central server coordinates training
    - Each hospital trains locally on their data
    - Only gradients/weights are shared (never raw data)
    - Secure aggregation prevents server from seeing individual updates
    """
    
    def __init__(self, global_model, hospitals: List[str]):
        self.global_model = global_model
        self.hospitals = hospitals
        
    async def federated_round(self) -> Dict:
        """Execute one round of federated learning"""
        
        # 1. Distribute global model to all hospitals
        global_weights = self.global_model.state_dict()
        
        # 2. Each hospital trains locally (in parallel)
        local_updates = await asyncio.gather(*[
            self._train_at_hospital(hospital, global_weights)
            for hospital in self.hospitals
        ])
        
        # 3. Securely aggregate updates
        aggregated = self._secure_aggregate(local_updates)
        
        # 4. Update global model
        self.global_model.load_state_dict(aggregated)
        
        return {
            "participating_hospitals": len(self.hospitals),
            "round_completed": True,
        }
    
    async def _train_at_hospital(
        self,
        hospital: str,
        global_weights: dict
    ) -> dict:
        """
        Train on hospital's local data
        
        This runs AT the hospital, not centrally.
        Only the weight updates leave the hospital.
        """
        
        # Hospital loads their local data (stays local)
        local_data = await self._get_hospital_data(hospital)
        
        # Create local model with global weights
        local_model = self._create_model()
        local_model.load_state_dict(global_weights)
        
        # Train on local data
        trained_model = self._local_training(local_model, local_data)
        
        # Compute weight delta (what changed)
        weight_delta = {
            k: trained_model.state_dict()[k] - global_weights[k]
            for k in global_weights.keys()
        }
        
        # Add noise for differential privacy
        noisy_delta = self._add_dp_noise(weight_delta)
        
        return noisy_delta
    
    def _secure_aggregate(self, updates: List[dict]) -> dict:
        """
        Securely aggregate model updates
        
        Uses secure multi-party computation or homomorphic
        encryption so server never sees individual updates.
        """
        
        # Simple averaging (in production, use secure aggregation)
        aggregated = {}
        num_updates = len(updates)
        
        for key in updates[0].keys():
            aggregated[key] = sum(u[key] for u in updates) / num_updates
            
        return aggregated

Complete E2E Chat + AI System

"""
Complete HIPAA-Compliant E2E Encrypted Chat with AI

Combines:
- Signal Protocol for E2E encryption
- Hybrid AI processing
- Comprehensive audit logging
"""

class HIPAACompliantChat:
    """
    End-to-end encrypted healthcare chat with AI capabilities
    """
    
    def __init__(
        self,
        signal_protocol: SignalProtocol,
        ai_router: HybridAIRouter,
        audit_logger: AuditLogger,
        key_manager: VaultKeyManager,
    ):
        self.signal = signal_protocol
        self.ai = ai_router
        self.audit = audit_logger
        self.keys = key_manager
        
    async def send_message(
        self,
        sender_id: str,
        recipient_id: str,
        message: str,
        request_ai_response: bool = False,
    ) -> dict:
        """Send E2E encrypted message, optionally with AI processing"""
        
        session_id = f"{sender_id}:{recipient_id}"
        
        # 1. Encrypt message for recipient
        encrypted = self.signal.encrypt_message(
            session_id,
            message.encode()
        )
        
        # 2. Log message send (not content)
        await self.audit.log({
            "event_type": "MESSAGE_SENT",
            "sender_id": sender_id,
            "recipient_id": recipient_id,
            "message_size": len(message),
            "encrypted": True,
        })
        
        result = {"encrypted_message": encrypted}
        
        # 3. If AI response requested, process securely
        if request_ai_response:
            ai_response = await self._process_with_ai(
                message,
                sender_id,
                session_id,
            )
            
            # Encrypt AI response
            encrypted_ai = self.signal.encrypt_message(
                session_id,
                ai_response.encode()
            )
            
            result["ai_response"] = encrypted_ai
            
        return result
    
    async def _process_with_ai(
        self,
        message: str,
        user_id: str,
        session_id: str,
    ) -> str:
        """Process message with AI while maintaining security"""
        
        # Get patient context (if authorized)
        context = await self._get_authorized_context(user_id)
        
        # Route to appropriate AI processor
        response = await self.ai.process_query(
            query=message,
            context=context,
            session_id=session_id,
        )
        
        # Log AI interaction
        await self.audit.log({
            "event_type": "AI_INTERACTION",
            "user_id": user_id,
            "session_id": session_id,
            "processing_location": self.ai.last_processor,
        })
        
        return response
    
    async def receive_message(
        self,
        recipient_id: str,
        sender_id: str,
        encrypted_message: dict,
    ) -> str:
        """Decrypt received message"""
        
        session_id = f"{sender_id}:{recipient_id}"
        
        plaintext = self.signal.decrypt_message(
            session_id,
            encrypted_message
        )
        
        await self.audit.log({
            "event_type": "MESSAGE_RECEIVED",
            "recipient_id": recipient_id,
            "sender_id": sender_id,
            "decrypted": True,
        })
        
        return plaintext.decode()

Key Takeaways

No Perfect Solution

LLMs fundamentally need plaintext. Choose architecture based on your risk tolerance.

Defense in Depth

Combine multiple approaches: E2E encryption, TEEs, on-premise deployment.

Minimize Cloud Exposure

Process sensitive PHI on-premise; use cloud only for de-identified data.

Audit Everything

Log all AI interactions without logging actual PHI content.

Decision Matrix

ApproachSecurity LevelCostCapabilityComplexity
Secure EnclavesHighestHighLimited by enclave resourcesVery High
On-Premise LLMHighHighGood (depends on hardware)Medium
Hybrid + De-IDMedium-HighMediumBest (cloud + local)Medium
Cloud OnlyLowerLowBestLow

Practice Exercise

1

Implement Signal Protocol

Build basic E2E encryption using the Signal Protocol patterns.
2

Deploy Local LLM

Set up an on-premise LLM with vLLM or Ollama.
3

Build Hybrid Router

Create a query classifier that routes to appropriate processing.
4

Add De-identification

Implement PHI de-identification for cloud-safe queries.
5

Integrate Audit Logging

Log all AI interactions with proper privacy controls.

Next Steps