Back to Blog
January 12, 20266 min read

Modern Data Pipelines with PySpark: Best Practices & Patterns

Build production-grade data pipelines with PySpark. Learn optimization techniques, design patterns, and best practices from processing petabytes of data at scale.

PySparkSparkData Pipelines

After four years of running PySpark pipelines on petabyte-scale datasets, the gap between code that works in development and code that runs reliably in production is large. This guide covers the patterns that close that gap.

Session Configuration: Start Right

The SparkSession configuration determines your job's ceiling before you write a single transformation. These settings matter most:

from pyspark.sql import SparkSession

spark = SparkSession.builder \

.appName("production-etl-job") \

.config("spark.sql.adaptive.enabled", "true") \

.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \

.config("spark.sql.adaptive.skewJoin.enabled", "true") \

.config("spark.sql.shuffle.partitions", "auto") \

.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \

.config("spark.sql.parquet.filterPushdown", "true") \

.config("spark.sql.parquet.mergeSchema", "false") \

.config("spark.dynamicAllocation.enabled", "true") \

.config("spark.dynamicAllocation.minExecutors", "2") \

.config("spark.dynamicAllocation.maxExecutors", "20") \

.getOrCreate()

Adaptive Query Execution (AQE): Enabled by default in Spark 3.x but worth confirming. AQE dynamically adjusts join strategies, coalesces shuffle partitions after they're computed, and handles skew joins automatically. It is the single highest-leverage configuration change for most jobs. mergeSchema = false: Schema inference on read is expensive and dangerous in production. Enforce schemas explicitly.

Partition Management

Partition count is the most consequential tuning decision in a Spark job. Wrong partition count causes:

  • Too few: tasks run out of memory, GC pressure, slow execution
  • Too many: scheduling overhead, small files, slow writes
  • Rule of thumb: Target 128MB–256MB of data per partition.
    def optimal_partition_count(df, target_mb=200):
    

    """Estimate optimal partition count based on data size."""

    # Rough estimate: cache a sample, extrapolate

    sample_size = df.limit(10000).count()

    # Use Spark's internal size estimator

    estimated_bytes = spark._jsparkSession \

    .sessionState() \

    .executePlan(df._jdf.queryExecution().analyzed()) \

    .optimizedPlan() \

    .stats() \

    .sizeInBytes()

    target_bytes = target_mb 1024 1024

    return max(1, int(estimated_bytes / target_bytes))

    Repartition before heavy transformations

    df = df.repartition(optimal_partition_count(df))

    Coalesce vs. repartition:
  • repartition(n): full shuffle, use when increasing partitions or changing partition key
  • coalesce(n): no shuffle, only merges partitions — use only when decreasing partition count

Before writing: coalesce to reduce output file count

df.coalesce(10).write.format("delta").mode("append").save(output_path)

Join Optimization

Joins are the most common source of performance problems in PySpark.

Broadcast Joins

When one DataFrame is small enough to fit in executor memory, broadcast it to avoid a shuffle join:

from pyspark.sql.functions import broadcast

Explicit broadcast hint — Spark will also do this automatically

if the table is below spark.sql.autoBroadcastJoinThreshold (default: 10MB)

result = large_df.join(broadcast(small_lookup_df), "product_id")

Set the threshold appropriately:

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50mb")

Handling Skewed Joins

Data skew — one partition containing disproportionately more data — causes one task to run 10× longer than all others.

Diagnosis:

Check partition sizes after a shuffle

df.groupBy(spark_partition_id()).count().orderBy(desc("count")).show(20)

Solution — salting:

from pyspark.sql.functions import rand, concat, lit, floor

Add random salt to skewed key, explode the smaller table

SALT_FACTOR = 10

Salt the large (skewed) table

salted_large = large_df.withColumn(

"salted_key",

concat(col("user_id"), lit("_"), (floor(rand() * SALT_FACTOR)).cast("string"))

)

Explode the small table to match all salt values

from pyspark.sql.functions import array, explode

salted_small = small_df.withColumn(

"salt_range",

array([lit(str(i)) for i in range(SALT_FACTOR)])

).withColumn("salt", explode("salt_range")) \

.withColumn("salted_key", concat(col("user_id"), lit("_"), col("salt"))) \

.drop("salt_range", "salt")

