Skip to main content
3D Deep Learning

3D Deep Learning

The World is 3D

Real-world data is inherently 3D:
  • Autonomous driving: LiDAR point clouds
  • Robotics: Depth sensors, manipulation
  • Medical imaging: CT, MRI volumes
  • AR/VR: Scene reconstruction
  • Manufacturing: Quality inspection
Different 3D representations require different approaches.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List

torch.manual_seed(42)

3D Representations

Point Clouds

Unordered set of 3D points: {(xi,yi,zi)}i=1N\{(x_i, y_i, z_i)\}_{i=1}^N
class PointCloud:
    """
    Point cloud representation.
    
    Properties:
    - Unordered (permutation invariant)
    - Irregular sampling
    - Direct from sensors (LiDAR)
    """
    
    def __init__(
        self,
        points: torch.Tensor,    # [N, 3] xyz coordinates
        features: Optional[torch.Tensor] = None,  # [N, C] per-point features
        normals: Optional[torch.Tensor] = None    # [N, 3] surface normals
    ):
        self.points = points
        self.features = features
        self.normals = normals
        self.num_points = points.shape[0]
    
    def normalize(self) -> 'PointCloud':
        """Center and scale to unit sphere."""
        centroid = self.points.mean(dim=0)
        points = self.points - centroid
        max_dist = points.norm(dim=1).max()
        points = points / max_dist
        return PointCloud(points, self.features, self.normals)
    
    def random_sample(self, n_points: int) -> 'PointCloud':
        """Randomly sample points."""
        idx = torch.randperm(self.num_points)[:n_points]
        return PointCloud(
            self.points[idx],
            self.features[idx] if self.features is not None else None,
            self.normals[idx] if self.normals is not None else None
        )
    
    def farthest_point_sample(self, n_points: int) -> 'PointCloud':
        """Farthest point sampling for uniform coverage."""
        N = self.num_points
        centroids = torch.zeros(n_points, dtype=torch.long)
        distance = torch.ones(N) * 1e10
        
        # Random starting point
        farthest = torch.randint(0, N, (1,)).item()
        
        for i in range(n_points):
            centroids[i] = farthest
            centroid_point = self.points[farthest]
            
            dist = ((self.points - centroid_point) ** 2).sum(dim=1)
            distance = torch.min(distance, dist)
            farthest = distance.argmax().item()
        
        return PointCloud(self.points[centroids])


# Example
points = torch.randn(1024, 3)
pc = PointCloud(points)
pc_normalized = pc.normalize()
pc_sampled = pc.farthest_point_sample(256)

Voxels

