Skip to main content

Advanced Spark Topics

Module Duration: 4-5 hours Focus: Advanced RDD operations, custom partitioners, accumulators, broadcast variables Prerequisites: All previous modules

Overview

This module covers advanced Spark concepts including low-level RDD operations, custom partitioners, accumulators, broadcast variables, Delta Lake, GraphX, and integration with the broader data ecosystem.

Advanced RDD Operations

RDD Internals

from pyspark import SparkContext

sc = SparkContext("local[*]", "Advanced RDD")

# Create RDD
rdd = sc.parallelize(range(1000), numSlices=10)

# Inspect RDD
print(f"Partitions: {rdd.getNumPartitions()}")
print(f"Partitioner: {rdd.partitioner}")

# Get partition boundaries
def show_partition(index, iterator):
    yield f"Partition {index}: {list(iterator)[:5]}..."

rdd.mapPartitionsWithIndex(show_partition).collect()

mapPartitions vs map

# map: Called for each element
squared = rdd.map(lambda x: x * x)

# mapPartitions: Called once per partition
# More efficient for expensive setup/teardown
def process_partition(iterator):
    # Setup (once per partition)
    connection = create_connection()

    # Process all elements
    for item in iterator:
        yield process_with_connection(connection, item)

    # Cleanup
    connection.close()

result = rdd.mapPartitions(process_partition)

# mapPartitionsWithIndex: Include partition index
def process_with_index(index, iterator):
    for item in iterator:
        yield (index, item)

indexed = rdd.mapPartitionsWithIndex(process_with_index)

glom - View Partition Contents

# glom: Convert each partition to a list
partitions = rdd.glom().collect()
for i, partition in enumerate(partitions):
    print(f"Partition {i}: {len(partition)} elements")

# Useful for debugging partition distribution

coalesce vs repartition

# repartition: Full shuffle, can increase/decrease
rdd_repart = rdd.repartition(20)  # 10 -> 20 partitions (shuffle)

# coalesce: Minimize shuffle, only decrease
rdd_coal = rdd.coalesce(5)  # 10 -> 5 partitions (minimal shuffle)

# coalesce with shuffle
rdd_coal_shuffle = rdd.coalesce(15, shuffle=True)  # Force shuffle

Sampling and Statistics

# Sample RDD
sample = rdd.sample(
    withReplacement=False,
    fraction=0.1,
    seed=42
)

# Statistics
stats = rdd.stats()
print(f"Count: {stats.count()}")
print(f"Mean: {stats.mean()}")
print(f"StdDev: {stats.stdev()}")
print(f"Min/Max: {stats.min()}, {stats.max()}")

# Histogram
histogram = rdd.histogram(10)  # 10 buckets
print(f"Buckets: {histogram[0]}")
print(f"Counts: {histogram[1]}")

# Stratified sampling
pairs = sc.parallelize([(i % 3, i) for i in range(100)])
stratified = pairs.sampleByKey(
    withReplacement=False,
    fractions={0: 0.1, 1: 0.2, 2: 0.3},
    seed=42
)

Custom Partitioners

Why Custom Partitioners

# Default hash partitioner may not be optimal
# Custom partitioners allow domain-specific partitioning

from pyspark import Partitioner

class DomainPartitioner(Partitioner):
    """Partition by domain for URL processing"""

    def __init__(self, num_partitions):
        self._num_partitions = num_partitions

    def numPartitions(self):
        return self._num_partitions

    def getPartition(self, key):
        # Extract domain from URL
        domain = key.split('/')[2] if '/' in key else key
        return hash(domain) % self._num_partitions

# Use custom partitioner
urls = sc.parallelize([
    ("https://example.com/page1", 100),
    ("https://example.com/page2", 200),
    ("https://other.com/page1", 150)
])

partitioned = urls.partitionBy(
    numPartitions=4,
    partitionFunc=lambda key: hash(key.split('/')[2]) % 4
)

Range Partitioner

class RangePartitioner(Partitioner):
    """Partition based on value ranges"""

    def __init__(self, ranges):
        self.ranges = ranges  # [(0, 10), (10, 50), (50, 100)]

    def numPartitions(self):
        return len(self.ranges)

    def getPartition(self, key):
        for i, (low, high) in enumerate(self.ranges):
            if low <= key < high:
                return i
        return len(self.ranges) - 1  # Default to last partition