result = salted_large.join(salted_small, "salted_key").drop("salted_key")

Memory Management

Spark memory is divided into execution memory (shuffles, sorts, joins) and storage memory (caching). OOM errors usually mean one is starving the other.

Adjust memory fraction if caching large DataFrames

spark.conf.set("spark.memory.fraction", "0.75") # Total JVM heap for Spark

spark.conf.set("spark.memory.storageFraction", "0.3") # Within that, for caching

When to cache:

Cache only when a DataFrame is used multiple times in the same job

Using MEMORY_AND_DISK avoids OOM if the cache doesn't fit in memory

from pyspark import StorageLevel

df_lookup.persist(StorageLevel.MEMORY_AND_DISK)

Do multiple operations with df_lookup...

Unpersist when done — don't leak memory

df_lookup.unpersist()

Do not cache DataFrames that are used only once. Caching has a write cost and wastes memory.

Schema Management

Explicit schema definition is mandatory in production. Relying on inference is slow and breaks silently when upstream data changes.

from pyspark.sql.types import (

StructType, StructField, StringType, LongType,

TimestampType, DecimalType, BooleanType

)

ORDER_SCHEMA = StructType([

StructField("order_id", StringType(), nullable=False),

StructField("user_id", StringType(), nullable=False),

StructField("product_id", StringType(), nullable=True),

StructField("amount", DecimalType(18, 2), nullable=False),

StructField("created_at", TimestampType(), nullable=False),

StructField("is_test_order", BooleanType(), nullable=False),

StructField("status", StringType(), nullable=True),

])

df = spark.read.schema(ORDER_SCHEMA).json("s3://raw/orders/")

Fail fast: check schema matches expectation

assert df.schema == ORDER_SCHEMA, f"Schema mismatch: {df.schema}"

Writing Idiomatic PySpark

Common anti-patterns that hurt performance and readability: Anti-pattern: UDFs for simple transformations

Bad: Python UDF breaks Catalyst optimization, causes serialization overhead

from pyspark.sql.functions import udf

@udf("string")

def format_name(first, last):

return f"{first} {last}"

Good: use built-in functions — stays in JVM, Catalyst can optimize

from pyspark.sql.functions import concat_ws

df = df.withColumn("full_name", concat_ws(" ", col("first_name"), col("last_name")))

Anti-pattern: Iterating rows with .collect()

Bad: brings all data to driver, defeats distributed processing

for row in df.collect():

process(row)

Good: push processing to executors with foreachBatch or map partitions

df.foreachPartition(lambda partition: [process(row) for row in partition])

Anti-pattern: Chaining .withColumn() for many columns

Bad: each withColumn creates a new DataFrame plan node — slow with many columns

df = df.withColumn("col1", ...).withColumn("col2", ...).withColumn("col3", ...)

Good: use select with multiple expressions

from pyspark.sql.functions import col, expr

df = df.select(

"*",

expr("...").alias("col1"),

expr("...").alias("col2"),

expr("...").alias("col3"),

)

Production Observability

Instrument your jobs before they reach production:

import logging

from datetime import datetime

logger = logging.getLogger(__name__)

def process_partition(input_path: str, output_path: str, processing_date: str):

start = datetime.utcnow()

df = spark.read.format("delta").load(input_path) \

.filter(col("processing_date") == processing_date)

input_count = df.count()

logger.info(f"Input records: {input_count:,}")

# --- transformations ---

result = transform(df)

# Validate before write

output_count = result.count()

if output_count == 0:

raise ValueError(f"Output is empty for {processing_date} — aborting write")

drop_rate = 1 - (output_count / input_count)

if drop_rate > 0.05:

raise ValueError(f"Drop rate {drop_rate:.1%} exceeds 5% threshold")

result.write.format("delta").mode("overwrite") \

.option("replaceWhere", f"processing_date = '{processing_date}'") \

.save(output_path)

duration = (datetime.utcnow() - start).total_seconds()

logger.info(f"Completed: {output_count:,} records in {duration:.1f}s")

Pre-write validation — asserting that output is non-empty and that the drop rate is within expected bounds — prevents silent data loss. This check takes seconds and has caught real bugs more times than I can count.

The discipline of PySpark at scale is less about advanced API knowledge and more about operational hygiene: explicit schemas, measured partitioning, appropriate caching, and validating before committing writes.