3D grid of values: VRD×H×WV \in \mathbb{R}^{D \times H \times W}
class VoxelGrid:
    """
    Voxel grid representation.
    
    Properties:
    - Regular 3D grid (like 3D pixels)
    - Can use 3D convolutions
    - Memory intensive (O(n³))
    """
    
    def __init__(
        self,
        grid: torch.Tensor,  # [C, D, H, W] or [D, H, W]
        resolution: int
    ):
        self.grid = grid
        self.resolution = resolution
    
    @staticmethod
    def from_point_cloud(
        points: torch.Tensor,
        resolution: int = 32
    ) -> 'VoxelGrid':
        """Convert point cloud to voxel grid."""
        # Normalize to [0, 1]
        min_pt = points.min(dim=0)[0]
        max_pt = points.max(dim=0)[0]
        points = (points - min_pt) / (max_pt - min_pt + 1e-8)
        
        # Quantize to grid
        indices = (points * (resolution - 1)).long()
        indices = indices.clamp(0, resolution - 1)
        
        # Create occupancy grid
        grid = torch.zeros(resolution, resolution, resolution)
        grid[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
        
        return VoxelGrid(grid.unsqueeze(0), resolution)
    
    def to_point_cloud(self) -> torch.Tensor:
        """Convert voxel grid back to point cloud."""
        occupied = self.grid.squeeze() > 0.5
        indices = occupied.nonzero().float()
        points = indices / (self.resolution - 1)
        return points


class SparseVoxelGrid:
    """
    Sparse voxel representation for efficiency.
    Only stores occupied voxels.
    """
    
    def __init__(
        self,
        coords: torch.Tensor,    # [N, 3] voxel coordinates
        features: torch.Tensor,  # [N, C] voxel features
        resolution: int
    ):
        self.coords = coords
        self.features = features
        self.resolution = resolution
    
    @staticmethod
    def from_point_cloud(
        points: torch.Tensor,
        features: torch.Tensor,
        voxel_size: float = 0.05
    ) -> 'SparseVoxelGrid':
        """Voxelize point cloud with feature averaging."""
        # Quantize points
        coords = (points / voxel_size).floor().long()
        
        # Unique voxels
        unique_coords, inverse = torch.unique(
            coords, dim=0, return_inverse=True
        )
        
        # Average features per voxel
        num_voxels = unique_coords.shape[0]
        voxel_features = torch.zeros(num_voxels, features.shape[1])
        counts = torch.zeros(num_voxels)
        
        for i, idx in enumerate(inverse):
            voxel_features[idx] += features[i]
            counts[idx] += 1
        
        voxel_features = voxel_features / counts.unsqueeze(1)
        
        resolution = coords.max().item() + 1
        
        return SparseVoxelGrid(unique_coords, voxel_features, resolution)

Meshes

Vertices and faces: (V,F)(V, F)
class TriangleMesh:
    """
    Triangle mesh representation.
    
    Properties:
    - Explicit surface geometry
    - Good for rendering
    - Topology can be complex
    """
    
    def __init__(
        self,
        vertices: torch.Tensor,  # [V, 3] vertex positions
        faces: torch.Tensor      # [F, 3] face indices
    ):
        self.vertices = vertices
        self.faces = faces
        self.num_vertices = vertices.shape[0]
        self.num_faces = faces.shape[0]
    
    def compute_normals(self) -> torch.Tensor:
        """Compute face normals."""
        v0 = self.vertices[self.faces[:, 0]]
        v1 = self.vertices[self.faces[:, 1]]
        v2 = self.vertices[self.faces[:, 2]]
        
        e1 = v1 - v0
        e2 = v2 - v0
        
        normals = torch.cross(e1, e2, dim=1)
        normals = F.normalize(normals, dim=1)
        
        return normals
    
    def compute_vertex_normals(self) -> torch.Tensor:
        """Compute per-vertex normals."""
        face_normals = self.compute_normals()
        
        vertex_normals = torch.zeros_like(self.vertices)
        
        for i, face in enumerate(self.faces):
            vertex_normals[face] += face_normals[i]
        
        vertex_normals = F.normalize(vertex_normals, dim=1)
        
        return vertex_normals
    
    def sample_surface(self, n_points: int) -> torch.Tensor:
        """Sample points uniformly on mesh surface."""
        # Compute face areas
        v0 = self.vertices[self.faces[:, 0]]
        v1 = self.vertices[self.faces[:, 1]]
        v2 = self.vertices[self.faces[:, 2]]
        
        cross = torch.cross(v1 - v0, v2 - v0, dim=1)
        areas = cross.norm(dim=1) / 2
        
        # Sample faces proportional to area
        probs = areas / areas.sum()
        face_idx = torch.multinomial(probs, n_points, replacement=True)
        
        # Random barycentric coordinates
        r1 = torch.sqrt(torch.rand(n_points))
        r2 = torch.rand(n_points)
        
        w0 = 1 - r1
        w1 = r1 * (1 - r2)
        w2 = r1 * r2
        
        # Interpolate
        points = (
            w0.unsqueeze(1) * self.vertices[self.faces[face_idx, 0]] +
            w1.unsqueeze(1) * self.vertices[self.faces[face_idx, 1]] +
            w2.unsqueeze(1) * self.vertices[self.faces[face_idx, 2]]
        )
        
        return points

PointNet

The foundational architecture for point cloud processing:
class TNet(nn.Module):
    """
    T-Net: Spatial transformer network for point clouds.
    Learns transformation to canonicalize input.
    """
    
    def __init__(self, k: int = 3):
        super().__init__()
        self.k = k
        
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, k, N] input points or features
        
        Returns:
            transform: [B, k, k] transformation matrix
        """
        batch_size = x.size(0)
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Global max pooling
        x = x.max(dim=2)[0]
        
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        
        # Initialize as identity
        identity = torch.eye(self.k, device=x.device).view(1, -1).repeat(batch_size, 1)
        x = x + identity
        
        x = x.view(batch_size, self.k, self.k)
        
        return x


class PointNet(nn.Module):
    """
    PointNet: Deep Learning on Point Sets (Qi et al., 2017).
    
    Key innovations:
    1. Permutation invariance via max pooling
    2. Spatial transformer (T-Net) for alignment
    3. Works directly on raw point clouds
    """
    
    def __init__(
        self,
        num_classes: int = 40,
        input_channels: int = 3,
        use_tnet: bool = True
    ):
        super().__init__()
        
        self.use_tnet = use_tnet
        
        # Input transform
        if use_tnet:
            self.input_tnet = TNet(k=input_channels)
            self.feature_tnet = TNet(k=64)
        
        # MLP layers
        self.conv1 = nn.Conv1d(input_channels, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 64, 1)
        self.conv4 = nn.Conv1d(64, 128, 1)
        self.conv5 = nn.Conv1d(128, 1024, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)
        
        # Classifier
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.bn6 = nn.BatchNorm1d(512)
        self.bn7 = nn.BatchNorm1d(256)
        
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Args:
            x: [B, N, 3] point cloud
        
        Returns:
            logits: [B, num_classes] classification logits
            transforms: Dictionary of transformation matrices
        """
        batch_size, num_points, _ = x.shape
        
        # [B, N, 3] -> [B, 3, N]
        x = x.transpose(1, 2)
        
        transforms = {}
        
        # Input transform
        if self.use_tnet:
            input_transform = self.input_tnet(x)
            transforms['input'] = input_transform
            x = torch.bmm(input_transform, x)
        
        # MLP (64, 64)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        
        # Feature transform
        if self.use_tnet:
            feature_transform = self.feature_tnet(x)
            transforms['feature'] = feature_transform
            x = torch.bmm(feature_transform, x)
        
        # MLP (64, 128, 1024)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        
        # Global max pooling (permutation invariance!)
        x = x.max(dim=2)[0]  # [B, 1024]
        
        # Classifier
        x = F.relu(self.bn6(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn7(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x, transforms
    
    @staticmethod
    def feature_transform_regularization(transform: torch.Tensor) -> torch.Tensor:
        """Regularization loss for feature transform (should be orthogonal)."""
        batch_size = transform.size(0)
        k = transform.size(1)
        
        I = torch.eye(k, device=transform.device).unsqueeze(0)
        diff = torch.bmm(transform, transform.transpose(1, 2)) - I
        
        return (diff ** 2).sum() / batch_size


class PointNetSegmentation(nn.Module):
    """PointNet for per-point segmentation."""
    
    def __init__(self, num_classes: int = 50, input_channels: int = 3):
        super().__init__()
        
        self.input_tnet = TNet(k=input_channels)
        self.feature_tnet = TNet(k=64)
        
        # Encoder
        self.conv1 = nn.Conv1d(input_channels, 64, 1)
        self.conv2 = nn.Conv1d(64, 64, 1)
        self.conv3 = nn.Conv1d(64, 128, 1)
        self.conv4 = nn.Conv1d(128, 1024, 1)
        
        # Decoder (with skip connections)
        self.conv5 = nn.Conv1d(1088, 512, 1)  # 1024 + 64
        self.conv6 = nn.Conv1d(512, 256, 1)
        self.conv7 = nn.Conv1d(256, 128, 1)
        self.conv8 = nn.Conv1d(128, num_classes, 1)
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(1024)
        self.bn5 = nn.BatchNorm1d(512)
        self.bn6 = nn.BatchNorm1d(256)
        self.bn7 = nn.BatchNorm1d(128)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, N, 3] point cloud
        
        Returns:
            segmentation: [B, N, num_classes] per-point logits
        """
        batch_size, num_points, _ = x.shape
        x = x.transpose(1, 2)
        
        # Input transform
        input_transform = self.input_tnet(x)
        x = torch.bmm(input_transform, x)
        
        # Encode
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        local_features = x  # Save for skip connection
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Global feature
        global_feature = x.max(dim=2, keepdim=True)[0]  # [B, 1024, 1]
        global_feature = global_feature.repeat(1, 1, num_points)  # [B, 1024, N]
        
        # Concatenate local and global
        x = torch.cat([local_features, global_feature], dim=1)  # [B, 1088, N]
        
        # Decode
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = self.conv8(x)
        
        return x.transpose(1, 2)  # [B, N, num_classes]

PointNet++

Hierarchical learning on point clouds:
class SetAbstraction(nn.Module):
    """
    PointNet++ Set Abstraction layer.
    
    Steps:
    1. Sample points (farthest point sampling)
    2. Group neighbors (ball query)
    3. Apply PointNet to each group
    """
    
    def __init__(
        self,
        n_points: int,
        radius: float,
        n_samples: int,
        in_channels: int,
        mlp_channels: List[int]
    ):
        super().__init__()
        
        self.n_points = n_points
        self.radius = radius
        self.n_samples = n_samples
        
        # PointNet-style MLP
        self.mlps = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        prev_channels = in_channels + 3  # xyz + features
        for out_channels in mlp_channels:
            self.mlps.append(nn.Conv2d(prev_channels, out_channels, 1))
            self.bns.append(nn.BatchNorm2d(out_channels))
            prev_channels = out_channels
        
        self.out_channels = mlp_channels[-1]
    
    def farthest_point_sample(
        self,
        xyz: torch.Tensor,
        n_points: int
    ) -> torch.Tensor:
        """FPS sampling."""
        B, N, _ = xyz.shape
        
        centroids = torch.zeros(B, n_points, dtype=torch.long, device=xyz.device)
        distance = torch.ones(B, N, device=xyz.device) * 1e10
        
        farthest = torch.randint(0, N, (B,), device=xyz.device)
        batch_indices = torch.arange(B, device=xyz.device)
        
        for i in range(n_points):
            centroids[:, i] = farthest
            centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
            dist = ((xyz - centroid) ** 2).sum(dim=-1)
            distance = torch.min(distance, dist)
            farthest = distance.argmax(dim=-1)
        
        return centroids
    
    def ball_query(
        self,
        xyz: torch.Tensor,
        new_xyz: torch.Tensor,
        radius: float,
        n_samples: int
    ) -> torch.Tensor:
        """Ball query grouping."""
        B, N, _ = xyz.shape
        _, S, _ = new_xyz.shape
        
        group_idx = torch.zeros(B, S, n_samples, dtype=torch.long, device=xyz.device)
        
        for b in range(B):
            for s in range(S):
                center = new_xyz[b, s]
                dists = ((xyz[b] - center) ** 2).sum(dim=-1)
                
                within_radius = dists < radius ** 2
                indices = within_radius.nonzero().squeeze(-1)
                
                if len(indices) >= n_samples:
                    group_idx[b, s] = indices[:n_samples]
                elif len(indices) > 0:
                    # Repeat to fill
                    repeat = n_samples // len(indices) + 1
                    indices = indices.repeat(repeat)[:n_samples]
                    group_idx[b, s] = indices
                # else: keep zeros (first point)
        
        return group_idx
    
    def forward(
        self,
        xyz: torch.Tensor,
        features: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            xyz: [B, N, 3] point positions
            features: [B, N, C] point features
        
        Returns:
            new_xyz: [B, n_points, 3] sampled positions
            new_features: [B, n_points, C'] new features
        """
        B, N, _ = xyz.shape
        
        # Sample points
        fps_idx = self.farthest_point_sample(xyz, self.n_points)
        
        # Get new xyz
        batch_indices = torch.arange(B, device=xyz.device).view(B, 1).expand(-1, self.n_points)
        new_xyz = xyz[batch_indices, fps_idx]  # [B, n_points, 3]
        
        # Group neighbors
        group_idx = self.ball_query(xyz, new_xyz, self.radius, self.n_samples)
        
        # Get grouped xyz (relative to center)
        batch_indices = batch_indices.unsqueeze(-1).expand(-1, -1, self.n_samples)
        grouped_xyz = xyz[batch_indices, group_idx]  # [B, n_points, n_samples, 3]
        grouped_xyz = grouped_xyz - new_xyz.unsqueeze(2)
        
        # Get grouped features
        if features is not None:
            grouped_features = features[batch_indices, group_idx]
            grouped_features = torch.cat([grouped_xyz, grouped_features], dim=-1)
        else:
            grouped_features = grouped_xyz
        
        # [B, n_points, n_samples, C] -> [B, C, n_points, n_samples]
        grouped_features = grouped_features.permute(0, 3, 1, 2)
        
        # Apply MLP
        for mlp, bn in zip(self.mlps, self.bns):
            grouped_features = F.relu(bn(mlp(grouped_features)))
        
        # Max pool over neighbors
        new_features = grouped_features.max(dim=-1)[0]  # [B, C', n_points]
        new_features = new_features.transpose(1, 2)  # [B, n_points, C']
        
        return new_xyz, new_features


class PointNetPlusPlus(nn.Module):
    """
    PointNet++: Deep Hierarchical Feature Learning (Qi et al., 2017).
    
    Improvements over PointNet:
    1. Local feature learning (hierarchical)
    2. Multi-scale grouping
    3. Better handling of non-uniform density
    """
    
    def __init__(self, num_classes: int = 40):
        super().__init__()
        
        # Set abstraction layers
        self.sa1 = SetAbstraction(512, 0.2, 32, 0, [64, 64, 128])
        self.sa2 = SetAbstraction(128, 0.4, 64, 128, [128, 128, 256])
        self.sa3 = SetAbstraction(1, float('inf'), 128, 256, [256, 512, 1024])
        
        # Classifier
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, xyz: torch.Tensor) -> torch.Tensor:
        """
        Args:
            xyz: [B, N, 3] point cloud
        
        Returns:
            logits: [B, num_classes]
        """
        # Hierarchical abstraction
        l1_xyz, l1_features = self.sa1(xyz, None)
        l2_xyz, l2_features = self.sa2(l1_xyz, l1_features)
        l3_xyz, l3_features = self.sa3(l2_xyz, l2_features)
        
        # Global feature
        x = l3_features.squeeze(1)  # [B, 1024]
        
        # Classifier
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

3D Convolutions

For voxel and volumetric data:
class Conv3DBlock(nn.Module):
    """3D convolution block."""
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1
    ):
        super().__init__()
        
        self.conv = nn.Conv3d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=False
        )
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.relu(self.bn(self.conv(x)))