# Example: Partition user IDs by age ranges
users = sc.parallelize([
    (5, "child"),
    (25, "adult"),
    (65, "senior"),
    (15, "teen")
])

age_partitioner = RangePartitioner([
    (0, 18),    # Partition 0: minors
    (18, 65),   # Partition 1: adults
    (65, 150)   # Partition 2: seniors
])

partitioned_users = users.partitionBy(3, lambda age: age_partitioner.getPartition(age))

Preserving Partitioning

# Some operations preserve partitioning
pairs = sc.parallelize([(i, i*2) for i in range(100)])
partitioned = pairs.partitionBy(10)

# Preserved: map, flatMap, filter on values
preserved = partitioned.mapValues(lambda x: x * 2)
print(f"Partitioner preserved: {preserved.partitioner is not None}")

# Lost: map on keys
not_preserved = partitioned.map(lambda kv: (kv[0] * 2, kv[1]))
print(f"Partitioner lost: {not_preserved.partitioner is None}")

# Co-partitioning for joins
rdd1 = sc.parallelize([(i, f"value{i}") for i in range(100)]).partitionBy(10)
rdd2 = sc.parallelize([(i, i*2) for i in range(100)]).partitionBy(10)

# No shuffle needed - same partitioner
joined = rdd1.join(rdd2)  # Efficient!

Accumulators

Basic Accumulators

# Accumulators: Shared variables for aggregating information
sc = SparkContext("local", "Accumulators")

# Long accumulator
counter = sc.accumulator(0)

def count_errors(line):
    global counter
    if "ERROR" in line:
        counter.add(1)
    return line

logs = sc.textFile("/logs/*.log")
processed = logs.map(count_errors)
processed.count()  # Trigger action

print(f"Total errors: {counter.value}")

Custom Accumulators

from pyspark import AccumulatorParam

class SetAccumulatorParam(AccumulatorParam):
    """Accumulator for collecting unique values"""

    def zero(self, initial_value):
        return set()

    def addInPlace(self, acc1, acc2):
        return acc1 | acc2  # Set union

# Create custom accumulator
unique_users = sc.accumulator(set(), SetAccumulatorParam())

def track_user(record):
    global unique_users
    unique_users.add(record['user_id'])
    return record

data.map(track_user).count()
print(f"Unique users: {len(unique_users.value)}")

Multiple Accumulators

# Track multiple metrics
total_records = sc.accumulator(0)
error_records = sc.accumulator(0)
skipped_records = sc.accumulator(0)

def process_record(record):
    total_records.add(1)

    try:
        # Process record
        if should_skip(record):
            skipped_records.add(1)
            return None
        return transform(record)
    except Exception as e:
        error_records.add(1)
        return None

result = data.map(process_record).filter(lambda x: x is not None)
result.count()

print(f"Total: {total_records.value}")
print(f"Errors: {error_records.value}")
print(f"Skipped: {skipped_records.value}")
print(f"Success: {total_records.value - error_records.value - skipped_records.value}")

Accumulator Best Practices

# 1. Only use in actions, not transformations
# Bad: May count multiple times due to re-computation
bad_counter = sc.accumulator(0)
rdd.map(lambda x: bad_counter.add(1) or x).take(10)  # Wrong!

# Good: Use in actions or after caching
good_counter = sc.accumulator(0)
cached = rdd.cache()
cached.foreach(lambda x: good_counter.add(1))  # Correct!

# 2. Accumulators are write-only from executors
# Can only read on driver
print(f"Value on driver: {good_counter.value}")

# 3. Use for side effects, not core logic
# Accumulators for monitoring, not business logic

Broadcast Variables

Basic Broadcasting

# Broadcast: Efficiently share read-only data to all executors
lookup_table = {
    "US": "United States",
    "UK": "United Kingdom",
    "CA": "Canada"
}

# Broadcast to all executors
broadcast_lookup = sc.broadcast(lookup_table)

def expand_country(code):
    # Access broadcast value
    return broadcast_lookup.value.get(code, "Unknown")

codes = sc.parallelize(["US", "UK", "FR", "CA"])
expanded = codes.map(expand_country)
print(expanded.collect())

# Clean up
broadcast_lookup.unpersist()

Large Dataset Broadcasting

# Example: Broadcast reference data for joins
# Instead of shuffle join, broadcast small dataset

