Skip to main content
Multimodal Models

Multimodal Models

Connecting Vision and Language

Multimodal models understand multiple types of data — images, text, audio — in a shared representation space. Key applications:
  • Image-text search
  • Visual question answering
  • Image captioning
  • Text-to-image generation

CLIP: Contrastive Language-Image Pretraining

CLIP Architecture

Core Idea

Learn a shared embedding space where matching image-text pairs are close together. L=1Ni=1Nlogexp(sim(Ii,Ti)/τ)j=1Nexp(sim(Ii,Tj)/τ)\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \log \frac{\exp(\text{sim}(I_i, T_i)/\tau)}{\sum_{j=1}^{N} \exp(\text{sim}(I_i, T_j)/\tau)}

CLIP Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class CLIP(nn.Module):
    """Simplified CLIP model."""
    
    def __init__(self, image_encoder, text_encoder, embed_dim=512, temperature=0.07):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        
        # Projection heads
        self.image_proj = nn.Linear(image_encoder.embed_dim, embed_dim)
        self.text_proj = nn.Linear(text_encoder.embed_dim, embed_dim)
        
        self.temperature = nn.Parameter(torch.ones([]) * temperature)
    
    def encode_image(self, images):
        features = self.image_encoder(images)
        return F.normalize(self.image_proj(features), dim=-1)
    
    def encode_text(self, text):
        features = self.text_encoder(text)
        return F.normalize(self.text_proj(features), dim=-1)
    
    def forward(self, images, text):
        image_embeds = self.encode_image(images)
        text_embeds = self.encode_text(text)
        
        # Compute similarity matrix
        logits = (image_embeds @ text_embeds.T) / self.temperature
        
        return logits


def clip_loss(logits):
    """Symmetric contrastive loss."""
    labels = torch.arange(logits.shape[0], device=logits.device)
    
    # Image-to-text loss
    loss_i2t = F.cross_entropy(logits, labels)
    # Text-to-image loss
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2

Using Pretrained CLIP

from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image

# Load pretrained CLIP
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Zero-shot classification
image = Image.open("cat.jpg")
texts = ["a photo of a cat", "a photo of a dog", "a photo of a bird"]

inputs = processor(
    text=texts,
    images=image,
    return_tensors="pt",
    padding=True
)

with torch.no_grad():
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

print(f"Predictions: {dict(zip(texts, probs[0].tolist()))}")

Zero-Shot Classification

def zero_shot_classify(model, processor, image, class_names):
    """Classify image using text descriptions."""
    # Create text prompts
    prompts = [f"a photo of a {cls}" for cls in class_names]
    
    inputs = processor(
        text=prompts,
        images=image,
        return_tensors="pt",
        padding=True
    )
    
    with torch.no_grad():
        outputs = model(**inputs)
        probs = outputs.logits_per_image.softmax(dim=1)
    
    predicted_class = class_names[probs.argmax().item()]
    confidence = probs.max().item()
    
    return predicted_class, confidence

Visual Question Answering (VQA)

from transformers import BlipProcessor, BlipForQuestionAnswering

# Load BLIP for VQA
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

def answer_question(image, question):
    inputs = processor(image, question, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(**inputs)
    
    answer = processor.decode(outputs[0], skip_special_tokens=True)
    return answer

# Usage
image = Image.open("scene.jpg")
question = "What color is the car?"
answer = answer_question(image, question)
print(f"Q: {question}\nA: {answer}")

Image Captioning

from transformers import BlipProcessor, BlipForConditionalGeneration

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

def generate_caption(image, conditional_text=None):
    if conditional_text:
        inputs = processor(image, conditional_text, return_tensors="pt")
    else:
        inputs = processor(image, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=50)
    
    caption = processor.decode(outputs[0], skip_special_tokens=True)
    return caption

# Unconditional captioning
caption = generate_caption(image)
print(f"Caption: {caption}")

# Conditional captioning
caption = generate_caption(image, "a photograph of")
print(f"Caption: {caption}")

LLaVA: Large Language-and-Vision Assistant

Connecting vision encoders with LLMs:
from transformers import AutoProcessor, LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

def chat_with_image(image, prompt):
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt},
            ],
        },
    ]
    
    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(**inputs, max_new_tokens=200)
    return processor.decode(outputs[0], skip_special_tokens=True)

# Usage
response = chat_with_image(image, "Describe this image in detail.")

Building a Multimodal Model

Vision-Language Projector

class VisionLanguageProjector(nn.Module):
    """Project vision features to language model space."""
    
    def __init__(self, vision_dim, language_dim, num_query_tokens=32):
        super().__init__()
        
        # Learnable query tokens (Q-Former style)
        self.query_tokens = nn.Parameter(torch.randn(1, num_query_tokens, language_dim))
        
        # Cross-attention to extract from vision features
        self.cross_attention = nn.MultiheadAttention(
            language_dim, num_heads=8, batch_first=True
        )
        
        # Project vision features
        self.vision_proj = nn.Linear(vision_dim, language_dim)
        
        # MLP for final projection
        self.mlp = nn.Sequential(
            nn.Linear(language_dim, language_dim * 4),
            nn.GELU(),
            nn.Linear(language_dim * 4, language_dim),
        )
    
    def forward(self, vision_features):
        B = vision_features.shape[0]
        
        # Project vision features
        vision_features = self.vision_proj(vision_features)
        
        # Expand query tokens for batch
        queries = self.query_tokens.expand(B, -1, -1)
        
        # Cross-attention: queries attend to vision features
        attended, _ = self.cross_attention(queries, vision_features, vision_features)
        
        # MLP projection
        output = self.mlp(attended)
        
        return output  # (B, num_query_tokens, language_dim)

Image-Text Retrieval

def build_image_text_index(model, images, texts, processor):
    """Build index for image-text retrieval."""
    model.eval()
    
    # Encode all images
    image_embeds = []
    for image in images:
        inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            embed = model.get_image_features(**inputs)
            image_embeds.append(F.normalize(embed, dim=-1))
    image_embeds = torch.cat(image_embeds, dim=0)
    
    # Encode all texts
    text_embeds = []
    for text in texts:
        inputs = processor(text=text, return_tensors="pt")
        with torch.no_grad():
            embed = model.get_text_features(**inputs)
            text_embeds.append(F.normalize(embed, dim=-1))
    text_embeds = torch.cat(text_embeds, dim=0)
    
    return image_embeds, text_embeds


def retrieve_images(query_text, image_embeds, images, model, processor, top_k=5):
    """Retrieve images matching text query."""
    inputs = processor(text=query_text, return_tensors="pt")
    with torch.no_grad():
        text_embed = model.get_text_features(**inputs)
        text_embed = F.normalize(text_embed, dim=-1)
    
    # Compute similarities
    similarities = (text_embed @ image_embeds.T).squeeze(0)
    top_indices = similarities.argsort(descending=True)[:top_k]
    
    return [images[i] for i in top_indices], similarities[top_indices]

Multimodal Model Comparison

ModelVisionLanguageTraining DataUse Case
CLIPViT/ResNetTransformer400M pairsRetrieval, zero-shot
BLIPViTBERT129M pairsCaptioning, VQA
LLaVACLIPLLaMA/Vicuna600KConversation
GPT-4VProprietaryGPT-4UnknownGeneral
FlamingoNFNetChinchillaInterleavedFew-shot

Exercises

Build a zero-shot image classifier using CLIP for a custom set of classes.
Create a text-to-image search system using CLIP embeddings.
Fine-tune BLIP on a custom VQA dataset and evaluate performance.

What’s Next