E2E Encryption with AI Agents
The fundamental challenge of healthcare AI: LLMs need plaintext to process, but E2E encryption means only endpoints have plaintext. This module explores practical solutions for HIPAA-compliant AI chat systems.Learning Objectives:
- Understand the encryption-AI tension
- Implement Signal Protocol for healthcare chat
- Explore secure enclaves and TEEs
- Design privacy-preserving AI architectures
- Build HIPAA-compliant AI medical assistants
The Fundamental Tension
The E2E Encryption + AI Challenge
Solution Architecture Overview
There is no perfect solution, but several practical approaches exist:Secure Enclaves (TEE)
Process data in hardware-isolated environments. Highest security, complex to implement.
On-Premise LLMs
Deploy models within your infrastructure. Full control, significant cost.
Endpoint Processing
Run smaller models on user devices. Privacy-first, limited capability.
Hybrid Architecture
Combine approaches based on data sensitivity. Most practical for real-world use.
E2E Encrypted Chat Foundation
Signal Protocol Implementation
Before adding AI, let’s build proper E2E encrypted chat:Copy
"""
Signal Protocol Implementation for Healthcare Chat
Components:
1. X3DH - Extended Triple Diffie-Hellman (initial key exchange)
2. Double Ratchet - Forward secrecy per message
3. Sesame - Multi-device support
"""
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives import hashes, serialization
from dataclasses import dataclass, field
from typing import Optional, Tuple, List
import os
import hashlib
@dataclass
class IdentityKeyPair:
"""Long-term identity keys"""
private_key: Ed25519PrivateKey
public_key: bytes
@classmethod
def generate(cls) -> "IdentityKeyPair":
private = Ed25519PrivateKey.generate()
public = private.public_key().public_bytes(
serialization.Encoding.Raw,
serialization.PublicFormat.Raw
)
return cls(private_key=private, public_key=public)
@dataclass
class PreKeyBundle:
"""Published keys for X3DH"""
identity_key: bytes
signed_prekey: bytes
signed_prekey_signature: bytes
one_time_prekey: Optional[bytes] = None
@dataclass
class DoubleRatchetState:
"""State for Double Ratchet algorithm"""
root_key: bytes
chain_key_sending: bytes
chain_key_receiving: bytes
dh_sending: X25519PrivateKey
dh_receiving: Optional[X25519PublicKey] = None
message_number_sending: int = 0
message_number_receiving: int = 0
previous_chain_length: int = 0
skipped_keys: dict = field(default_factory=dict)
class SignalProtocol:
"""
Signal Protocol implementation for E2E encryption
Provides:
- Perfect forward secrecy
- Future secrecy (break-in recovery)
- Deniability
"""
def __init__(self):
self.sessions: dict[str, DoubleRatchetState] = {}
def x3dh_initiator(
self,
our_identity: IdentityKeyPair,
their_prekey_bundle: PreKeyBundle
) -> Tuple[bytes, bytes]:
"""
X3DH key agreement (initiator side)
Computes shared secret from:
- Our identity key + their signed prekey
- Our ephemeral key + their identity key
- Our ephemeral key + their signed prekey
- Our ephemeral key + their one-time prekey (if available)
"""
# Generate ephemeral key pair
ephemeral_private = X25519PrivateKey.generate()
ephemeral_public = ephemeral_private.public_key().public_bytes(
serialization.Encoding.Raw,
serialization.PublicFormat.Raw
)
# Load their keys
their_identity = X25519PublicKey.from_public_bytes(
their_prekey_bundle.identity_key
)
their_signed_prekey = X25519PublicKey.from_public_bytes(
their_prekey_bundle.signed_prekey
)
# Convert our identity key for X25519
our_identity_x25519 = self._ed25519_to_x25519_private(
our_identity.private_key
)
# DH1: Our identity * their signed prekey
dh1 = our_identity_x25519.exchange(their_signed_prekey)
# DH2: Our ephemeral * their identity
dh2 = ephemeral_private.exchange(their_identity)
# DH3: Our ephemeral * their signed prekey
dh3 = ephemeral_private.exchange(their_signed_prekey)
# DH4: Our ephemeral * their one-time prekey (if available)
if their_prekey_bundle.one_time_prekey:
their_otpk = X25519PublicKey.from_public_bytes(
their_prekey_bundle.one_time_prekey
)
dh4 = ephemeral_private.exchange(their_otpk)
else:
dh4 = b""
# Derive shared secret
shared_secret = self._kdf(dh1 + dh2 + dh3 + dh4, b"X3DH")
return shared_secret, ephemeral_public
def initialize_double_ratchet(
self,
session_id: str,
shared_secret: bytes,
their_ratchet_key: X25519PublicKey
):
"""Initialize Double Ratchet with X3DH shared secret"""
# Generate our first ratchet key pair
our_ratchet = X25519PrivateKey.generate()
# Derive initial keys
root_key, chain_key = self._kdf_rk(
shared_secret,
our_ratchet.exchange(their_ratchet_key)
)
self.sessions[session_id] = DoubleRatchetState(
root_key=root_key,
chain_key_sending=chain_key,
chain_key_receiving=b"",
dh_sending=our_ratchet,
dh_receiving=their_ratchet_key,
)
def encrypt_message(
self,
session_id: str,
plaintext: bytes
) -> dict:
"""
Encrypt a message using Double Ratchet
Each message uses a unique key derived from the chain key,
providing perfect forward secrecy.
"""
state = self.sessions[session_id]
# Derive message key from chain key
message_key, new_chain_key = self._kdf_ck(state.chain_key_sending)
state.chain_key_sending = new_chain_key
# Encrypt with AEAD
nonce = os.urandom(12)
aesgcm = AESGCM(message_key)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
# Prepare header
header = {
"dh_public": state.dh_sending.public_key().public_bytes(
serialization.Encoding.Raw,
serialization.PublicFormat.Raw
).hex(),
"message_number": state.message_number_sending,
"previous_chain_length": state.previous_chain_length,
}
state.message_number_sending += 1
return {
"header": header,
"nonce": nonce.hex(),
"ciphertext": ciphertext.hex(),
}
def decrypt_message(
self,
session_id: str,
encrypted: dict
) -> bytes:
"""Decrypt a message, handling out-of-order delivery"""
state = self.sessions[session_id]
header = encrypted["header"]
# Check if this triggers a DH ratchet
their_dh = X25519PublicKey.from_public_bytes(
bytes.fromhex(header["dh_public"])
)
if their_dh != state.dh_receiving:
# Perform DH ratchet
self._dh_ratchet(state, their_dh)
# Derive message key
message_key, new_chain_key = self._kdf_ck(state.chain_key_receiving)
state.chain_key_receiving = new_chain_key
# Decrypt
nonce = bytes.fromhex(encrypted["nonce"])
ciphertext = bytes.fromhex(encrypted["ciphertext"])
aesgcm = AESGCM(message_key)
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
state.message_number_receiving += 1
return plaintext
def _dh_ratchet(self, state: DoubleRatchetState, their_dh: X25519PublicKey):
"""Perform DH ratchet step"""
state.previous_chain_length = state.message_number_sending
state.message_number_sending = 0
state.message_number_receiving = 0
state.dh_receiving = their_dh
# Derive new receiving chain key
dh_output = state.dh_sending.exchange(their_dh)
state.root_key, state.chain_key_receiving = self._kdf_rk(
state.root_key, dh_output
)
# Generate new sending key pair
state.dh_sending = X25519PrivateKey.generate()
dh_output = state.dh_sending.exchange(their_dh)
state.root_key, state.chain_key_sending = self._kdf_rk(
state.root_key, dh_output
)
def _kdf(self, input_key_material: bytes, info: bytes) -> bytes:
"""Key derivation function"""
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=b"\x00" * 32,
info=info,
)
return hkdf.derive(input_key_material)
def _kdf_rk(self, root_key: bytes, dh_output: bytes) -> Tuple[bytes, bytes]:
"""Root key KDF - derives new root key and chain key"""
output = HKDF(
algorithm=hashes.SHA256(),
length=64,
salt=root_key,
info=b"RatchetKDF",
).derive(dh_output)
return output[:32], output[32:]
def _kdf_ck(self, chain_key: bytes) -> Tuple[bytes, bytes]:
"""Chain key KDF - derives message key and new chain key"""
message_key = hashlib.sha256(chain_key + b"\x01").digest()
new_chain_key = hashlib.sha256(chain_key + b"\x02").digest()
return message_key, new_chain_key
Architecture 1: Secure Enclaves (TEE)
Trusted Execution Environment Approach
Copy
┌─────────────────────────────────────────────────────────────────────────────┐
│ SECURE ENCLAVE ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌────────────┐ ┌─────────────────────────────────────────┐ │
│ │ Client │ │ SECURE ENCLAVE (TEE) │ │
│ │ │ │ ┌───────────────────────────────────┐ │ │
│ │ Encrypt │ │ │ Decryption happens here │ │ │
│ │ with │──[E2E]──────►│ │ │ │ │
│ │ enclave │ │ │ ┌─────────────────────┐ │ │ │
│ │ public │ │ │ │ LLM Inference │ │ │ │
│ │ key │ │ │ │ (Plaintext) │ │ │ │
│ │ │◄─[E2E]───────│ │ └─────────────────────┘ │ │ │
│ └────────────┘ │ │ │ │ │
│ │ │ Result encrypted before │ │ │
│ │ │ leaving enclave │ │ │
│ │ └───────────────────────────────────┘ │ │
│ │ │ │
│ │ HARDWARE ISOLATION │ │
│ │ • Intel SGX / AMD SEV / ARM TrustZone │ │
│ │ • Memory encryption │ │
│ │ • Attestation │ │
│ └─────────────────────────────────────────┘ │
│ │
│ EVEN THE SERVER OPERATOR CANNOT SEE PLAINTEXT │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
Implementation with AWS Nitro Enclaves
Copy
"""
AWS Nitro Enclaves implementation for secure AI inference
Nitro Enclaves provide:
- Isolated compute environment
- No persistent storage
- No external networking (only vsock to parent)
- Cryptographic attestation
"""
import socket
import json
from dataclasses import dataclass
from typing import Any
# Parent instance (outside enclave)
class EnclaveClient:
"""Client that communicates with the secure enclave"""
def __init__(self, enclave_cid: int, enclave_port: int = 5000):
self.enclave_cid = enclave_cid
self.enclave_port = enclave_port
def process_phi_with_ai(
self,
encrypted_message: bytes,
session_id: str,
attestation_doc: bytes
) -> bytes:
"""
Send encrypted PHI to enclave for AI processing
The enclave will:
1. Verify attestation
2. Decrypt message using session key
3. Process with LLM
4. Encrypt response
5. Return encrypted response
"""
# Connect via vsock (only connection allowed to enclave)
sock = socket.socket(socket.AF_VSOCK, socket.SOCK_STREAM)
sock.connect((self.enclave_cid, self.enclave_port))
request = {
"action": "process_medical_query",
"encrypted_message": encrypted_message.hex(),
"session_id": session_id,
}
sock.send(json.dumps(request).encode())
response = sock.recv(65536)
sock.close()
return bytes.fromhex(json.loads(response)["encrypted_response"])
# Inside the enclave
class EnclaveServer:
"""
Server running inside Nitro Enclave
Has access to:
- Decryption keys (fetched via attestation from KMS)
- LLM model weights
Cannot:
- Access internet
- Write to persistent storage
- Be inspected by parent instance
"""
def __init__(self):
self.kms_key = self._fetch_key_from_kms()
self.llm = self._load_llm_model()
self.sessions = {}
def _fetch_key_from_kms(self) -> bytes:
"""
Fetch decryption key from KMS using attestation
KMS verifies the enclave's attestation document before
releasing the key. This ensures only valid enclave code
can access the key.
"""
# In real implementation, use aws-nitro-enclaves-sdk
# This requires attestation document signed by Nitro Hypervisor
attestation_doc = self._get_attestation_document()
# KMS verifies:
# - Enclave image measurement (code hash)
# - Signing certificate
# - Platform configuration
# Only then releases the key
return kms_client.decrypt_with_attestation(
attestation_doc=attestation_doc,
key_id="alias/enclave-master-key"
)
def _load_llm_model(self):
"""Load quantized LLM model into enclave memory"""
# Use a smaller, quantized model that fits in enclave memory
# Examples: Llama-7B-GPTQ, Mistral-7B-GGUF
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"/enclave/models/medical-llm-7b-q4",
device_map="auto",
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"/enclave/models/medical-llm-7b-q4"
)
return model, tokenizer
def handle_request(self, request: dict) -> dict:
"""Process request inside secure enclave"""
session_id = request["session_id"]
encrypted_message = bytes.fromhex(request["encrypted_message"])
# Decrypt using session key
session_key = self._get_session_key(session_id)
plaintext = self._decrypt(encrypted_message, session_key)
# Process with LLM (plaintext only exists in enclave memory)
medical_query = plaintext.decode()
response = self._generate_response(medical_query)
# Encrypt response before returning
encrypted_response = self._encrypt(response.encode(), session_key)
# Clear plaintext from memory
del plaintext
del medical_query
del response
return {"encrypted_response": encrypted_response.hex()}
def _generate_response(self, query: str) -> str:
"""Generate medical AI response"""
model, tokenizer = self.llm
prompt = f"""You are a medical assistant. Answer the following health query.
Always recommend consulting a healthcare provider for medical decisions.
Query: {query}
Response:"""
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=500)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
Architecture 2: On-Premise LLM Deployment
Self-Hosted LLM Architecture
Copy
┌─────────────────────────────────────────────────────────────────────────────┐
│ ON-PREMISE LLM ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ HEALTHCARE ORGANIZATION'S INFRASTRUCTURE │
│ ────────────────────────────────────────── │
│ │
│ ┌────────────┐ ┌─────────────────────────────────────────────────┐ │
│ │ Client │ │ PRIVATE CLOUD / DATA CENTER │ │
│ │ │ │ │ │
│ │ E2E │ │ ┌──────────────┐ ┌──────────────────────┐ │ │
│ │ Encrypted │────►│ │ API │───►│ LLM Inference │ │ │
│ │ to API │ │ │ Gateway │ │ Server │ │ │
│ │ Gateway │◄────│ │ │◄───│ (vLLM/TGI) │ │ │
│ └────────────┘ │ │ Decrypts │ │ │ │ │
│ │ │ here │ │ Llama-70B / Mixtral │ │ │
│ │ └──────────────┘ └──────────────────────┘ │ │
│ │ │ │ │ │
│ │ ▼ ▼ │ │
│ │ ┌──────────────┐ ┌──────────────────────┐ │ │
│ │ │ Audit │ │ Vector Database │ │ │
│ │ │ Logging │ │ (RAG Context) │ │ │
│ │ └──────────────┘ └──────────────────────┘ │ │
│ │ │ │
│ │ ALL DATA STAYS WITHIN ORGANIZATION'S NETWORK │ │
│ └─────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
Implementation with vLLM
Copy
"""
On-premise LLM deployment for HIPAA-compliant AI
Benefits:
- Complete data control
- No external API calls
- Full audit trail
- HIPAA BAA not required for LLM provider
"""
from vllm import LLM, SamplingParams
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
from typing import Optional
import asyncio
# Initialize the LLM (runs on your infrastructure)
llm = LLM(
model="mistralai/Mistral-7B-Instruct-v0.2", # Or medical fine-tuned model
tensor_parallel_size=4, # For multi-GPU
dtype="float16",
max_model_len=8192,
)
app = FastAPI()
class MedicalQuery(BaseModel):
query: str
patient_context: Optional[str] = None
session_id: str
class MedicalResponse(BaseModel):
response: str
disclaimer: str
session_id: str
# Medical system prompt
MEDICAL_SYSTEM_PROMPT = """You are a medical AI assistant operating in a HIPAA-compliant environment.
IMPORTANT GUIDELINES:
1. Provide helpful medical information based on the query
2. Always recommend consulting with a healthcare provider
3. Do not diagnose conditions definitively
4. If information could be life-threatening, advise seeking emergency care
5. Keep responses factual and evidence-based
6. Respect patient privacy and confidentiality
"""
@app.post("/api/medical-ai", response_model=MedicalResponse)
async def process_medical_query(
query: MedicalQuery,
current_user: User = Depends(get_current_user),
audit_logger: AuditLogger = Depends(get_audit_logger),
):
"""Process medical query with on-premise LLM"""
# Log the access (without logging actual PHI content)
await audit_logger.log({
"event_type": "AI_MEDICAL_QUERY",
"user_id": current_user.id,
"session_id": query.session_id,
"query_length": len(query.query),
"has_context": bool(query.patient_context),
})
# Construct prompt
full_prompt = f"""{MEDICAL_SYSTEM_PROMPT}
Patient Context: {query.patient_context or 'Not provided'}
User Query: {query.query}
Medical AI Response:"""
# Generate response
sampling_params = SamplingParams(
temperature=0.7,
max_tokens=1000,
stop=["User Query:", "Patient Context:"],
)
outputs = llm.generate([full_prompt], sampling_params)
response_text = outputs[0].outputs[0].text.strip()
# Log completion
await audit_logger.log({
"event_type": "AI_RESPONSE_GENERATED",
"user_id": current_user.id,
"session_id": query.session_id,
"response_length": len(response_text),
})
return MedicalResponse(
response=response_text,
disclaimer="This AI response is for informational purposes only. Please consult a healthcare provider for medical advice.",
session_id=query.session_id,
)
# RAG Integration for medical knowledge
class MedicalRAG:
"""Retrieval-Augmented Generation for medical queries"""
def __init__(self, vector_store, llm):
self.vector_store = vector_store
self.llm = llm
async def query_with_context(
self,
query: str,
patient_history: Optional[str] = None
) -> str:
"""
Query with relevant medical knowledge retrieval
The vector store contains:
- Medical guidelines
- Drug interactions
- Treatment protocols
- Patient history (encrypted, decrypted for this query)
"""
# Retrieve relevant context
relevant_docs = await self.vector_store.similarity_search(
query,
k=5,
filter={"category": "medical_knowledge"}
)
context = "\n\n".join([doc.page_content for doc in relevant_docs])
prompt = f"""{MEDICAL_SYSTEM_PROMPT}
Relevant Medical Knowledge:
{context}
Patient History: {patient_history or 'Not available'}
Query: {query}
Based on the medical knowledge and patient context, provide a helpful response:"""
outputs = self.llm.generate([prompt], SamplingParams(
temperature=0.5, # Lower for more factual responses
max_tokens=1000,
))
return outputs[0].outputs[0].text.strip()
Architecture 3: Hybrid Approach
The Practical Solution
Most real-world healthcare AI systems use a hybrid approach:Copy
"""
Hybrid Architecture for Healthcare AI Chat
Strategy:
1. E2E encrypt all chat messages
2. Use on-premise LLM for PHI processing
3. Use cloud LLMs only for de-identified/general queries
4. Client-side AI for simple tasks
"""
from enum import Enum
from dataclasses import dataclass
from typing import Optional
class QuerySensitivity(Enum):
PUBLIC = "public" # General health info, no PHI
CONTEXTUAL = "contextual" # Needs patient context but can be de-identified
SENSITIVE = "sensitive" # Contains PHI, must stay on-premise
@dataclass
class ProcessingDecision:
sensitivity: QuerySensitivity
processor: str # "client", "on_premise", "cloud"
requires_deidentification: bool
class HybridAIRouter:
"""
Route queries to appropriate processing environment
based on sensitivity and content
"""
def __init__(
self,
client_model,
on_premise_llm,
cloud_llm,
de_identifier,
):
self.client_model = client_model
self.on_premise_llm = on_premise_llm
self.cloud_llm = cloud_llm
self.de_identifier = de_identifier
def classify_query(self, query: str, context: dict) -> ProcessingDecision:
"""Classify query to determine processing approach"""
# Check for PHI indicators
phi_patterns = [
r'\b\d{3}-\d{2}-\d{4}\b', # SSN
r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', # Names
r'\bMRN:\s*\d+\b', # Medical record numbers
r'\bDOB:\s*\d{2}/\d{2}/\d{4}\b', # Date of birth
]
has_phi = any(re.search(p, query) for p in phi_patterns)
references_patient = "patient" in context or context.get("patient_id")
if has_phi:
return ProcessingDecision(
sensitivity=QuerySensitivity.SENSITIVE,
processor="on_premise",
requires_deidentification=False,
)
elif references_patient:
return ProcessingDecision(
sensitivity=QuerySensitivity.CONTEXTUAL,
processor="on_premise", # Or cloud with de-identification
requires_deidentification=True,
)
else:
return ProcessingDecision(
sensitivity=QuerySensitivity.PUBLIC,
processor="cloud", # Can use more powerful cloud model
requires_deidentification=False,
)
async def process_query(
self,
query: str,
context: dict,
session_id: str,
) -> str:
"""Process query using appropriate method"""
decision = self.classify_query(query, context)
if decision.processor == "client":
# Simple queries handled on device
return await self.client_model.generate(query)
elif decision.processor == "on_premise":
# PHI stays on-premise
return await self.on_premise_llm.generate(
query=query,
context=context,
)
elif decision.processor == "cloud":
if decision.requires_deidentification:
# De-identify before sending to cloud
safe_query = self.de_identifier.deidentify(query)
safe_context = self.de_identifier.deidentify_context(context)
response = await self.cloud_llm.generate(
query=safe_query,
context=safe_context,
)
# Re-identify response
return self.de_identifier.reidentify(response, context)
else:
return await self.cloud_llm.generate(query=query)
class PHIDeIdentifier:
"""
De-identify PHI for safe cloud processing
Uses consistent pseudonymization so context is preserved
but actual identifiers are removed.
"""
def __init__(self):
self.mapping = {} # Original -> Pseudonym
self.reverse_mapping = {} # Pseudonym -> Original
def deidentify(self, text: str) -> str:
"""Replace PHI with consistent pseudonyms"""
result = text
# Names
names = self._extract_names(text)
for i, name in enumerate(names):
pseudonym = f"[PATIENT_{i+1}]"
self.mapping[name] = pseudonym
self.reverse_mapping[pseudonym] = name
result = result.replace(name, pseudonym)
# Dates (shift by random consistent offset)
dates = self._extract_dates(text)
for date in dates:
shifted = self._shift_date(date)
self.mapping[date] = shifted
self.reverse_mapping[shifted] = date
result = result.replace(date, shifted)
# Other identifiers...
return result
def reidentify(self, text: str, original_context: dict) -> str:
"""Restore original identifiers in response"""
result = text
for pseudonym, original in self.reverse_mapping.items():
result = result.replace(pseudonym, original)
return result
Architecture 4: Privacy-Preserving AI
Differential Privacy for Training
Copy
"""
Differential Privacy for Healthcare AI
Use DP when:
- Fine-tuning models on patient data
- Training on aggregate patterns
- Generating synthetic training data
"""
import torch
from opacus import PrivacyEngine
from torch.utils.data import DataLoader
class DifferentiallyPrivateTraining:
"""
Train models with differential privacy guarantees
Provides mathematical guarantee that individual patient
data cannot be extracted from the model.
"""
def __init__(
self,
model: torch.nn.Module,
epsilon: float = 1.0, # Privacy budget
delta: float = 1e-5, # Probability of privacy breach
max_grad_norm: float = 1.0,
):
self.model = model
self.epsilon = epsilon
self.delta = delta
self.max_grad_norm = max_grad_norm
def train_with_dp(
self,
train_loader: DataLoader,
epochs: int,
learning_rate: float = 1e-4,
):
"""Train with differential privacy"""
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=learning_rate
)
# Wrap with Opacus privacy engine
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=self.model,
optimizer=optimizer,
data_loader=train_loader,
target_epsilon=self.epsilon,
target_delta=self.delta,
epochs=epochs,
max_grad_norm=self.max_grad_norm,
)
for epoch in range(epochs):
for batch in train_loader:
optimizer.zero_grad()
loss = self._compute_loss(batch)
loss.backward()
optimizer.step()
# Check privacy spent
epsilon_spent = privacy_engine.get_epsilon(self.delta)
print(f"Epoch {epoch}: ε = {epsilon_spent:.2f}")
if epsilon_spent >= self.epsilon:
print("Privacy budget exhausted")
break
return model
Federated Learning for Multi-Hospital Collaboration
Copy
"""
Federated Learning for Healthcare AI
Train models across multiple hospitals without sharing patient data.
Each hospital keeps their data local; only model updates are shared.
"""
from typing import List, Dict
import numpy as np
class FederatedMedicalAI:
"""
Federated learning coordinator for healthcare
Architecture:
- Central server coordinates training
- Each hospital trains locally on their data
- Only gradients/weights are shared (never raw data)
- Secure aggregation prevents server from seeing individual updates
"""
def __init__(self, global_model, hospitals: List[str]):
self.global_model = global_model
self.hospitals = hospitals
async def federated_round(self) -> Dict:
"""Execute one round of federated learning"""
# 1. Distribute global model to all hospitals
global_weights = self.global_model.state_dict()
# 2. Each hospital trains locally (in parallel)
local_updates = await asyncio.gather(*[
self._train_at_hospital(hospital, global_weights)
for hospital in self.hospitals
])
# 3. Securely aggregate updates
aggregated = self._secure_aggregate(local_updates)
# 4. Update global model
self.global_model.load_state_dict(aggregated)
return {
"participating_hospitals": len(self.hospitals),
"round_completed": True,
}
async def _train_at_hospital(
self,
hospital: str,
global_weights: dict
) -> dict:
"""
Train on hospital's local data
This runs AT the hospital, not centrally.
Only the weight updates leave the hospital.
"""
# Hospital loads their local data (stays local)
local_data = await self._get_hospital_data(hospital)
# Create local model with global weights
local_model = self._create_model()
local_model.load_state_dict(global_weights)
# Train on local data
trained_model = self._local_training(local_model, local_data)
# Compute weight delta (what changed)
weight_delta = {
k: trained_model.state_dict()[k] - global_weights[k]
for k in global_weights.keys()
}
# Add noise for differential privacy
noisy_delta = self._add_dp_noise(weight_delta)
return noisy_delta
def _secure_aggregate(self, updates: List[dict]) -> dict:
"""
Securely aggregate model updates
Uses secure multi-party computation or homomorphic
encryption so server never sees individual updates.
"""
# Simple averaging (in production, use secure aggregation)
aggregated = {}
num_updates = len(updates)
for key in updates[0].keys():
aggregated[key] = sum(u[key] for u in updates) / num_updates
return aggregated
Complete E2E Chat + AI System
Copy
"""
Complete HIPAA-Compliant E2E Encrypted Chat with AI
Combines:
- Signal Protocol for E2E encryption
- Hybrid AI processing
- Comprehensive audit logging
"""
class HIPAACompliantChat:
"""
End-to-end encrypted healthcare chat with AI capabilities
"""
def __init__(
self,
signal_protocol: SignalProtocol,
ai_router: HybridAIRouter,
audit_logger: AuditLogger,
key_manager: VaultKeyManager,
):
self.signal = signal_protocol
self.ai = ai_router
self.audit = audit_logger
self.keys = key_manager
async def send_message(
self,
sender_id: str,
recipient_id: str,
message: str,
request_ai_response: bool = False,
) -> dict:
"""Send E2E encrypted message, optionally with AI processing"""
session_id = f"{sender_id}:{recipient_id}"
# 1. Encrypt message for recipient
encrypted = self.signal.encrypt_message(
session_id,
message.encode()
)
# 2. Log message send (not content)
await self.audit.log({
"event_type": "MESSAGE_SENT",
"sender_id": sender_id,
"recipient_id": recipient_id,
"message_size": len(message),
"encrypted": True,
})
result = {"encrypted_message": encrypted}
# 3. If AI response requested, process securely
if request_ai_response:
ai_response = await self._process_with_ai(
message,
sender_id,
session_id,
)
# Encrypt AI response
encrypted_ai = self.signal.encrypt_message(
session_id,
ai_response.encode()
)
result["ai_response"] = encrypted_ai
return result
async def _process_with_ai(
self,
message: str,
user_id: str,
session_id: str,
) -> str:
"""Process message with AI while maintaining security"""
# Get patient context (if authorized)
context = await self._get_authorized_context(user_id)
# Route to appropriate AI processor
response = await self.ai.process_query(
query=message,
context=context,
session_id=session_id,
)
# Log AI interaction
await self.audit.log({
"event_type": "AI_INTERACTION",
"user_id": user_id,
"session_id": session_id,
"processing_location": self.ai.last_processor,
})
return response
async def receive_message(
self,
recipient_id: str,
sender_id: str,
encrypted_message: dict,
) -> str:
"""Decrypt received message"""
session_id = f"{sender_id}:{recipient_id}"
plaintext = self.signal.decrypt_message(
session_id,
encrypted_message
)
await self.audit.log({
"event_type": "MESSAGE_RECEIVED",
"recipient_id": recipient_id,
"sender_id": sender_id,
"decrypted": True,
})
return plaintext.decode()
Key Takeaways
No Perfect Solution
LLMs fundamentally need plaintext. Choose architecture based on your risk tolerance.
Defense in Depth
Combine multiple approaches: E2E encryption, TEEs, on-premise deployment.
Minimize Cloud Exposure
Process sensitive PHI on-premise; use cloud only for de-identified data.
Audit Everything
Log all AI interactions without logging actual PHI content.
Decision Matrix
| Approach | Security Level | Cost | Capability | Complexity |
|---|---|---|---|---|
| Secure Enclaves | Highest | High | Limited by enclave resources | Very High |
| On-Premise LLM | High | High | Good (depends on hardware) | Medium |
| Hybrid + De-ID | Medium-High | Medium | Best (cloud + local) | Medium |
| Cloud Only | Lower | Low | Best | Low |
Practice Exercise
1
Implement Signal Protocol
Build basic E2E encryption using the Signal Protocol patterns.
2
Deploy Local LLM
Set up an on-premise LLM with vLLM or Ollama.
3
Build Hybrid Router
Create a query classifier that routes to appropriate processing.
4
Add De-identification
Implement PHI de-identification for cloud-safe queries.
5
Integrate Audit Logging
Log all AI interactions with proper privacy controls.