# Load small dimension table (e.g., 100MB)
products = sc.textFile("/data/products.csv") \
    .map(parse_product) \
    .collectAsMap()  # Collect to driver

# Broadcast to executors
broadcast_products = sc.broadcast(products)

# Use in transformations
def enrich_transaction(txn):
    product = broadcast_products.value.get(txn['product_id'])
    return {
        **txn,
        'product_name': product['name'],
        'category': product['category']
    }

transactions = sc.textFile("/data/transactions.csv") \
    .map(parse_transaction) \
    .map(enrich_transaction)

Broadcast for Complex Objects

import pickle

# Broadcast ML model
from sklearn.linear_model import LogisticRegression

# Train model (on driver)
model = LogisticRegression()
model.fit(X_train, y_train)

# Serialize and broadcast
broadcast_model = sc.broadcast(model)

def predict_batch(partition):
    """Apply model to partition"""
    model = broadcast_model.value
    for features in partition:
        yield model.predict([features])[0]

predictions = features_rdd.mapPartitions(predict_batch)

Broadcast Best Practices

# 1. Size considerations
# Good for < 1GB datasets
# Larger datasets: Use distributed caching

# 2. Update broadcasts
old_broadcast = sc.broadcast(old_data)

# ... use old_broadcast ...

# Update with new data
old_broadcast.unpersist()
new_broadcast = sc.broadcast(new_data)

# 3. Serialization
# Ensure broadcast data is serializable
# Use pickle-compatible objects

# 4. Memory management
# Unpersist when done
broadcast_var.unpersist()
broadcast_var.destroy()  # Remove from all executors

Delta Lake Integration

Introduction to Delta Lake

# Delta Lake: ACID transactions, schema enforcement, time travel
from delta import *

spark = SparkSession.builder \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()

Creating Delta Tables

# Write DataFrame as Delta table
df = spark.read.json("/data/events.json")

df.write \
    .format("delta") \
    .mode("overwrite") \
    .save("/delta/events")

# Or as managed table
df.write \
    .format("delta") \
    .mode("overwrite") \
    .saveAsTable("events")

Reading Delta Tables

# Read Delta table
delta_df = spark.read \
    .format("delta") \
    .load("/delta/events")

# Or from table name
delta_df = spark.table("events")

# Read specific version (time travel)
historical_df = spark.read \
    .format("delta") \
    .option("versionAsOf", 0) \
    .load("/delta/events")

# Read as of timestamp
timestamp_df = spark.read \
    .format("delta") \
    .option("timestampAsOf", "2024-01-01") \
    .load("/delta/events")

ACID Transactions

from delta.tables import DeltaTable

# Upsert (merge)
delta_table = DeltaTable.forPath(spark, "/delta/events")

updates = spark.read.json("/data/updates.json")

delta_table.alias("target") \
    .merge(
        updates.alias("source"),
        "target.id = source.id"
    ) \
    .whenMatchedUpdate(set={
        "status": "source.status",
        "updated_at": "source.updated_at"
    }) \
    .whenNotMatchedInsert(values={
        "id": "source.id",
        "status": "source.status",
        "created_at": "source.created_at"
    }) \
    .execute()

Schema Evolution

# Enable schema merging
df_new_schema = spark.read.json("/data/new_events.json")

df_new_schema.write \
    .format("delta") \
    .mode("append") \
    .option("mergeSchema", "true") \
    .save("/delta/events")

# Schema enforcement
try:
    invalid_df.write \
        .format("delta") \
        .mode("append") \
        .save("/delta/events")
except Exception as e:
    print(f"Schema mismatch: {e}")

Optimization

# Optimize small files
delta_table = DeltaTable.forPath(spark, "/delta/events")
delta_table.optimize().executeCompaction()

# Z-ordering for better filtering
delta_table.optimize() \
    .where("date >= '2024-01-01'") \
    .executeZOrderBy("user_id", "timestamp")

# Vacuum old files
delta_table.vacuum(retentionHours=168)  # 7 days

Time Travel

# View history
delta_table = DeltaTable.forPath(spark, "/delta/events")
history = delta_table.history()
history.select("version", "timestamp", "operation").show()

# Restore to previous version
delta_table.restoreToVersion(5)

# Or restore to timestamp
delta_table.restoreToTimestamp("2024-01-01")

