Skip to main content

Documentation Index

Fetch the complete documentation index at: https://resources.devweekends.com/llms.txt

Use this file to discover all available pages before exploring further.

3D Deep Learning

3D Deep Learning

The World is 3D

We live in three dimensions, yet most deep learning operates on 2D projections (photographs). This is like trying to understand a building by looking at one photograph — you lose depth, occlusion relationships, and the ability to reason about the object from novel viewpoints. Real-world data is inherently 3D:
  • Autonomous driving: LiDAR sensors produce millions of 3D points per second
  • Robotics: Depth sensors enable grasping and manipulation in physical space
  • Medical imaging: CT and MRI scans are 3D volumes, not flat images
  • AR/VR: Scene reconstruction requires understanding 3D geometry
  • Manufacturing: Quality inspection needs precise 3D measurement
The fundamental challenge of 3D deep learning: different sensors produce different representations (point clouds, voxel grids, meshes), and each representation has different mathematical properties that require different neural network architectures. Choosing the right representation is often more important than choosing the right model.
When starting a 3D deep learning project, let the sensor dictate your architecture. If your data comes from LiDAR (point clouds), start with PointNet++ or a sparse voxel method. If you have depth cameras (organized point clouds / depth images), 2D CNN backbones on the depth map are a surprisingly strong baseline. If you have meshes from CAD software, consider MeshCNN or simply sample points from the surface and use point cloud methods.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List

torch.manual_seed(42)

3D Representations

Point Clouds

Unordered set of 3D points: {(xi,yi,zi)}i=1N\{(x_i, y_i, z_i)\}_{i=1}^N
class PointCloud:
    """
    Point cloud representation.
    
    A point cloud is the raw output of a LiDAR or depth sensor: just a bag
    of (x, y, z) coordinates in space. Think of it like a 3D scatter plot
    with potentially millions of dots.
    
    Properties:
    - Unordered (permutation invariant -- there is no "first" point)
    - Irregular sampling (denser near surfaces, sparser at distance)
    - Direct from sensors (LiDAR, stereo cameras, structured light)
    
    Pitfall: Unlike images, you cannot simply feed a point cloud into a
    standard CNN. The lack of grid structure requires specialized architectures.
    """
    
    def __init__(
        self,
        points: torch.Tensor,    # [N, 3] xyz coordinates
        features: Optional[torch.Tensor] = None,  # [N, C] per-point features
        normals: Optional[torch.Tensor] = None    # [N, 3] surface normals
    ):
        self.points = points
        self.features = features
        self.normals = normals
        self.num_points = points.shape[0]
    
    def normalize(self) -> 'PointCloud':
        """Center and scale to unit sphere."""
        centroid = self.points.mean(dim=0)
        points = self.points - centroid
        max_dist = points.norm(dim=1).max()
        points = points / max_dist
        return PointCloud(points, self.features, self.normals)
    
    def random_sample(self, n_points: int) -> 'PointCloud':
        """Randomly sample points."""
        idx = torch.randperm(self.num_points)[:n_points]
        return PointCloud(
            self.points[idx],
            self.features[idx] if self.features is not None else None,
            self.normals[idx] if self.normals is not None else None
        )
    
    def farthest_point_sample(self, n_points: int) -> 'PointCloud':
        """Farthest point sampling for uniform coverage."""
        N = self.num_points
        centroids = torch.zeros(n_points, dtype=torch.long)
        distance = torch.ones(N) * 1e10
        
        # Random starting point
        farthest = torch.randint(0, N, (1,)).item()
        
        for i in range(n_points):
            centroids[i] = farthest
            centroid_point = self.points[farthest]
            
            dist = ((self.points - centroid_point) ** 2).sum(dim=1)
            distance = torch.min(distance, dist)
            farthest = distance.argmax().item()
        
        return PointCloud(self.points[centroids])


# Example
points = torch.randn(1024, 3)
pc = PointCloud(points)
pc_normalized = pc.normalize()
pc_sampled = pc.farthest_point_sample(256)

