3D Deep Learning
The World is 3D
Real-world data is inherently 3D:- Autonomous driving: LiDAR point clouds
- Robotics: Depth sensors, manipulation
- Medical imaging: CT, MRI volumes
- AR/VR: Scene reconstruction
- Manufacturing: Quality inspection
Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List
torch.manual_seed(42)
3D Representations
Point Clouds
Unordered set of 3D points: {(xi,yi,zi)}i=1NCopy
class PointCloud:
"""
Point cloud representation.
Properties:
- Unordered (permutation invariant)
- Irregular sampling
- Direct from sensors (LiDAR)
"""
def __init__(
self,
points: torch.Tensor, # [N, 3] xyz coordinates
features: Optional[torch.Tensor] = None, # [N, C] per-point features
normals: Optional[torch.Tensor] = None # [N, 3] surface normals
):
self.points = points
self.features = features
self.normals = normals
self.num_points = points.shape[0]
def normalize(self) -> 'PointCloud':
"""Center and scale to unit sphere."""
centroid = self.points.mean(dim=0)
points = self.points - centroid
max_dist = points.norm(dim=1).max()
points = points / max_dist
return PointCloud(points, self.features, self.normals)
def random_sample(self, n_points: int) -> 'PointCloud':
"""Randomly sample points."""
idx = torch.randperm(self.num_points)[:n_points]
return PointCloud(
self.points[idx],
self.features[idx] if self.features is not None else None,
self.normals[idx] if self.normals is not None else None
)
def farthest_point_sample(self, n_points: int) -> 'PointCloud':
"""Farthest point sampling for uniform coverage."""
N = self.num_points
centroids = torch.zeros(n_points, dtype=torch.long)
distance = torch.ones(N) * 1e10
# Random starting point
farthest = torch.randint(0, N, (1,)).item()
for i in range(n_points):
centroids[i] = farthest
centroid_point = self.points[farthest]
dist = ((self.points - centroid_point) ** 2).sum(dim=1)
distance = torch.min(distance, dist)
farthest = distance.argmax().item()
return PointCloud(self.points[centroids])
# Example
points = torch.randn(1024, 3)
pc = PointCloud(points)
pc_normalized = pc.normalize()
pc_sampled = pc.farthest_point_sample(256)
Voxels
3D grid of values: V∈RD×H×WCopy
class VoxelGrid:
"""
Voxel grid representation.
Properties:
- Regular 3D grid (like 3D pixels)
- Can use 3D convolutions
- Memory intensive (O(n³))
"""
def __init__(
self,
grid: torch.Tensor, # [C, D, H, W] or [D, H, W]
resolution: int
):
self.grid = grid
self.resolution = resolution
@staticmethod
def from_point_cloud(
points: torch.Tensor,
resolution: int = 32
) -> 'VoxelGrid':
"""Convert point cloud to voxel grid."""
# Normalize to [0, 1]
min_pt = points.min(dim=0)[0]
max_pt = points.max(dim=0)[0]
points = (points - min_pt) / (max_pt - min_pt + 1e-8)
# Quantize to grid
indices = (points * (resolution - 1)).long()
indices = indices.clamp(0, resolution - 1)
# Create occupancy grid
grid = torch.zeros(resolution, resolution, resolution)
grid[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
return VoxelGrid(grid.unsqueeze(0), resolution)
def to_point_cloud(self) -> torch.Tensor:
"""Convert voxel grid back to point cloud."""
occupied = self.grid.squeeze() > 0.5
indices = occupied.nonzero().float()
points = indices / (self.resolution - 1)
return points
class SparseVoxelGrid:
"""
Sparse voxel representation for efficiency.
Only stores occupied voxels.
"""
def __init__(
self,
coords: torch.Tensor, # [N, 3] voxel coordinates
features: torch.Tensor, # [N, C] voxel features
resolution: int
):
self.coords = coords
self.features = features
self.resolution = resolution
@staticmethod
def from_point_cloud(
points: torch.Tensor,
features: torch.Tensor,
voxel_size: float = 0.05
) -> 'SparseVoxelGrid':
"""Voxelize point cloud with feature averaging."""
# Quantize points
coords = (points / voxel_size).floor().long()
# Unique voxels
unique_coords, inverse = torch.unique(
coords, dim=0, return_inverse=True
)
# Average features per voxel
num_voxels = unique_coords.shape[0]
voxel_features = torch.zeros(num_voxels, features.shape[1])
counts = torch.zeros(num_voxels)
for i, idx in enumerate(inverse):
voxel_features[idx] += features[i]
counts[idx] += 1
voxel_features = voxel_features / counts.unsqueeze(1)
resolution = coords.max().item() + 1
return SparseVoxelGrid(unique_coords, voxel_features, resolution)
Meshes
Vertices and faces: (V,F)Copy
class TriangleMesh:
"""
Triangle mesh representation.
Properties:
- Explicit surface geometry
- Good for rendering
- Topology can be complex
"""
def __init__(
self,
vertices: torch.Tensor, # [V, 3] vertex positions
faces: torch.Tensor # [F, 3] face indices
):
self.vertices = vertices
self.faces = faces
self.num_vertices = vertices.shape[0]
self.num_faces = faces.shape[0]
def compute_normals(self) -> torch.Tensor:
"""Compute face normals."""
v0 = self.vertices[self.faces[:, 0]]
v1 = self.vertices[self.faces[:, 1]]
v2 = self.vertices[self.faces[:, 2]]
e1 = v1 - v0
e2 = v2 - v0
normals = torch.cross(e1, e2, dim=1)
normals = F.normalize(normals, dim=1)
return normals
def compute_vertex_normals(self) -> torch.Tensor:
"""Compute per-vertex normals."""
face_normals = self.compute_normals()
vertex_normals = torch.zeros_like(self.vertices)
for i, face in enumerate(self.faces):
vertex_normals[face] += face_normals[i]
vertex_normals = F.normalize(vertex_normals, dim=1)
return vertex_normals
def sample_surface(self, n_points: int) -> torch.Tensor:
"""Sample points uniformly on mesh surface."""
# Compute face areas
v0 = self.vertices[self.faces[:, 0]]
v1 = self.vertices[self.faces[:, 1]]
v2 = self.vertices[self.faces[:, 2]]
cross = torch.cross(v1 - v0, v2 - v0, dim=1)
areas = cross.norm(dim=1) / 2
# Sample faces proportional to area
probs = areas / areas.sum()
face_idx = torch.multinomial(probs, n_points, replacement=True)
# Random barycentric coordinates
r1 = torch.sqrt(torch.rand(n_points))
r2 = torch.rand(n_points)
w0 = 1 - r1
w1 = r1 * (1 - r2)
w2 = r1 * r2
# Interpolate
points = (
w0.unsqueeze(1) * self.vertices[self.faces[face_idx, 0]] +
w1.unsqueeze(1) * self.vertices[self.faces[face_idx, 1]] +
w2.unsqueeze(1) * self.vertices[self.faces[face_idx, 2]]
)
return points
PointNet
The foundational architecture for point cloud processing:Copy
class TNet(nn.Module):
"""
T-Net: Spatial transformer network for point clouds.
Learns transformation to canonicalize input.
"""
def __init__(self, k: int = 3):
super().__init__()
self.k = k
self.conv1 = nn.Conv1d(k, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, k, N] input points or features
Returns:
transform: [B, k, k] transformation matrix
"""
batch_size = x.size(0)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
# Global max pooling
x = x.max(dim=2)[0]
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
# Initialize as identity
identity = torch.eye(self.k, device=x.device).view(1, -1).repeat(batch_size, 1)
x = x + identity
x = x.view(batch_size, self.k, self.k)
return x
class PointNet(nn.Module):
"""
PointNet: Deep Learning on Point Sets (Qi et al., 2017).
Key innovations:
1. Permutation invariance via max pooling
2. Spatial transformer (T-Net) for alignment
3. Works directly on raw point clouds
"""
def __init__(
self,
num_classes: int = 40,
input_channels: int = 3,
use_tnet: bool = True
):
super().__init__()
self.use_tnet = use_tnet
# Input transform
if use_tnet:
self.input_tnet = TNet(k=input_channels)
self.feature_tnet = TNet(k=64)
# MLP layers
self.conv1 = nn.Conv1d(input_channels, 64, 1)
self.conv2 = nn.Conv1d(64, 64, 1)
self.conv3 = nn.Conv1d(64, 64, 1)
self.conv4 = nn.Conv1d(64, 128, 1)
self.conv5 = nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(64)
self.bn3 = nn.BatchNorm1d(64)
self.bn4 = nn.BatchNorm1d(128)
self.bn5 = nn.BatchNorm1d(1024)
# Classifier
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, num_classes)
self.bn6 = nn.BatchNorm1d(512)
self.bn7 = nn.BatchNorm1d(256)
self.dropout = nn.Dropout(0.3)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""
Args:
x: [B, N, 3] point cloud
Returns:
logits: [B, num_classes] classification logits
transforms: Dictionary of transformation matrices
"""
batch_size, num_points, _ = x.shape
# [B, N, 3] -> [B, 3, N]
x = x.transpose(1, 2)
transforms = {}
# Input transform
if self.use_tnet:
input_transform = self.input_tnet(x)
transforms['input'] = input_transform
x = torch.bmm(input_transform, x)
# MLP (64, 64)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
# Feature transform
if self.use_tnet:
feature_transform = self.feature_tnet(x)
transforms['feature'] = feature_transform
x = torch.bmm(feature_transform, x)
# MLP (64, 128, 1024)
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
x = F.relu(self.bn5(self.conv5(x)))
# Global max pooling (permutation invariance!)
x = x.max(dim=2)[0] # [B, 1024]
# Classifier
x = F.relu(self.bn6(self.fc1(x)))
x = self.dropout(x)
x = F.relu(self.bn7(self.fc2(x)))
x = self.dropout(x)
x = self.fc3(x)
return x, transforms
@staticmethod
def feature_transform_regularization(transform: torch.Tensor) -> torch.Tensor:
"""Regularization loss for feature transform (should be orthogonal)."""
batch_size = transform.size(0)
k = transform.size(1)
I = torch.eye(k, device=transform.device).unsqueeze(0)
diff = torch.bmm(transform, transform.transpose(1, 2)) - I
return (diff ** 2).sum() / batch_size
class PointNetSegmentation(nn.Module):
"""PointNet for per-point segmentation."""
def __init__(self, num_classes: int = 50, input_channels: int = 3):
super().__init__()
self.input_tnet = TNet(k=input_channels)
self.feature_tnet = TNet(k=64)
# Encoder
self.conv1 = nn.Conv1d(input_channels, 64, 1)
self.conv2 = nn.Conv1d(64, 64, 1)
self.conv3 = nn.Conv1d(64, 128, 1)
self.conv4 = nn.Conv1d(128, 1024, 1)
# Decoder (with skip connections)
self.conv5 = nn.Conv1d(1088, 512, 1) # 1024 + 64
self.conv6 = nn.Conv1d(512, 256, 1)
self.conv7 = nn.Conv1d(256, 128, 1)
self.conv8 = nn.Conv1d(128, num_classes, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(64)
self.bn3 = nn.BatchNorm1d(128)
self.bn4 = nn.BatchNorm1d(1024)
self.bn5 = nn.BatchNorm1d(512)
self.bn6 = nn.BatchNorm1d(256)
self.bn7 = nn.BatchNorm1d(128)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, N, 3] point cloud
Returns:
segmentation: [B, N, num_classes] per-point logits
"""
batch_size, num_points, _ = x.shape
x = x.transpose(1, 2)
# Input transform
input_transform = self.input_tnet(x)
x = torch.bmm(input_transform, x)
# Encode
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
local_features = x # Save for skip connection
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
# Global feature
global_feature = x.max(dim=2, keepdim=True)[0] # [B, 1024, 1]
global_feature = global_feature.repeat(1, 1, num_points) # [B, 1024, N]
# Concatenate local and global
x = torch.cat([local_features, global_feature], dim=1) # [B, 1088, N]
# Decode
x = F.relu(self.bn5(self.conv5(x)))
x = F.relu(self.bn6(self.conv6(x)))
x = F.relu(self.bn7(self.conv7(x)))
x = self.conv8(x)
return x.transpose(1, 2) # [B, N, num_classes]
PointNet++
Hierarchical learning on point clouds:Copy
class SetAbstraction(nn.Module):
"""
PointNet++ Set Abstraction layer.
Steps:
1. Sample points (farthest point sampling)
2. Group neighbors (ball query)
3. Apply PointNet to each group
"""
def __init__(
self,
n_points: int,
radius: float,
n_samples: int,
in_channels: int,
mlp_channels: List[int]
):
super().__init__()
self.n_points = n_points
self.radius = radius
self.n_samples = n_samples
# PointNet-style MLP
self.mlps = nn.ModuleList()
self.bns = nn.ModuleList()
prev_channels = in_channels + 3 # xyz + features
for out_channels in mlp_channels:
self.mlps.append(nn.Conv2d(prev_channels, out_channels, 1))
self.bns.append(nn.BatchNorm2d(out_channels))
prev_channels = out_channels
self.out_channels = mlp_channels[-1]
def farthest_point_sample(
self,
xyz: torch.Tensor,
n_points: int
) -> torch.Tensor:
"""FPS sampling."""
B, N, _ = xyz.shape
centroids = torch.zeros(B, n_points, dtype=torch.long, device=xyz.device)
distance = torch.ones(B, N, device=xyz.device) * 1e10
farthest = torch.randint(0, N, (B,), device=xyz.device)
batch_indices = torch.arange(B, device=xyz.device)
for i in range(n_points):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = ((xyz - centroid) ** 2).sum(dim=-1)
distance = torch.min(distance, dist)
farthest = distance.argmax(dim=-1)
return centroids
def ball_query(
self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
radius: float,
n_samples: int
) -> torch.Tensor:
"""Ball query grouping."""
B, N, _ = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.zeros(B, S, n_samples, dtype=torch.long, device=xyz.device)
for b in range(B):
for s in range(S):
center = new_xyz[b, s]
dists = ((xyz[b] - center) ** 2).sum(dim=-1)
within_radius = dists < radius ** 2
indices = within_radius.nonzero().squeeze(-1)
if len(indices) >= n_samples:
group_idx[b, s] = indices[:n_samples]
elif len(indices) > 0:
# Repeat to fill
repeat = n_samples // len(indices) + 1
indices = indices.repeat(repeat)[:n_samples]
group_idx[b, s] = indices
# else: keep zeros (first point)
return group_idx
def forward(
self,
xyz: torch.Tensor,
features: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
xyz: [B, N, 3] point positions
features: [B, N, C] point features
Returns:
new_xyz: [B, n_points, 3] sampled positions
new_features: [B, n_points, C'] new features
"""
B, N, _ = xyz.shape
# Sample points
fps_idx = self.farthest_point_sample(xyz, self.n_points)
# Get new xyz
batch_indices = torch.arange(B, device=xyz.device).view(B, 1).expand(-1, self.n_points)
new_xyz = xyz[batch_indices, fps_idx] # [B, n_points, 3]
# Group neighbors
group_idx = self.ball_query(xyz, new_xyz, self.radius, self.n_samples)
# Get grouped xyz (relative to center)
batch_indices = batch_indices.unsqueeze(-1).expand(-1, -1, self.n_samples)
grouped_xyz = xyz[batch_indices, group_idx] # [B, n_points, n_samples, 3]
grouped_xyz = grouped_xyz - new_xyz.unsqueeze(2)
# Get grouped features
if features is not None:
grouped_features = features[batch_indices, group_idx]
grouped_features = torch.cat([grouped_xyz, grouped_features], dim=-1)
else:
grouped_features = grouped_xyz
# [B, n_points, n_samples, C] -> [B, C, n_points, n_samples]
grouped_features = grouped_features.permute(0, 3, 1, 2)
# Apply MLP
for mlp, bn in zip(self.mlps, self.bns):
grouped_features = F.relu(bn(mlp(grouped_features)))
# Max pool over neighbors
new_features = grouped_features.max(dim=-1)[0] # [B, C', n_points]
new_features = new_features.transpose(1, 2) # [B, n_points, C']
return new_xyz, new_features
class PointNetPlusPlus(nn.Module):
"""
PointNet++: Deep Hierarchical Feature Learning (Qi et al., 2017).
Improvements over PointNet:
1. Local feature learning (hierarchical)
2. Multi-scale grouping
3. Better handling of non-uniform density
"""
def __init__(self, num_classes: int = 40):
super().__init__()
# Set abstraction layers
self.sa1 = SetAbstraction(512, 0.2, 32, 0, [64, 64, 128])
self.sa2 = SetAbstraction(128, 0.4, 64, 128, [128, 128, 256])
self.sa3 = SetAbstraction(1, float('inf'), 128, 256, [256, 512, 1024])
# Classifier
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, num_classes)
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
self.dropout = nn.Dropout(0.5)
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Args:
xyz: [B, N, 3] point cloud
Returns:
logits: [B, num_classes]
"""
# Hierarchical abstraction
l1_xyz, l1_features = self.sa1(xyz, None)
l2_xyz, l2_features = self.sa2(l1_xyz, l1_features)
l3_xyz, l3_features = self.sa3(l2_xyz, l2_features)
# Global feature
x = l3_features.squeeze(1) # [B, 1024]
# Classifier
x = F.relu(self.bn1(self.fc1(x)))
x = self.dropout(x)
x = F.relu(self.bn2(self.fc2(x)))
x = self.dropout(x)
x = self.fc3(x)
return x
3D Convolutions
For voxel and volumetric data:Copy
class Conv3DBlock(nn.Module):
"""3D convolution block."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1
):
super().__init__()
self.conv = nn.Conv3d(
in_channels, out_channels,
kernel_size, stride, padding, bias=False
)
self.bn = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.relu(self.bn(self.conv(x)))
class VoxNet(nn.Module):
"""
VoxNet: A 3D Convolutional Neural Network.
Simple but effective for 3D classification.
"""
def __init__(self, num_classes: int = 40, input_channels: int = 1):
super().__init__()
self.features = nn.Sequential(
nn.Conv3d(input_channels, 32, 5, stride=2, padding=2),
nn.BatchNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(32, 32, 3, padding=1),
nn.BatchNorm3d(32),
nn.ReLU(inplace=True),
nn.MaxPool3d(2),
nn.Conv3d(32, 64, 3, padding=1),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.MaxPool3d(2)
)
# Classifier (assuming 32³ input → 4³ after pooling)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 4 * 4 * 4, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, 1, D, H, W] voxel grid
Returns:
logits: [B, num_classes]
"""
x = self.features(x)
x = self.classifier(x)
return x
class SparseConv3D(nn.Module):
"""
Sparse 3D convolution for efficiency.
Only computes on occupied voxels.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3
):
super().__init__()
self.kernel_size = kernel_size
self.weight = nn.Parameter(
torch.randn(out_channels, in_channels, kernel_size, kernel_size, kernel_size)
)
self.bias = nn.Parameter(torch.zeros(out_channels))
def forward(
self,
coords: torch.Tensor, # [N, 3] voxel coordinates
features: torch.Tensor # [N, C] voxel features
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sparse convolution (simplified)."""
# In practice, use libraries like MinkowskiEngine or TorchSparse
# Build coordinate hash map
# For each voxel, find neighbors within kernel
# Apply convolution only on active voxels
# This is a placeholder - real implementation is complex
return coords, features
Point Transformer
Attention on point clouds:Copy
class PointTransformerLayer(nn.Module):
"""
Point Transformer layer.
Self-attention for point clouds.
"""
def __init__(self, in_channels: int, out_channels: int, k: int = 16):
super().__init__()
self.k = k # Number of neighbors
self.to_q = nn.Linear(in_channels, out_channels)
self.to_k = nn.Linear(in_channels, out_channels)
self.to_v = nn.Linear(in_channels, out_channels)
# Position encoding
self.pos_enc = nn.Sequential(
nn.Linear(3, out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
# Attention MLP
self.attn_mlp = nn.Sequential(
nn.Linear(out_channels, out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
def forward(
self,
xyz: torch.Tensor, # [B, N, 3]
features: torch.Tensor # [B, N, C]
) -> torch.Tensor:
"""
Returns:
new_features: [B, N, C']
"""
B, N, _ = xyz.shape
# Find k nearest neighbors
dist = torch.cdist(xyz, xyz) # [B, N, N]
_, knn_idx = dist.topk(self.k, dim=-1, largest=False) # [B, N, k]
# Gather neighbor features
batch_idx = torch.arange(B, device=xyz.device).view(B, 1, 1).expand(-1, N, self.k)
point_idx = torch.arange(N, device=xyz.device).view(1, N, 1).expand(B, -1, self.k)
neighbor_xyz = xyz[batch_idx, knn_idx] # [B, N, k, 3]
neighbor_features = features[batch_idx, knn_idx] # [B, N, k, C]
# Position encoding
rel_pos = neighbor_xyz - xyz.unsqueeze(2) # [B, N, k, 3]
pos_enc = self.pos_enc(rel_pos) # [B, N, k, C']
# Q, K, V
q = self.to_q(features).unsqueeze(2) # [B, N, 1, C']
k = self.to_k(neighbor_features) # [B, N, k, C']
v = self.to_v(neighbor_features) # [B, N, k, C']
# Attention with position encoding
attn = q - k + pos_enc # [B, N, k, C']
attn = self.attn_mlp(attn) # [B, N, k, C']
attn = F.softmax(attn, dim=2) # [B, N, k, C']
# Aggregate
out = (attn * (v + pos_enc)).sum(dim=2) # [B, N, C']
return out
class PointTransformer(nn.Module):
"""Point Transformer for 3D understanding."""
def __init__(self, num_classes: int = 40):
super().__init__()
self.embed = nn.Linear(3, 32)
self.transformer1 = PointTransformerLayer(32, 64)
self.transformer2 = PointTransformerLayer(64, 128)
self.transformer3 = PointTransformerLayer(128, 256)
self.classifier = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, num_classes)
)
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
features = self.embed(xyz)
features = self.transformer1(xyz, features)
features = self.transformer2(xyz, features)
features = self.transformer3(xyz, features)
# Global pooling
global_feat = features.max(dim=1)[0]
return self.classifier(global_feat)
Best Practices
Copy
def best_practices_3d():
"""3D deep learning guidelines."""
tips = """
╔════════════════════════════════════════════════════════════════╗
║ 3D DEEP LEARNING BEST PRACTICES ║
╠════════════════════════════════════════════════════════════════╣
║ ║
║ 1. REPRESENTATION CHOICE ║
║ • Point cloud: Raw sensor data, flexible ║
║ • Voxel: Regular, uses 3D conv, memory heavy ║
║ • Mesh: Surface-based, good for rendering ║
║ • Multi-view: 2D CNNs on rendered views ║
║ ║
║ 2. DATA AUGMENTATION ║
║ • Random rotation (SO(3) or SO(2)) ║
║ • Random scaling ║
║ • Random jitter (add noise) ║
║ • Random point dropout ║
║ ║
║ 3. EFFICIENCY ║
║ • Farthest point sampling for uniform coverage ║
║ • Sparse convolutions for voxels ║
║ • KNN with spatial hashing ║
║ ║
║ 4. NORMALIZATION ║
║ • Center to origin ║
║ • Scale to unit sphere/cube ║
║ • Align principal axes ║
║ ║
║ 5. EVALUATION ║
║ • Instance accuracy (per-object) ║
║ • Class-mean accuracy (balanced) ║
║ • mIoU for segmentation ║
║ ║
╚════════════════════════════════════════════════════════════════╝
"""
print(tips)
best_practices_3d()
Exercises
Exercise 1: Multi-Scale Grouping
Exercise 1: Multi-Scale Grouping
Implement MSG (Multi-Scale Grouping) for PointNet++:
Copy
# Use multiple ball queries with different radii
# Concatenate features from different scales
Exercise 2: 3D Object Detection
Exercise 2: 3D Object Detection
Build a simple 3D detection model:
- Backbone: PointNet++ or VoxelNet
- Head: 3D bounding box regression
- Output: (x, y, z, l, w, h, θ)
Exercise 3: Point Cloud Completion
Exercise 3: Point Cloud Completion
Implement encoder-decoder for completing partial point clouds:
- Encoder: PointNet feature extraction
- Decoder: FoldingNet or MLP-based generation