# Clone table
spark.sql("""
    CREATE TABLE events_backup
    SHALLOW CLONE events
    VERSION AS OF 10
""")

GraphX for Graph Processing

Creating Graphs

import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD

// Define vertices (id, property)
val vertices: RDD[(VertexId, String)] = sc.parallelize(Array(
  (1L, "Alice"),
  (2L, "Bob"),
  (3L, "Charlie"),
  (4L, "David")
))

// Define edges (srcId, dstId, property)
val edges: RDD[Edge[String]] = sc.parallelize(Array(
  Edge(1L, 2L, "friend"),
  Edge(2L, 3L, "colleague"),
  Edge(3L, 4L, "friend"),
  Edge(1L, 4L, "family")
))

// Create graph
val graph = Graph(vertices, edges)

Graph Operations

// Vertex count
val numVertices = graph.vertices.count()

// Edge count
val numEdges = graph.edges.count()

// Degree distribution
val degrees = graph.degrees.collect()

// Find triangles
val triangles = graph.triangleCount()

// Connected components
val cc = graph.connectedComponents()

// PageRank
val ranks = graph.pageRank(0.0001).vertices

// Join with original graph
val users = vertices.join(ranks).map {
  case (id, (name, rank)) => (id, name, rank)
}

Pregel API

// Pregel: Bulk synchronous parallel computation
val shortestPaths = graph.pregel(
  initialMsg = Double.PositiveInfinity,
  maxIterations = Int.MaxValue,
  activeDirection = EdgeDirection.Out
)(
  // Vertex program
  (id, dist, newDist) => math.min(dist, newDist),

  // Send message
  triplet => {
    if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
      Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
    } else {
      Iterator.empty
    }
  },

  // Merge messages
  (a, b) => math.min(a, b)
)

Performance Profiling

Spark UI Deep Dive

# Enable event logging
spark.conf.set("spark.eventLog.enabled", "true")
spark.conf.set("spark.eventLog.dir", "/tmp/spark-events")

# Detailed metrics
spark.conf.set("spark.sql.codegen.wholeStage", "true")
spark.conf.set("spark.sql.adaptive.enabled", "true")

# Profile specific queries
df.explain("cost")  # Show cost-based optimization
df.explain("extended")  # Show all plans

Python Profiling

# Enable Python profiling
spark.conf.set("spark.python.profile", "true")

# Run workload
result = rdd.map(expensive_function).collect()

# View profile results
sc.show_profiles()

# Dump to file
sc.dump_profiles("/tmp/profiles")

Memory Profiling

# Track memory usage
from memory_profiler import profile

@profile
def process_large_dataset():
    df = spark.read.parquet("/large/dataset")
    result = df.groupBy("key").agg({"value": "sum"})
    result.write.parquet("/output")

# Monitor executor memory
spark.sparkContext._jsc.sc().getExecutorMemoryStatus()

Hands-On Exercises

Exercise 1: Custom Partitioner

# TODO: Implement custom partitioner
# 1. Create partitioner for geographic regions
# 2. Partition customer data by region
# 3. Verify partition distribution
# 4. Compare performance with hash partitioner

# Your code here

Exercise 2: Accumulator for Data Quality

# TODO: Create accumulators for data quality metrics
# 1. Track null values per column
# 2. Count invalid records
# 3. Monitor data distributions
# 4. Generate quality report

# Your code here

Exercise 3: Delta Lake Pipeline

# TODO: Build Delta Lake ETL pipeline
# 1. Ingest raw data to Bronze table
# 2. Transform and write to Silver table
# 3. Aggregate to Gold table
# 4. Implement time travel and optimization

# Your code here

Summary

Advanced Spark techniques enable sophisticated data processing:
  • RDD Operations: Low-level control and optimization
  • Custom Partitioners: Domain-specific data distribution
  • Accumulators: Distributed aggregation and monitoring
  • Broadcast Variables: Efficient data sharing
  • Delta Lake: ACID transactions and time travel
  • GraphX: Graph analytics at scale

Key Takeaways

  1. Use RDDs for low-level control when needed
  2. Custom partitioners improve co-location
  3. Accumulators for metrics, not core logic
  4. Broadcast for read-only reference data
  5. Delta Lake for reliable data lakes
  6. Profile and measure for optimization

Continue to the final capstone project to apply all learned concepts.