Voxels

3D grid of values: VRD×H×WV \in \mathbb{R}^{D \times H \times W}
class VoxelGrid:
    """
    Voxel grid representation.
    
    A voxel grid is the 3D equivalent of a pixel grid: discretize 3D space
    into tiny cubes (voxels = "volumetric pixels") and mark which cubes are
    occupied. Think of it like building with LEGO: you approximate any shape
    by stacking small cubes.
    
    Properties:
    - Regular 3D grid (like 3D pixels -- enables standard 3D convolutions)
    - Can use 3D convolutions (the math is identical to 2D conv, just one more dimension)
    - Memory intensive (O(n^3) -- doubling resolution requires 8x more memory!)
    
    Pitfall: A 512^3 voxel grid with float32 features uses ~512 MB just for
    occupancy. Most of those voxels are empty air, which is why sparse
    representations are critical for practical use.
    """
    
    def __init__(
        self,
        grid: torch.Tensor,  # [C, D, H, W] or [D, H, W]
        resolution: int
    ):
        self.grid = grid
        self.resolution = resolution
    
    @staticmethod
    def from_point_cloud(
        points: torch.Tensor,
        resolution: int = 32
    ) -> 'VoxelGrid':
        """Convert point cloud to voxel grid."""
        # Normalize to [0, 1]
        min_pt = points.min(dim=0)[0]
        max_pt = points.max(dim=0)[0]
        points = (points - min_pt) / (max_pt - min_pt + 1e-8)
        
        # Quantize to grid
        indices = (points * (resolution - 1)).long()
        indices = indices.clamp(0, resolution - 1)
        
        # Create occupancy grid
        grid = torch.zeros(resolution, resolution, resolution)
        grid[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
        
        return VoxelGrid(grid.unsqueeze(0), resolution)
    
    def to_point_cloud(self) -> torch.Tensor:
        """Convert voxel grid back to point cloud."""
        occupied = self.grid.squeeze() > 0.5
        indices = occupied.nonzero().float()
        points = indices / (self.resolution - 1)
        return points


class SparseVoxelGrid:
    """
    Sparse voxel representation for efficiency.
    Only stores occupied voxels.
    """
    
    def __init__(
        self,
        coords: torch.Tensor,    # [N, 3] voxel coordinates
        features: torch.Tensor,  # [N, C] voxel features
        resolution: int
    ):
        self.coords = coords
        self.features = features
        self.resolution = resolution
    
    @staticmethod
    def from_point_cloud(
        points: torch.Tensor,
        features: torch.Tensor,
        voxel_size: float = 0.05
    ) -> 'SparseVoxelGrid':
        """Voxelize point cloud with feature averaging."""
        # Quantize points
        coords = (points / voxel_size).floor().long()
        
        # Unique voxels
        unique_coords, inverse = torch.unique(
            coords, dim=0, return_inverse=True
        )
        
        # Average features per voxel
        num_voxels = unique_coords.shape[0]
        voxel_features = torch.zeros(num_voxels, features.shape[1])
        counts = torch.zeros(num_voxels)
        
        for i, idx in enumerate(inverse):
            voxel_features[idx] += features[i]
            counts[idx] += 1
        
        voxel_features = voxel_features / counts.unsqueeze(1)
        
        resolution = coords.max().item() + 1
        
        return SparseVoxelGrid(unique_coords, voxel_features, resolution)

Meshes

Vertices and faces: (V,F)(V, F)
class TriangleMesh:
    """
    Triangle mesh representation.
    
    A triangle mesh is the 3D equivalent of a polygon in 2D: connect vertices
    with triangular faces to approximate any surface. Think of it like
    origami -- you can approximate any curved surface with enough small
    triangles. This is the standard representation in computer graphics,
    gaming, and CAD.
    
    Properties:
    - Explicit surface geometry (vertices + faces define the surface exactly)
    - Good for rendering (GPUs are optimized for triangle rasterization)
    - Topology can be complex (handles, holes, non-manifold edges)
    
    Pitfall: Unlike point clouds, meshes have connectivity information
    (which vertices are connected by faces). This topology must be
    preserved during processing, which makes mesh-based deep learning
    significantly more complex than point-cloud-based approaches.
    """
    
    def __init__(
        self,
        vertices: torch.Tensor,  # [V, 3] vertex positions
        faces: torch.Tensor      # [F, 3] face indices
    ):
        self.vertices = vertices
        self.faces = faces
        self.num_vertices = vertices.shape[0]
        self.num_faces = faces.shape[0]
    
    def compute_normals(self) -> torch.Tensor:
        """Compute face normals."""
        v0 = self.vertices[self.faces[:, 0]]
        v1 = self.vertices[self.faces[:, 1]]
        v2 = self.vertices[self.faces[:, 2]]
        
        e1 = v1 - v0
        e2 = v2 - v0
        
        normals = torch.cross(e1, e2, dim=1)
        normals = F.normalize(normals, dim=1)
        
        return normals
    
    def compute_vertex_normals(self) -> torch.Tensor:
        """Compute per-vertex normals."""
        face_normals = self.compute_normals()
        
        vertex_normals = torch.zeros_like(self.vertices)
        
        for i, face in enumerate(self.faces):
            vertex_normals[face] += face_normals[i]
        
        vertex_normals = F.normalize(vertex_normals, dim=1)
        
        return vertex_normals
    
    def sample_surface(self, n_points: int) -> torch.Tensor:
        """Sample points uniformly on mesh surface."""
        # Compute face areas
        v0 = self.vertices[self.faces[:, 0]]
        v1 = self.vertices[self.faces[:, 1]]
        v2 = self.vertices[self.faces[:, 2]]
        
        cross = torch.cross(v1 - v0, v2 - v0, dim=1)
        areas = cross.norm(dim=1) / 2
        
        # Sample faces proportional to area
        probs = areas / areas.sum()
        face_idx = torch.multinomial(probs, n_points, replacement=True)
        
        # Random barycentric coordinates
        r1 = torch.sqrt(torch.rand(n_points))
        r2 = torch.rand(n_points)
        
        w0 = 1 - r1
        w1 = r1 * (1 - r2)
        w2 = r1 * r2
        
        # Interpolate
        points = (
            w0.unsqueeze(1) * self.vertices[self.faces[face_idx, 0]] +
            w1.unsqueeze(1) * self.vertices[self.faces[face_idx, 1]] +
            w2.unsqueeze(1) * self.vertices[self.faces[face_idx, 2]]
        )
        
        return points

PointNet

The foundational architecture for point cloud processing. PointNet was the first deep learning model to work directly on raw, unordered point sets. Before PointNet (2017), the standard approach was to voxelize point clouds or render them from multiple viewpoints — both of which lose information. PointNet’s key insight: use a symmetric function (max pooling) to handle the unordered nature of point sets, making the output invariant to the order in which points are fed in:
class TNet(nn.Module):
    """
    T-Net: Spatial transformer network for point clouds.
    Learns transformation to canonicalize input.
    """
    
    def __init__(self, k: int = 3):
        super().__init__()
        self.k = k
        
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, k, N] input points or features
        
        Returns:
            transform: [B, k, k] transformation matrix
        """
        batch_size = x.size(0)
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Global max pooling
        x = x.max(dim=2)[0]
        
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        
        # Initialize as identity
        identity = torch.eye(self.k, device=x.device).view(1, -1).repeat(batch_size, 1)
        x = x + identity
        
        x = x.view(batch_size, self.k, self.k)
        
        return x


class PointNet(nn.Module):
    """
    PointNet: Deep Learning on Point Sets (Qi et al., 2017).
    
    Key innovations:
    1. Permutation invariance via max pooling -- since max() gives the
       same result regardless of input order, the network's output is
       the same no matter how you shuffle the points
    2. Spatial transformer (T-Net) for alignment -- learns to rotate/align
       the input so the network does not have to learn separate features
       for every possible orientation
    3. Works directly on raw point clouds -- no voxelization or projection
    
    Limitation: PointNet processes each point independently before pooling,
    so it has no concept of local structure. Two points that are neighbors
    in 3D space are treated the same as two points far apart. PointNet++
    addresses this with hierarchical local feature learning.
    """
    
    def __init__(
        self,
        num_classes: int = 40,
        input_channels: int = 3,
        use_tnet: bool = True
    ):
        super().__init__()
        
        self.use_tnet = use_tnet
        
        # Input transform
        if use_tnet:
            self.input_tnet = TNet(k=input_channels)
            self.feature_tnet = TNet(k=64)
        
        # MLP layers
        self.conv1 = nn.Conv1d(input_channels, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 64, 1)
        self.conv4 = nn.Conv1d(64, 128, 1)
        self.conv5 = nn.Conv1d(128, 1024, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)
        
        # Classifier
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.bn6 = nn.BatchNorm1d(512)
        self.bn7 = nn.BatchNorm1d(256)
        
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Args:
            x: [B, N, 3] point cloud
        
        Returns:
            logits: [B, num_classes] classification logits
            transforms: Dictionary of transformation matrices
        """
        batch_size, num_points, _ = x.shape
        
        # [B, N, 3] -> [B, 3, N]
        x = x.transpose(1, 2)
        
        transforms = {}
        
        # Input transform
        if self.use_tnet:
            input_transform = self.input_tnet(x)
            transforms['input'] = input_transform
            x = torch.bmm(input_transform, x)
        
        # MLP (64, 64)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        
        # Feature transform
        if self.use_tnet:
            feature_transform = self.feature_tnet(x)
            transforms['feature'] = feature_transform
            x = torch.bmm(feature_transform, x)
        
        # MLP (64, 128, 1024)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        
        # Global max pooling (permutation invariance!)
        # This is the critical operation: max() over the point dimension
        # means the output is the same regardless of point ordering.
        # Each of the 1024 channels captures "the most activated point"
        # for that particular learned feature detector.
        x = x.max(dim=2)[0]  # [B, 1024]
        
        # Classifier
        x = F.relu(self.bn6(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn7(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x, transforms
    
    @staticmethod
    def feature_transform_regularization(transform: torch.Tensor) -> torch.Tensor:
        """Regularization loss for feature transform (should be orthogonal)."""
        batch_size = transform.size(0)
        k = transform.size(1)
        
        I = torch.eye(k, device=transform.device).unsqueeze(0)
        diff = torch.bmm(transform, transform.transpose(1, 2)) - I
        
        return (diff ** 2).sum() / batch_size


class PointNetSegmentation(nn.Module):
    """PointNet for per-point segmentation."""
    
    def __init__(self, num_classes: int = 50, input_channels: int = 3):
        super().__init__()
        
        self.input_tnet = TNet(k=input_channels)
        self.feature_tnet = TNet(k=64)
        
        # Encoder
        self.conv1 = nn.Conv1d(input_channels, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 128, 1)
        self.conv4 = nn.Conv1d(128, 1024, 1)
        
        # Decoder (with skip connections)
        self.conv5 = nn.Conv1d(1088, 512, 1)  # 1024 + 64
        self.conv6 = nn.Conv1d(512, 256, 1)
        self.conv7 = nn.Conv1d(256, 128, 1)
        self.conv8 = nn.Conv1d(128, num_classes, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(1024)
        self.bn5 = nn.BatchNorm1d(512)
        self.bn6 = nn.BatchNorm1d(256)
        self.bn7 = nn.BatchNorm1d(128)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, N, 3] point cloud
        
        Returns:
            segmentation: [B, N, num_classes] per-point logits
        """
        batch_size, num_points, _ = x.shape
        x = x.transpose(1, 2)
        
        # Input transform
        input_transform = self.input_tnet(x)
        x = torch.bmm(input_transform, x)
        
        # Encode
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        local_features = x  # Save for skip connection
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Global feature
        global_feature = x.max(dim=2, keepdim=True)[0]  # [B, 1024, 1]
        global_feature = global_feature.repeat(1, 1, num_points)  # [B, 1024, N]
        
        # Concatenate local and global
        x = torch.cat([local_features, global_feature], dim=1)  # [B, 1088, N]
        
        # Decode
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = self.conv8(x)
        
        return x.transpose(1, 2)  # [B, N, num_classes]

PointNet++

PointNet++ addresses PointNet’s biggest weakness: it has no notion of local geometry. PointNet++ applies PointNet recursively on increasingly larger local regions, much like how CNNs build features from small patches to large receptive fields. The hierarchy works like a zoom-out: first learn fine-grained local features from small neighborhoods, then learn broader features from larger regions:
class SetAbstraction(nn.Module):
    """
    PointNet++ Set Abstraction layer.
    
    This is the workhorse of PointNet++. Each layer reduces the number of
    points while increasing the feature dimension -- analogous to a conv +
    pool layer in a CNN.
    
    Steps:
    1. Sample points (farthest point sampling) -- select representative
       points that are spread out evenly, like placing survey stations to
       maximize coverage of a landscape
    2. Group neighbors (ball query) -- for each sampled point, find all
       points within a radius to form a local patch
    3. Apply PointNet to each group -- extract local features using the
       same permutation-invariant approach as the original PointNet
    """
    
    def __init__(
        self,
        n_points: int,
        radius: float,
        n_samples: int,
        in_channels: int,
        mlp_channels: List[int]
    ):
        super().__init__()
        
        self.n_points = n_points
        self.radius = radius
        self.n_samples = n_samples
        
        # PointNet-style MLP
        self.mlps = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        prev_channels = in_channels + 3  # xyz + features
        for out_channels in mlp_channels:
            self.mlps.append(nn.Conv2d(prev_channels, out_channels, 1))
            self.bns.append(nn.BatchNorm2d(out_channels))
            prev_channels = out_channels
        
        self.out_channels = mlp_channels[-1]
    
    def farthest_point_sample(
        self,
        xyz: torch.Tensor,
        n_points: int
    ) -> torch.Tensor:
        """FPS sampling."""
        B, N, _ = xyz.shape
        
        centroids = torch.zeros(B, n_points, dtype=torch.long, device=xyz.device)
        distance = torch.ones(B, N, device=xyz.device) * 1e10
        
        farthest = torch.randint(0, N, (B,), device=xyz.device)
        batch_indices = torch.arange(B, device=xyz.device)
        
        for i in range(n_points):
            centroids[:, i] = farthest
            centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
            dist = ((xyz - centroid) ** 2).sum(dim=-1)
            distance = torch.min(distance, dist)
            farthest = distance.argmax(dim=-1)
        
        return centroids
    
    def ball_query(
        self,
        xyz: torch.Tensor,
        new_xyz: torch.Tensor,
        radius: float,
        n_samples: int
    ) -> torch.Tensor:
        """Ball query grouping."""
        B, N, _ = xyz.shape
        _, S, _ = new_xyz.shape
        
        group_idx = torch.zeros(B, S, n_samples, dtype=torch.long, device=xyz.device)
        
        for b in range(B):
            for s in range(S):
                center = new_xyz[b, s]
                dists = ((xyz[b] - center) ** 2).sum(dim=-1)
                
                within_radius = dists < radius ** 2
                indices = within_radius.nonzero().squeeze(-1)
                
                if len(indices) >= n_samples:
                    group_idx[b, s] = indices[:n_samples]
                elif len(indices) > 0:
                    # Repeat to fill
                    repeat = n_samples // len(indices) + 1
                    indices = indices.repeat(repeat)[:n_samples]
                    group_idx[b, s] = indices
                # else: keep zeros (first point)
        
        return group_idx
    
    def forward(
        self,
        xyz: torch.Tensor,
        features: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            xyz: [B, N, 3] point positions
            features: [B, N, C] point features
        
        Returns:
            new_xyz: [B, n_points, 3] sampled positions
            new_features: [B, n_points, C'] new features
        """
        B, N, _ = xyz.shape
        
        # Sample points
        fps_idx = self.farthest_point_sample(xyz, self.n_points)
        
        # Get new xyz
        batch_indices = torch.arange(B, device=xyz.device).view(B, 1).expand(-1, self.n_points)
        new_xyz = xyz[batch_indices, fps_idx]  # [B, n_points, 3]
        
        # Group neighbors
        group_idx = self.ball_query(xyz, new_xyz, self.radius, self.n_samples)
        
        # Get grouped xyz (relative to center)
        batch_indices = batch_indices.unsqueeze(-1).expand(-1, -1, self.n_samples)
        grouped_xyz = xyz[batch_indices, group_idx]  # [B, n_points, n_samples, 3]
        grouped_xyz = grouped_xyz - new_xyz.unsqueeze(2)
        
        # Get grouped features
        if features is not None:
            grouped_features = features[batch_indices, group_idx]
            grouped_features = torch.cat([grouped_xyz, grouped_features], dim=-1)
        else:
            grouped_features = grouped_xyz
        
        # [B, n_points, n_samples, C] -> [B, C, n_points, n_samples]
        grouped_features = grouped_features.permute(0, 3, 1, 2)
        
        # Apply MLP
        for mlp, bn in zip(self.mlps, self.bns):
            grouped_features = F.relu(bn(mlp(grouped_features)))
        
        # Max pool over neighbors
        new_features = grouped_features.max(dim=-1)[0]  # [B, C', n_points]
        new_features = new_features.transpose(1, 2)  # [B, n_points, C']
        
        return new_xyz, new_features


class PointNetPlusPlus(nn.Module):
    """
    PointNet++: Deep Hierarchical Feature Learning (Qi et al., 2017).
    
    Improvements over PointNet:
    1. Local feature learning (hierarchical)
    2. Multi-scale grouping
    3. Better handling of non-uniform density
    """
    
    def __init__(self, num_classes: int = 40):
        super().__init__()
        
        # Set abstraction layers
        self.sa1 = SetAbstraction(512, 0.2, 32, 0, [64, 64, 128])
        self.sa2 = SetAbstraction(128, 0.4, 64, 128, [128, 128, 256])
        self.sa3 = SetAbstraction(1, float('inf'), 128, 256, [256, 512, 1024])
        
        # Classifier
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, xyz: torch.Tensor) -> torch.Tensor:
        """
        Args:
            xyz: [B, N, 3] point cloud
        
        Returns:
            logits: [B, num_classes]
        """
        # Hierarchical abstraction -- each layer reduces points, increases features
        # Level 1: 1024 pts -> 512 pts, with local features from radius-0.2 neighborhoods
        l1_xyz, l1_features = self.sa1(xyz, None)
        # Level 2: 512 pts -> 128 pts, with larger radius-0.4 neighborhoods
        l2_xyz, l2_features = self.sa2(l1_xyz, l1_features)
        # Level 3: 128 pts -> 1 pt (global), captures the entire shape
        l3_xyz, l3_features = self.sa3(l2_xyz, l2_features)
        
        # Global feature
        x = l3_features.squeeze(1)  # [B, 1024]
        
        # Classifier
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

3D Convolutions

For voxel and volumetric data, 3D convolutions are the natural extension of 2D convolutions — the same sliding-window approach, just with an extra spatial dimension. The math is identical, but the computational cost scales cubically: a 3D convolution with kernel size kk on a volume of resolution n3n^3 costs O(k3n3)O(k^3 \cdot n^3), compared to O(k2n2)O(k^2 \cdot n^2) for 2D. This cubic scaling is why sparse voxel methods (which only compute on occupied voxels) dominate practical 3D applications.
Dense 3D convolutions are memory-prohibitive for anything beyond about 64x64x64 resolution. A 128x128x128 feature volume with 64 channels and float32 consumes over 8 GB. For real-world 3D tasks (autonomous driving, medical imaging), always use sparse convolution libraries like MinkowskiEngine or TorchSparse, which only compute on non-empty voxels and can handle resolutions of 1000+ along each axis.
class Conv3DBlock(nn.Module):
    """3D convolution block."""
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1
    ):
        super().__init__()
        
        self.conv = nn.Conv3d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=False
        )
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.bn(self.conv(x)))


