WebSocket Chat
Basic WebSocket Server
Copy
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from openai import OpenAI
import json
from typing import List
import asyncio
app = FastAPI()
class ConnectionManager:
"""Manage WebSocket connections."""
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
"""Accept new connection."""
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
"""Remove connection."""
self.active_connections.remove(websocket)
async def send_message(self, message: str, websocket: WebSocket):
"""Send message to specific client."""
await websocket.send_text(message)
async def broadcast(self, message: str):
"""Send message to all clients."""
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
client = OpenAI()
@app.websocket("/ws/chat")
async def websocket_chat(websocket: WebSocket):
"""WebSocket endpoint for streaming chat."""
await manager.connect(websocket)
conversation = []
try:
while True:
# Receive message
data = await websocket.receive_text()
message = json.loads(data)
user_content = message.get("content", "")
conversation.append({"role": "user", "content": user_content})
# Send acknowledgment
await manager.send_message(
json.dumps({"type": "ack", "status": "processing"}),
websocket
)
# Stream response
stream = client.chat.completions.create(
model="gpt-4o-mini",
messages=conversation,
stream=True
)
full_response = ""
for chunk in stream:
if chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
full_response += token
await manager.send_message(
json.dumps({
"type": "token",
"content": token
}),
websocket
)
# Send completion
conversation.append({"role": "assistant", "content": full_response})
await manager.send_message(
json.dumps({
"type": "complete",
"content": full_response
}),
websocket
)
except WebSocketDisconnect:
manager.disconnect(websocket)
# Run with: uvicorn main:app --host 0.0.0.0 --port 8000
WebSocket Client
Copy
import asyncio
import websockets
import json
async def chat_client(uri: str = "ws://localhost:8000/ws/chat"):
"""Interactive WebSocket chat client."""
async with websockets.connect(uri) as websocket:
print("Connected to chat server. Type 'quit' to exit.")
while True:
# Get user input
user_input = input("\nYou: ")
if user_input.lower() == "quit":
break
# Send message
await websocket.send(json.dumps({"content": user_input}))
# Receive streaming response
print("\nAssistant: ", end="", flush=True)
while True:
response = await websocket.recv()
data = json.loads(response)
if data["type"] == "token":
print(data["content"], end="", flush=True)
elif data["type"] == "complete":
print() # Newline
break
elif data["type"] == "ack":
continue
# Run the client
if __name__ == "__main__":
asyncio.run(chat_client())
Server-Sent Events (SSE)
Copy
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from openai import OpenAI
import json
import asyncio
app = FastAPI()
client = OpenAI()
async def generate_sse_stream(messages: list):
"""Generate SSE stream for chat response."""
stream = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
data = json.dumps({"token": token})
yield f"data: {data}\n\n"
await asyncio.sleep(0) # Allow other tasks
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/api/chat/stream")
async def stream_chat(request: Request):
"""SSE endpoint for streaming chat responses."""
body = await request.json()
messages = body.get("messages", [])
return StreamingResponse(
generate_sse_stream(messages),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# Client-side JavaScript:
"""
const eventSource = new EventSource('/api/chat/stream', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({messages: [{role: 'user', content: 'Hello'}]})
});
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.token) {
console.log(data.token);
}
if (data.done) {
eventSource.close();
}
};
"""
OpenAI Realtime API
Audio Conversation
Copy
import asyncio
import websockets
import json
import base64
import wave
import io
class RealtimeClient:
"""Client for OpenAI Realtime API."""
def __init__(self, api_key: str):
self.api_key = api_key
self.url = "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
self.websocket = None
async def connect(self):
"""Establish WebSocket connection."""
headers = {
"Authorization": f"Bearer {self.api_key}",
"OpenAI-Beta": "realtime=v1"
}
self.websocket = await websockets.connect(
self.url,
extra_headers=headers
)
# Configure session
await self.send_event({
"type": "session.update",
"session": {
"modalities": ["text", "audio"],
"instructions": "You are a helpful assistant. Be concise.",
"voice": "alloy",
"input_audio_format": "pcm16",
"output_audio_format": "pcm16",
"turn_detection": {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 500
}
}
})
async def send_event(self, event: dict):
"""Send event to server."""
await self.websocket.send(json.dumps(event))
async def receive_events(self):
"""Receive and handle events from server."""
while True:
message = await self.websocket.recv()
event = json.loads(message)
yield event
async def send_audio(self, audio_data: bytes):
"""Send audio data to server."""
# Audio should be base64-encoded PCM16 at 24kHz
encoded = base64.b64encode(audio_data).decode()
await self.send_event({
"type": "input_audio_buffer.append",
"audio": encoded
})
async def send_text(self, text: str):
"""Send text message."""
await self.send_event({
"type": "conversation.item.create",
"item": {
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": text}]
}
})
# Request response
await self.send_event({"type": "response.create"})
async def close(self):
"""Close connection."""
if self.websocket:
await self.websocket.close()
async def realtime_conversation():
"""Run a realtime conversation."""
import os
client = RealtimeClient(os.environ["OPENAI_API_KEY"])
await client.connect()
print("Connected to Realtime API. Type messages or 'quit' to exit.")
# Handle incoming events in background
async def handle_events():
async for event in client.receive_events():
event_type = event.get("type", "")
if event_type == "response.text.delta":
print(event.get("delta", ""), end="", flush=True)
elif event_type == "response.text.done":
print() # Newline
elif event_type == "response.audio.delta":
# Handle audio output
audio_data = base64.b64decode(event.get("delta", ""))
# Play audio or save to file
elif event_type == "error":
print(f"Error: {event.get('error', {}).get('message', 'Unknown')}")
event_task = asyncio.create_task(handle_events())
try:
while True:
user_input = await asyncio.get_event_loop().run_in_executor(
None, input, "You: "
)
if user_input.lower() == "quit":
break
await client.send_text(user_input)
finally:
event_task.cancel()
await client.close()
# Run with: asyncio.run(realtime_conversation())
Voice Activity Detection
Copy
import numpy as np
from dataclasses import dataclass
from typing import Optional, Callable
import asyncio
@dataclass
class VADConfig:
"""Voice Activity Detection configuration."""
threshold: float = 0.02
min_speech_duration_ms: int = 200
min_silence_duration_ms: int = 500
sample_rate: int = 16000
class VoiceActivityDetector:
"""Detect speech in audio stream."""
def __init__(self, config: VADConfig = None):
self.config = config or VADConfig()
self.is_speaking = False
self.silence_start = None
self.speech_start = None
self.speech_buffer = []
def process_chunk(self, audio_chunk: bytes) -> Optional[dict]:
"""Process audio chunk and return speech events."""
# Convert bytes to numpy array
audio = np.frombuffer(audio_chunk, dtype=np.int16)
# Calculate RMS energy
rms = np.sqrt(np.mean(audio.astype(np.float32) ** 2)) / 32768
current_time_ms = len(self.speech_buffer) * len(audio_chunk) * 1000 / (
self.config.sample_rate * 2 # 2 bytes per sample
)
if rms > self.config.threshold:
# Speech detected
if not self.is_speaking:
self.speech_start = current_time_ms
self.is_speaking = True
self.silence_start = None
self.speech_buffer.append(audio_chunk)
return {"type": "speech_active", "rms": rms}
else:
# Silence detected
if self.is_speaking:
if self.silence_start is None:
self.silence_start = current_time_ms
silence_duration = current_time_ms - self.silence_start
if silence_duration >= self.config.min_silence_duration_ms:
# End of speech
speech_duration = current_time_ms - self.speech_start
if speech_duration >= self.config.min_speech_duration_ms:
# Valid speech segment
audio_data = b"".join(self.speech_buffer)
self.speech_buffer = []
self.is_speaking = False
return {
"type": "speech_end",
"audio": audio_data,
"duration_ms": speech_duration
}
self.speech_buffer = []
self.is_speaking = False
return None
class RealtimeAudioProcessor:
"""Process realtime audio with VAD."""
def __init__(
self,
on_speech_segment: Callable[[bytes], None],
vad_config: VADConfig = None
):
self.vad = VoiceActivityDetector(vad_config)
self.on_speech_segment = on_speech_segment
async def process_stream(self, audio_stream):
"""Process audio stream and emit speech segments."""
async for chunk in audio_stream:
result = self.vad.process_chunk(chunk)
if result and result["type"] == "speech_end":
await self.on_speech_segment(result["audio"])
# Usage with microphone input
async def process_microphone():
"""Process microphone input with VAD."""
import pyaudio
CHUNK_SIZE = 1600 # 100ms at 16kHz
async def handle_speech(audio_data: bytes):
print(f"Speech segment: {len(audio_data)} bytes")
# Send to transcription or realtime API
processor = RealtimeAudioProcessor(
on_speech_segment=handle_speech,
vad_config=VADConfig(threshold=0.02)
)
# Create audio stream generator
async def audio_generator():
p = pyaudio.PyAudio()
stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=16000,
input=True,
frames_per_buffer=CHUNK_SIZE
)
try:
while True:
chunk = stream.read(CHUNK_SIZE)
yield chunk
await asyncio.sleep(0)
finally:
stream.stop_stream()
stream.close()
p.terminate()
await processor.process_stream(audio_generator())
Latency Optimization
Copy
from openai import OpenAI
from dataclasses import dataclass
import time
import asyncio
from typing import Optional
@dataclass
class LatencyMetrics:
"""Track latency metrics."""
time_to_first_token_ms: float
total_time_ms: float
tokens_generated: int
tokens_per_second: float
class LowLatencyClient:
"""Optimized client for low-latency inference."""
def __init__(
self,
model: str = "gpt-4o-mini",
timeout: float = 10.0
):
self.client = OpenAI(timeout=timeout)
self.model = model
def stream_with_metrics(
self,
messages: list,
max_tokens: int = 256
) -> tuple[str, LatencyMetrics]:
"""Stream response and collect latency metrics."""
start_time = time.time()
first_token_time = None
tokens = []
stream = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_tokens,
stream=True,
stream_options={"include_usage": True}
)
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
if first_token_time is None:
first_token_time = time.time()
tokens.append(chunk.choices[0].delta.content)
end_time = time.time()
total_time_ms = (end_time - start_time) * 1000
ttft_ms = (first_token_time - start_time) * 1000 if first_token_time else 0
metrics = LatencyMetrics(
time_to_first_token_ms=ttft_ms,
total_time_ms=total_time_ms,
tokens_generated=len(tokens),
tokens_per_second=len(tokens) / (total_time_ms / 1000) if total_time_ms > 0 else 0
)
return "".join(tokens), metrics
async def parallel_completions(
self,
prompts: list[str],
max_concurrent: int = 5
) -> list[tuple[str, LatencyMetrics]]:
"""Run multiple completions in parallel."""
semaphore = asyncio.Semaphore(max_concurrent)
async def run_one(prompt: str):
async with semaphore:
# Use sync client in thread pool
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: self.stream_with_metrics(
[{"role": "user", "content": prompt}]
)
)
tasks = [run_one(p) for p in prompts]
return await asyncio.gather(*tasks)
class LatencyOptimizer:
"""Strategies for optimizing latency."""
@staticmethod
def optimize_prompt(prompt: str) -> str:
"""Optimize prompt for lower latency."""
# Remove unnecessary whitespace
prompt = " ".join(prompt.split())
# Truncate very long prompts
if len(prompt) > 4000:
prompt = prompt[:4000] + "..."
return prompt
@staticmethod
def select_model_for_latency(
complexity: str,
max_latency_ms: float
) -> str:
"""Select appropriate model based on latency requirements."""
model_latencies = {
"gpt-4o-mini": {"simple": 100, "moderate": 200, "complex": 400},
"gpt-4o": {"simple": 300, "moderate": 600, "complex": 1000},
}
for model, latencies in model_latencies.items():
if latencies.get(complexity, 1000) <= max_latency_ms:
return model
return "gpt-4o-mini" # Fallback to fastest
@staticmethod
def chunk_for_streaming(
text: str,
chunk_size: int = 50
) -> list[str]:
"""Chunk text for streaming display."""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size):
chunks.append(" ".join(words[i:i + chunk_size]))
return chunks
# Usage
client = LowLatencyClient(model="gpt-4o-mini")
# Single completion with metrics
response, metrics = client.stream_with_metrics(
[{"role": "user", "content": "Hello!"}]
)
print(f"Response: {response}")
print(f"TTFT: {metrics.time_to_first_token_ms:.0f}ms")
print(f"Total: {metrics.total_time_ms:.0f}ms")
print(f"Speed: {metrics.tokens_per_second:.1f} tok/s")
# Parallel completions
prompts = ["What is AI?", "Explain Python", "Define cloud computing"]
results = asyncio.run(client.parallel_completions(prompts))
for (response, metrics), prompt in zip(results, prompts):
print(f"{prompt[:20]}: {metrics.time_to_first_token_ms:.0f}ms TTFT")
Typing Indicators and Progress
Copy
from fastapi import FastAPI, WebSocket
from openai import OpenAI
import json
import asyncio
import time
app = FastAPI()
client = OpenAI()
class ProgressiveResponse:
"""Handle progressive response display."""
def __init__(self, websocket: WebSocket):
self.websocket = websocket
self.is_typing = False
self.typing_task = None
async def start_typing(self):
"""Start typing indicator."""
self.is_typing = True
async def send_typing():
while self.is_typing:
await self.websocket.send_json({
"type": "typing",
"status": True
})
await asyncio.sleep(2)
self.typing_task = asyncio.create_task(send_typing())
async def stop_typing(self):
"""Stop typing indicator."""
self.is_typing = False
if self.typing_task:
self.typing_task.cancel()
try:
await self.typing_task
except asyncio.CancelledError:
pass
async def send_token(self, token: str, index: int):
"""Send token with metadata."""
await self.websocket.send_json({
"type": "token",
"content": token,
"index": index,
"timestamp": time.time()
})
async def send_progress(
self,
stage: str,
progress: float,
message: str = ""
):
"""Send progress update."""
await self.websocket.send_json({
"type": "progress",
"stage": stage,
"progress": progress,
"message": message
})
@app.websocket("/ws/chat/progressive")
async def progressive_chat(websocket: WebSocket):
"""WebSocket with progressive response features."""
await websocket.accept()
responder = ProgressiveResponse(websocket)
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
# Show typing indicator
await responder.start_typing()
# Send processing stages
await responder.send_progress("thinking", 0.2, "Understanding your question...")
await asyncio.sleep(0.1)
await responder.send_progress("generating", 0.4, "Formulating response...")
# Stop typing and start streaming
await responder.stop_typing()
stream = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": message["content"]}],
stream=True
)
token_index = 0
for chunk in stream:
if chunk.choices[0].delta.content:
await responder.send_token(
chunk.choices[0].delta.content,
token_index
)
token_index += 1
await responder.send_progress("complete", 1.0, "Done!")
await websocket.send_json({
"type": "complete",
"total_tokens": token_index
})
except Exception as e:
await responder.stop_typing()
await websocket.send_json({
"type": "error",
"message": str(e)
})
# Client-side handling:
"""
const ws = new WebSocket('ws://localhost:8000/ws/chat/progressive');
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch (data.type) {
case 'typing':
showTypingIndicator(data.status);
break;
case 'progress':
updateProgressBar(data.stage, data.progress, data.message);
break;
case 'token':
appendToken(data.content);
break;
case 'complete':
hideProgress();
console.log(`Completed with ${data.total_tokens} tokens`);
break;
case 'error':
showError(data.message);
break;
}
};
"""
Realtime Collaboration
Copy
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import Dict, Set
import json
import asyncio
from dataclasses import dataclass, field
@dataclass
class Room:
"""A collaborative chat room."""
id: str
participants: Set[WebSocket] = field(default_factory=set)
conversation: list = field(default_factory=list)
ai_responding: bool = False
class CollaborativeChat:
"""Multi-user realtime AI chat."""
def __init__(self):
self.rooms: Dict[str, Room] = {}
from openai import OpenAI
self.client = OpenAI()
def get_or_create_room(self, room_id: str) -> Room:
"""Get or create a chat room."""
if room_id not in self.rooms:
self.rooms[room_id] = Room(id=room_id)
return self.rooms[room_id]
async def join_room(self, room_id: str, websocket: WebSocket):
"""Add participant to room."""
room = self.get_or_create_room(room_id)
room.participants.add(websocket)
# Send room state to new participant
await websocket.send_json({
"type": "room_state",
"conversation": room.conversation,
"participants": len(room.participants)
})
# Notify others
await self.broadcast(room, {
"type": "participant_joined",
"count": len(room.participants)
}, exclude=websocket)
async def leave_room(self, room_id: str, websocket: WebSocket):
"""Remove participant from room."""
if room_id in self.rooms:
room = self.rooms[room_id]
room.participants.discard(websocket)
await self.broadcast(room, {
"type": "participant_left",
"count": len(room.participants)
})
async def broadcast(
self,
room: Room,
message: dict,
exclude: WebSocket = None
):
"""Broadcast message to all room participants."""
for ws in room.participants:
if ws != exclude:
try:
await ws.send_json(message)
except Exception:
pass
async def handle_message(
self,
room_id: str,
user_id: str,
content: str
):
"""Handle incoming user message."""
room = self.rooms[room_id]
# Add user message
user_msg = {
"role": "user",
"user_id": user_id,
"content": content
}
room.conversation.append(user_msg)
# Broadcast user message
await self.broadcast(room, {
"type": "user_message",
"user_id": user_id,
"content": content
})
# Generate AI response if not already responding
if not room.ai_responding:
room.ai_responding = True
try:
await self.broadcast(room, {"type": "ai_typing", "status": True})
# Format conversation for API
messages = [
{"role": m["role"], "content": m["content"]}
for m in room.conversation
]
stream = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
stream=True
)
ai_response = ""
for chunk in stream:
if chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
ai_response += token
await self.broadcast(room, {
"type": "ai_token",
"content": token
})
# Add AI response to conversation
room.conversation.append({
"role": "assistant",
"content": ai_response
})
await self.broadcast(room, {
"type": "ai_complete",
"content": ai_response
})
finally:
room.ai_responding = False
await self.broadcast(room, {"type": "ai_typing", "status": False})
app = FastAPI()
collab = CollaborativeChat()
@app.websocket("/ws/room/{room_id}")
async def room_websocket(websocket: WebSocket, room_id: str):
"""WebSocket for collaborative room."""
await websocket.accept()
await collab.join_room(room_id, websocket)
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
await collab.handle_message(
room_id,
message.get("user_id", "anonymous"),
message.get("content", "")
)
except WebSocketDisconnect:
await collab.leave_room(room_id, websocket)
Realtime Best Practices
- Use WebSockets for bidirectional communication
- Implement typing indicators for better UX
- Stream tokens as they are generated
- Handle disconnections gracefully
- Monitor latency metrics continuously
Practice Exercise
Build a realtime AI application that:- Uses WebSockets for bidirectional communication
- Implements voice activity detection
- Streams responses with latency metrics
- Shows typing indicators and progress
- Supports multiple concurrent users
- Minimizing time to first token
- Smooth streaming experience
- Graceful error handling
- Scalable connection management