class VoxNet(nn.Module):
    """
    VoxNet: A 3D Convolutional Neural Network.
    Simple but effective for 3D classification.
    """
    
    def __init__(self, num_classes: int = 40, input_channels: int = 1):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv3d(input_channels, 32, 5, stride=2, padding=2),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(32, 32, 3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(32, 64, 3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2)
        )
        
        # Classifier (assuming 32³ input → 4³ after pooling)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4 * 4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, 1, D, H, W] voxel grid
        
        Returns:
            logits: [B, num_classes]
        """
        x = self.features(x)
        x = self.classifier(x)
        return x


class SparseConv3D(nn.Module):
    """
    Sparse 3D convolution for efficiency.
    Only computes on occupied voxels.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3
    ):
        super().__init__()
        
        self.kernel_size = kernel_size
        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size, kernel_size)
        )
        self.bias = nn.Parameter(torch.zeros(out_channels))
    
    def forward(
        self,
        coords: torch.Tensor,   # [N, 3] voxel coordinates
        features: torch.Tensor  # [N, C] voxel features
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sparse convolution (simplified)."""
        # In practice, use libraries like MinkowskiEngine or TorchSparse
        
        # Build coordinate hash map
        # For each voxel, find neighbors within kernel
        # Apply convolution only on active voxels
        
        # This is a placeholder - real implementation is complex
        return coords, features

Point Transformer

Attention on point clouds:
class PointTransformerLayer(nn.Module):
    """
    Point Transformer layer.
    Self-attention for point clouds.
    """
    
    def __init__(self, in_channels: int, out_channels: int, k: int = 16):
        super().__init__()
        
        self.k = k  # Number of neighbors
        
        self.to_q = nn.Linear(in_channels, out_channels)
        self.to_k = nn.Linear(in_channels, out_channels)
        self.to_v = nn.Linear(in_channels, out_channels)
        
        # Position encoding
        self.pos_enc = nn.Sequential(
            nn.Linear(3, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
        
        # Attention MLP
        self.attn_mlp = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(
        self,
        xyz: torch.Tensor,    # [B, N, 3]
        features: torch.Tensor  # [B, N, C]
    ) -> torch.Tensor:
        """
        Returns:
            new_features: [B, N, C']
        """
        B, N, _ = xyz.shape
        
        # Find k nearest neighbors
        dist = torch.cdist(xyz, xyz)  # [B, N, N]
        _, knn_idx = dist.topk(self.k, dim=-1, largest=False)  # [B, N, k]
        
        # Gather neighbor features
        batch_idx = torch.arange(B, device=xyz.device).view(B, 1, 1).expand(-1, N, self.k)
        point_idx = torch.arange(N, device=xyz.device).view(1, N, 1).expand(B, -1, self.k)
        
        neighbor_xyz = xyz[batch_idx, knn_idx]  # [B, N, k, 3]
        neighbor_features = features[batch_idx, knn_idx]  # [B, N, k, C]
        
        # Position encoding
        rel_pos = neighbor_xyz - xyz.unsqueeze(2)  # [B, N, k, 3]
        pos_enc = self.pos_enc(rel_pos)  # [B, N, k, C']
        
        # Q, K, V
        q = self.to_q(features).unsqueeze(2)  # [B, N, 1, C']
        k = self.to_k(neighbor_features)  # [B, N, k, C']
        v = self.to_v(neighbor_features)  # [B, N, k, C']
        
        # Attention with position encoding
        attn = q - k + pos_enc  # [B, N, k, C']
        attn = self.attn_mlp(attn)  # [B, N, k, C']
        attn = F.softmax(attn, dim=2)  # [B, N, k, C']
        
        # Aggregate
        out = (attn * (v + pos_enc)).sum(dim=2)  # [B, N, C']
        
        return out


class PointTransformer(nn.Module):
    """Point Transformer for 3D understanding."""
    
    def __init__(self, num_classes: int = 40):
        super().__init__()
        
        self.embed = nn.Linear(3, 32)
        
        self.transformer1 = PointTransformerLayer(32, 64)
        self.transformer2 = PointTransformerLayer(64, 128)
        self.transformer3 = PointTransformerLayer(128, 256)
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, xyz: torch.Tensor) -> torch.Tensor:
        features = self.embed(xyz)
        
        features = self.transformer1(xyz, features)
        features = self.transformer2(xyz, features)
        features = self.transformer3(xyz, features)
        
        # Global pooling
        global_feat = features.max(dim=1)[0]
        
        return self.classifier(global_feat)

Best Practices

def best_practices_3d():
    """3D deep learning guidelines."""
    
    tips = """
    ╔════════════════════════════════════════════════════════════════╗
    ║                 3D DEEP LEARNING BEST PRACTICES                ║
    ╠════════════════════════════════════════════════════════════════╣
    ║                                                                ║
    ║  1. REPRESENTATION CHOICE                                      ║
    ║     • Point cloud: Raw sensor data, flexible                   ║
    ║     • Voxel: Regular, uses 3D conv, memory heavy               ║
    ║     • Mesh: Surface-based, good for rendering                  ║
    ║     • Multi-view: 2D CNNs on rendered views                    ║
    ║                                                                ║
    ║  2. DATA AUGMENTATION                                          ║
    ║     • Random rotation (SO(3) or SO(2))                         ║
    ║     • Random scaling                                           ║
    ║     • Random jitter (add noise)                                ║
    ║     • Random point dropout                                     ║
    ║                                                                ║
    ║  3. EFFICIENCY                                                 ║
    ║     • Farthest point sampling for uniform coverage             ║
    ║     • Sparse convolutions for voxels                           ║
    ║     • KNN with spatial hashing                                 ║
    ║                                                                ║
    ║  4. NORMALIZATION                                              ║
    ║     • Center to origin                                         ║
    ║     • Scale to unit sphere/cube                                ║
    ║     • Align principal axes                                     ║
    ║                                                                ║
    ║  5. EVALUATION                                                 ║
    ║     • Instance accuracy (per-object)                           ║
    ║     • Class-mean accuracy (balanced)                           ║
    ║     • mIoU for segmentation                                    ║
    ║                                                                ║
    ╚════════════════════════════════════════════════════════════════╝
    """
    print(tips)

best_practices_3d()

Exercises

Implement MSG (Multi-Scale Grouping) for PointNet++:
# Use multiple ball queries with different radii
# Concatenate features from different scales
Build a simple 3D detection model:
  • Backbone: PointNet++ or VoxelNet
  • Head: 3D bounding box regression
  • Output: (x, y, z, l, w, h, θ)
Implement encoder-decoder for completing partial point clouds:
  • Encoder: PointNet feature extraction
  • Decoder: FoldingNet or MLP-based generation

What’s Next?