class VoxNet(nn.Module):
    """
    VoxNet: A 3D Convolutional Neural Network.
    Simple but effective for 3D classification.
    """
    
    def __init__(self, num_classes: int = 40, input_channels: int = 1):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv3d(input_channels, 32, 5, stride=2, padding=2),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(32, 32, 3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(32, 64, 3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2)
        )
        
        # Classifier (assuming 32³ input → 4³ after pooling)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4 * 4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, 1, D, H, W] voxel grid
        
        Returns:
            logits: [B, num_classes]
        """
        x = self.features(x)
        x = self.classifier(x)
        return x


class SparseConv3D(nn.Module):
    """
    Sparse 3D convolution for efficiency.
    Only computes on occupied voxels.
    
    In a typical indoor scene, less than 5% of voxels are occupied. Dense
    3D convolution wastes 95%+ of compute on empty air. Sparse convolutions
    skip empty voxels entirely, making 3D CNNs practical for large scenes.
    
    Practical tip: Use MinkowskiEngine or TorchSparse for production sparse
    convolutions. The implementation below is conceptual -- real sparse conv
    requires specialized hash-table-based data structures for efficiency.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3
    ):
        super().__init__()
        
        self.kernel_size = kernel_size
        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size, kernel_size)
        )
        self.bias = nn.Parameter(torch.zeros(out_channels))
    
    def forward(
        self,
        coords: torch.Tensor,   # [N, 3] voxel coordinates
        features: torch.Tensor  # [N, C] voxel features
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sparse convolution (simplified)."""
        # In practice, use libraries like MinkowskiEngine or TorchSparse
        
        # Build coordinate hash map
        # For each voxel, find neighbors within kernel
        # Apply convolution only on active voxels
        
        # This is a placeholder - real implementation is complex
        return coords, features

Point Transformer

Just as transformers revolutionized NLP and 2D vision, they are making their mark on 3D point clouds. Point Transformer applies self-attention within local neighborhoods of points, allowing each point to dynamically weight the contribution of its neighbors based on both their features and their relative spatial positions. This is strictly more expressive than the fixed aggregation functions in PointNet/PointNet++.
class PointTransformerLayer(nn.Module):
    """
    Point Transformer layer (Zhao et al., 2021).
    Self-attention for point clouds.
    
    Unlike standard transformers that attend to all tokens, Point Transformer
    attends only to k-nearest neighbors in 3D space. This is both more
    efficient (local attention) and more principled (distant points rarely
    matter for local geometry understanding).
    """
    
    def __init__(self, in_channels: int, out_channels: int, k: int = 16):
        super().__init__()
        
        self.k = k  # Number of neighbors
        
        self.to_q = nn.Linear(in_channels, out_channels)
        self.to_k = nn.Linear(in_channels, out_channels)
        self.to_v = nn.Linear(in_channels, out_channels)
        
        # Position encoding
        self.pos_enc = nn.Sequential(
            nn.Linear(3, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
        
        # Attention MLP
        self.attn_mlp = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(
        self,
        xyz: torch.Tensor,    # [B, N, 3]
        features: torch.Tensor  # [B, N, C]
    ) -> torch.Tensor:
        """
        Returns:
            new_features: [B, N, C']
        """
        B, N, _ = xyz.shape
        
        # Find k nearest neighbors
        dist = torch.cdist(xyz, xyz)  # [B, N, N]
        _, knn_idx = dist.topk(self.k, dim=-1, largest=False)  # [B, N, k]
        
        # Gather neighbor features
        batch_idx = torch.arange(B, device=xyz.device).view(B, 1, 1).expand(-1, N, self.k)
        point_idx = torch.arange(N, device=xyz.device).view(1, N, 1).expand(B, -1, self.k)
        
        neighbor_xyz = xyz[batch_idx, knn_idx]  # [B, N, k, 3]
        neighbor_features = features[batch_idx, knn_idx]  # [B, N, k, C]
        
        # Position encoding
        rel_pos = neighbor_xyz - xyz.unsqueeze(2)  # [B, N, k, 3]
        pos_enc = self.pos_enc(rel_pos)  # [B, N, k, C']
        
        # Q, K, V
        q = self.to_q(features).unsqueeze(2)  # [B, N, 1, C']
        k = self.to_k(neighbor_features)  # [B, N, k, C']
        v = self.to_v(neighbor_features)  # [B, N, k, C']
        
        # Attention with position encoding
        attn = q - k + pos_enc  # [B, N, k, C']
        attn = self.attn_mlp(attn)  # [B, N, k, C']
        attn = F.softmax(attn, dim=2)  # [B, N, k, C']
        
        # Aggregate
        out = (attn * (v + pos_enc)).sum(dim=2)  # [B, N, C']
        
        return out


class PointTransformer(nn.Module):
    """Point Transformer for 3D understanding."""
    
    def __init__(self, num_classes: int = 40):
        super().__init__()
        
        self.embed = nn.Linear(3, 32)
        
        self.transformer1 = PointTransformerLayer(32, 64)
        self.transformer2 = PointTransformerLayer(64, 128)
        self.transformer3 = PointTransformerLayer(128, 256)
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, xyz: torch.Tensor) -> torch.Tensor:
        features = self.embed(xyz)
        
        features = self.transformer1(xyz, features)
        features = self.transformer2(xyz, features)
        features = self.transformer3(xyz, features)
        
        # Global pooling
        global_feat = features.max(dim=1)[0]
        
        return self.classifier(global_feat)

Best Practices

def best_practices_3d():
    """3D deep learning guidelines."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════╗
    ║                 3D DEEP LEARNING BEST PRACTICES                ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. REPRESENTATION CHOICE                                      ║
    ║     • Point cloud: Raw sensor data, flexible                   ║
    ║     • Voxel: Regular, uses 3D conv, memory heavy               ║
    ║     • Mesh: Surface-based, good for rendering                  ║
    ║     • Multi-view: 2D CNNs on rendered views                    ║
    ║                                                                ║
    ║  2. DATA AUGMENTATION                                          ║
    ║     • Random rotation (SO(3) or SO(2))                         ║
    ║     • Random scaling                                           ║
    ║     • Random jitter (add noise)                                ║
    ║     • Random point dropout                                     ║
    ║                                                                ║
    ║  3. EFFICIENCY                                                 ║
    ║     • Farthest point sampling for uniform coverage             ║
    ║     • Sparse convolutions for voxels                           ║
    ║     • KNN with spatial hashing                                 ║
    ║                                                                ║
    ║  4. NORMALIZATION                                              ║
    ║     • Center to origin                                         ║
    ║     • Scale to unit sphere/cube                                ║
    ║     • Align principal axes                                     ║
    ║                                                                ║
    ║  5. EVALUATION                                                 ║
    ║     • Instance accuracy (per-object)                           ║
    ║     • Class-mean accuracy (balanced)                           ║
    ║     • mIoU for segmentation                                    ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(tips)

best_practices_3d()

Exercises

Implement MSG (Multi-Scale Grouping) for PointNet++:
# Use multiple ball queries with different radii
# Concatenate features from different scales
Build a simple 3D detection model:
  • Backbone: PointNet++ or VoxelNet
  • Head: 3D bounding box regression
  • Output: (x, y, z, l, w, h, θ)
Implement encoder-decoder for completing partial point clouds:
  • Encoder: PointNet feature extraction
  • Decoder: FoldingNet or MLP-based generation

What’s Next?

Object Detection

YOLO, Faster R-CNN, DETR

Semantic Segmentation

U-Net, DeepLab, panoptic segmentation