Advanced Spark Topics
Module Duration: 4-5 hours
Focus: Advanced RDD operations, custom partitioners, accumulators, broadcast variables
Prerequisites: All previous modules
Overview
This module covers advanced Spark concepts including low-level RDD operations, custom partitioners, accumulators, broadcast variables, Delta Lake, GraphX, and integration with the broader data ecosystem.Advanced RDD Operations
RDD Internals
Copy
from pyspark import SparkContext
sc = SparkContext("local[*]", "Advanced RDD")
# Create RDD
rdd = sc.parallelize(range(1000), numSlices=10)
# Inspect RDD
print(f"Partitions: {rdd.getNumPartitions()}")
print(f"Partitioner: {rdd.partitioner}")
# Get partition boundaries
def show_partition(index, iterator):
yield f"Partition {index}: {list(iterator)[:5]}..."
rdd.mapPartitionsWithIndex(show_partition).collect()
mapPartitions vs map
Copy
# map: Called for each element
squared = rdd.map(lambda x: x * x)
# mapPartitions: Called once per partition
# More efficient for expensive setup/teardown
def process_partition(iterator):
# Setup (once per partition)
connection = create_connection()
# Process all elements
for item in iterator:
yield process_with_connection(connection, item)
# Cleanup
connection.close()
result = rdd.mapPartitions(process_partition)
# mapPartitionsWithIndex: Include partition index
def process_with_index(index, iterator):
for item in iterator:
yield (index, item)
indexed = rdd.mapPartitionsWithIndex(process_with_index)
glom - View Partition Contents
Copy
# glom: Convert each partition to a list
partitions = rdd.glom().collect()
for i, partition in enumerate(partitions):
print(f"Partition {i}: {len(partition)} elements")
# Useful for debugging partition distribution
coalesce vs repartition
Copy
# repartition: Full shuffle, can increase/decrease
rdd_repart = rdd.repartition(20) # 10 -> 20 partitions (shuffle)
# coalesce: Minimize shuffle, only decrease
rdd_coal = rdd.coalesce(5) # 10 -> 5 partitions (minimal shuffle)
# coalesce with shuffle
rdd_coal_shuffle = rdd.coalesce(15, shuffle=True) # Force shuffle
Sampling and Statistics
Copy
# Sample RDD
sample = rdd.sample(
withReplacement=False,
fraction=0.1,
seed=42
)
# Statistics
stats = rdd.stats()
print(f"Count: {stats.count()}")
print(f"Mean: {stats.mean()}")
print(f"StdDev: {stats.stdev()}")
print(f"Min/Max: {stats.min()}, {stats.max()}")
# Histogram
histogram = rdd.histogram(10) # 10 buckets
print(f"Buckets: {histogram[0]}")
print(f"Counts: {histogram[1]}")
# Stratified sampling
pairs = sc.parallelize([(i % 3, i) for i in range(100)])
stratified = pairs.sampleByKey(
withReplacement=False,
fractions={0: 0.1, 1: 0.2, 2: 0.3},
seed=42
)
Custom Partitioners
Why Custom Partitioners
Copy
# Default hash partitioner may not be optimal
# Custom partitioners allow domain-specific partitioning
from pyspark import Partitioner
class DomainPartitioner(Partitioner):
"""Partition by domain for URL processing"""
def __init__(self, num_partitions):
self._num_partitions = num_partitions
def numPartitions(self):
return self._num_partitions
def getPartition(self, key):
# Extract domain from URL
domain = key.split('/')[2] if '/' in key else key
return hash(domain) % self._num_partitions
# Use custom partitioner
urls = sc.parallelize([
("https://example.com/page1", 100),
("https://example.com/page2", 200),
("https://other.com/page1", 150)
])
partitioned = urls.partitionBy(
numPartitions=4,
partitionFunc=lambda key: hash(key.split('/')[2]) % 4
)
Range Partitioner
Copy
class RangePartitioner(Partitioner):
"""Partition based on value ranges"""
def __init__(self, ranges):
self.ranges = ranges # [(0, 10), (10, 50), (50, 100)]
def numPartitions(self):
return len(self.ranges)
def getPartition(self, key):
for i, (low, high) in enumerate(self.ranges):
if low <= key < high:
return i
return len(self.ranges) - 1 # Default to last partition
# Example: Partition user IDs by age ranges
users = sc.parallelize([
(5, "child"),
(25, "adult"),
(65, "senior"),
(15, "teen")
])
age_partitioner = RangePartitioner([
(0, 18), # Partition 0: minors
(18, 65), # Partition 1: adults
(65, 150) # Partition 2: seniors
])
partitioned_users = users.partitionBy(3, lambda age: age_partitioner.getPartition(age))
Preserving Partitioning
Copy
# Some operations preserve partitioning
pairs = sc.parallelize([(i, i*2) for i in range(100)])
partitioned = pairs.partitionBy(10)
# Preserved: map, flatMap, filter on values
preserved = partitioned.mapValues(lambda x: x * 2)
print(f"Partitioner preserved: {preserved.partitioner is not None}")
# Lost: map on keys
not_preserved = partitioned.map(lambda kv: (kv[0] * 2, kv[1]))
print(f"Partitioner lost: {not_preserved.partitioner is None}")
# Co-partitioning for joins
rdd1 = sc.parallelize([(i, f"value{i}") for i in range(100)]).partitionBy(10)
rdd2 = sc.parallelize([(i, i*2) for i in range(100)]).partitionBy(10)
# No shuffle needed - same partitioner
joined = rdd1.join(rdd2) # Efficient!
Accumulators
Basic Accumulators
Copy
# Accumulators: Shared variables for aggregating information
sc = SparkContext("local", "Accumulators")
# Long accumulator
counter = sc.accumulator(0)
def count_errors(line):
global counter
if "ERROR" in line:
counter.add(1)
return line
logs = sc.textFile("/logs/*.log")
processed = logs.map(count_errors)
processed.count() # Trigger action
print(f"Total errors: {counter.value}")
Custom Accumulators
Copy
from pyspark import AccumulatorParam
class SetAccumulatorParam(AccumulatorParam):
"""Accumulator for collecting unique values"""
def zero(self, initial_value):
return set()
def addInPlace(self, acc1, acc2):
return acc1 | acc2 # Set union
# Create custom accumulator
unique_users = sc.accumulator(set(), SetAccumulatorParam())
def track_user(record):
global unique_users
unique_users.add(record['user_id'])
return record
data.map(track_user).count()
print(f"Unique users: {len(unique_users.value)}")
Multiple Accumulators
Copy
# Track multiple metrics
total_records = sc.accumulator(0)
error_records = sc.accumulator(0)
skipped_records = sc.accumulator(0)
def process_record(record):
total_records.add(1)
try:
# Process record
if should_skip(record):
skipped_records.add(1)
return None
return transform(record)
except Exception as e:
error_records.add(1)
return None
result = data.map(process_record).filter(lambda x: x is not None)
result.count()
print(f"Total: {total_records.value}")
print(f"Errors: {error_records.value}")
print(f"Skipped: {skipped_records.value}")
print(f"Success: {total_records.value - error_records.value - skipped_records.value}")
Accumulator Best Practices
Copy
# 1. Only use in actions, not transformations
# Bad: May count multiple times due to re-computation
bad_counter = sc.accumulator(0)
rdd.map(lambda x: bad_counter.add(1) or x).take(10) # Wrong!
# Good: Use in actions or after caching
good_counter = sc.accumulator(0)
cached = rdd.cache()
cached.foreach(lambda x: good_counter.add(1)) # Correct!
# 2. Accumulators are write-only from executors
# Can only read on driver
print(f"Value on driver: {good_counter.value}")
# 3. Use for side effects, not core logic
# Accumulators for monitoring, not business logic
Broadcast Variables
Basic Broadcasting
Copy
# Broadcast: Efficiently share read-only data to all executors
lookup_table = {
"US": "United States",
"UK": "United Kingdom",
"CA": "Canada"
}
# Broadcast to all executors
broadcast_lookup = sc.broadcast(lookup_table)
def expand_country(code):
# Access broadcast value
return broadcast_lookup.value.get(code, "Unknown")
codes = sc.parallelize(["US", "UK", "FR", "CA"])
expanded = codes.map(expand_country)
print(expanded.collect())
# Clean up
broadcast_lookup.unpersist()
Large Dataset Broadcasting
Copy
# Example: Broadcast reference data for joins
# Instead of shuffle join, broadcast small dataset
# Load small dimension table (e.g., 100MB)
products = sc.textFile("/data/products.csv") \
.map(parse_product) \
.collectAsMap() # Collect to driver
# Broadcast to executors
broadcast_products = sc.broadcast(products)
# Use in transformations
def enrich_transaction(txn):
product = broadcast_products.value.get(txn['product_id'])
return {
**txn,
'product_name': product['name'],
'category': product['category']
}
transactions = sc.textFile("/data/transactions.csv") \
.map(parse_transaction) \
.map(enrich_transaction)
Broadcast for Complex Objects
Copy
import pickle
# Broadcast ML model
from sklearn.linear_model import LogisticRegression
# Train model (on driver)
model = LogisticRegression()
model.fit(X_train, y_train)
# Serialize and broadcast
broadcast_model = sc.broadcast(model)
def predict_batch(partition):
"""Apply model to partition"""
model = broadcast_model.value
for features in partition:
yield model.predict([features])[0]
predictions = features_rdd.mapPartitions(predict_batch)
Broadcast Best Practices
Copy
# 1. Size considerations
# Good for < 1GB datasets
# Larger datasets: Use distributed caching
# 2. Update broadcasts
old_broadcast = sc.broadcast(old_data)
# ... use old_broadcast ...
# Update with new data
old_broadcast.unpersist()
new_broadcast = sc.broadcast(new_data)
# 3. Serialization
# Ensure broadcast data is serializable
# Use pickle-compatible objects
# 4. Memory management
# Unpersist when done
broadcast_var.unpersist()
broadcast_var.destroy() # Remove from all executors
Delta Lake Integration
Introduction to Delta Lake
Copy
# Delta Lake: ACID transactions, schema enforcement, time travel
from delta import *
spark = SparkSession.builder \
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
.config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
.getOrCreate()
Creating Delta Tables
Copy
# Write DataFrame as Delta table
df = spark.read.json("/data/events.json")
df.write \
.format("delta") \
.mode("overwrite") \
.save("/delta/events")
# Or as managed table
df.write \
.format("delta") \
.mode("overwrite") \
.saveAsTable("events")
Reading Delta Tables
Copy
# Read Delta table
delta_df = spark.read \
.format("delta") \
.load("/delta/events")
# Or from table name
delta_df = spark.table("events")
# Read specific version (time travel)
historical_df = spark.read \
.format("delta") \
.option("versionAsOf", 0) \
.load("/delta/events")
# Read as of timestamp
timestamp_df = spark.read \
.format("delta") \
.option("timestampAsOf", "2024-01-01") \
.load("/delta/events")
ACID Transactions
Copy
from delta.tables import DeltaTable
# Upsert (merge)
delta_table = DeltaTable.forPath(spark, "/delta/events")
updates = spark.read.json("/data/updates.json")
delta_table.alias("target") \
.merge(
updates.alias("source"),
"target.id = source.id"
) \
.whenMatchedUpdate(set={
"status": "source.status",
"updated_at": "source.updated_at"
}) \
.whenNotMatchedInsert(values={
"id": "source.id",
"status": "source.status",
"created_at": "source.created_at"
}) \
.execute()
Schema Evolution
Copy
# Enable schema merging
df_new_schema = spark.read.json("/data/new_events.json")
df_new_schema.write \
.format("delta") \
.mode("append") \
.option("mergeSchema", "true") \
.save("/delta/events")
# Schema enforcement
try:
invalid_df.write \
.format("delta") \
.mode("append") \
.save("/delta/events")
except Exception as e:
print(f"Schema mismatch: {e}")
Optimization
Copy
# Optimize small files
delta_table = DeltaTable.forPath(spark, "/delta/events")
delta_table.optimize().executeCompaction()
# Z-ordering for better filtering
delta_table.optimize() \
.where("date >= '2024-01-01'") \
.executeZOrderBy("user_id", "timestamp")
# Vacuum old files
delta_table.vacuum(retentionHours=168) # 7 days
Time Travel
Copy
# View history
delta_table = DeltaTable.forPath(spark, "/delta/events")
history = delta_table.history()
history.select("version", "timestamp", "operation").show()
# Restore to previous version
delta_table.restoreToVersion(5)
# Or restore to timestamp
delta_table.restoreToTimestamp("2024-01-01")
# Clone table
spark.sql("""
CREATE TABLE events_backup
SHALLOW CLONE events
VERSION AS OF 10
""")
GraphX for Graph Processing
Creating Graphs
Copy
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
// Define vertices (id, property)
val vertices: RDD[(VertexId, String)] = sc.parallelize(Array(
(1L, "Alice"),
(2L, "Bob"),
(3L, "Charlie"),
(4L, "David")
))
// Define edges (srcId, dstId, property)
val edges: RDD[Edge[String]] = sc.parallelize(Array(
Edge(1L, 2L, "friend"),
Edge(2L, 3L, "colleague"),
Edge(3L, 4L, "friend"),
Edge(1L, 4L, "family")
))
// Create graph
val graph = Graph(vertices, edges)
Graph Operations
Copy
// Vertex count
val numVertices = graph.vertices.count()
// Edge count
val numEdges = graph.edges.count()
// Degree distribution
val degrees = graph.degrees.collect()
// Find triangles
val triangles = graph.triangleCount()
// Connected components
val cc = graph.connectedComponents()
// PageRank
val ranks = graph.pageRank(0.0001).vertices
// Join with original graph
val users = vertices.join(ranks).map {
case (id, (name, rank)) => (id, name, rank)
}
Pregel API
Copy
// Pregel: Bulk synchronous parallel computation
val shortestPaths = graph.pregel(
initialMsg = Double.PositiveInfinity,
maxIterations = Int.MaxValue,
activeDirection = EdgeDirection.Out
)(
// Vertex program
(id, dist, newDist) => math.min(dist, newDist),
// Send message
triplet => {
if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
} else {
Iterator.empty
}
},
// Merge messages
(a, b) => math.min(a, b)
)
Performance Profiling
Spark UI Deep Dive
Copy
# Enable event logging
spark.conf.set("spark.eventLog.enabled", "true")
spark.conf.set("spark.eventLog.dir", "/tmp/spark-events")
# Detailed metrics
spark.conf.set("spark.sql.codegen.wholeStage", "true")
spark.conf.set("spark.sql.adaptive.enabled", "true")
# Profile specific queries
df.explain("cost") # Show cost-based optimization
df.explain("extended") # Show all plans
Python Profiling
Copy
# Enable Python profiling
spark.conf.set("spark.python.profile", "true")
# Run workload
result = rdd.map(expensive_function).collect()
# View profile results
sc.show_profiles()
# Dump to file
sc.dump_profiles("/tmp/profiles")
Memory Profiling
Copy
# Track memory usage
from memory_profiler import profile
@profile
def process_large_dataset():
df = spark.read.parquet("/large/dataset")
result = df.groupBy("key").agg({"value": "sum"})
result.write.parquet("/output")
# Monitor executor memory
spark.sparkContext._jsc.sc().getExecutorMemoryStatus()
Hands-On Exercises
Exercise 1: Custom Partitioner
Copy
# TODO: Implement custom partitioner
# 1. Create partitioner for geographic regions
# 2. Partition customer data by region
# 3. Verify partition distribution
# 4. Compare performance with hash partitioner
# Your code here
Exercise 2: Accumulator for Data Quality
Copy
# TODO: Create accumulators for data quality metrics
# 1. Track null values per column
# 2. Count invalid records
# 3. Monitor data distributions
# 4. Generate quality report
# Your code here
Exercise 3: Delta Lake Pipeline
Copy
# TODO: Build Delta Lake ETL pipeline
# 1. Ingest raw data to Bronze table
# 2. Transform and write to Silver table
# 3. Aggregate to Gold table
# 4. Implement time travel and optimization
# Your code here
Summary
Advanced Spark techniques enable sophisticated data processing:- RDD Operations: Low-level control and optimization
- Custom Partitioners: Domain-specific data distribution
- Accumulators: Distributed aggregation and monitoring
- Broadcast Variables: Efficient data sharing
- Delta Lake: ACID transactions and time travel
- GraphX: Graph analytics at scale
Key Takeaways
- Use RDDs for low-level control when needed
- Custom partitioners improve co-location
- Accumulators for metrics, not core logic
- Broadcast for read-only reference data
- Delta Lake for reliable data lakes
- Profile and measure for optimization
Continue to the final capstone project to apply all learned concepts.