Modern Data Pipelines with PySpark: Best Practices & Patterns
Introduction
PySpark has become the lingua franca of big data processing. Its combination of Python's ease of use and Spark's distributed computing power makes it perfect for building scalable data pipelines.
After optimizing PySpark pipelines processing petabytes of data, I'll share patterns and techniques that make the difference between slow, expensive jobs and fast, cost-effective ones.
Why PySpark?
Advantages:
- ✅ Python ecosystem (pandas, scikit-learn, etc.)
- ✅ Distributed computing (handle datasets larger than memory)
- ✅ Unified batch & streaming
- ✅ Built-in optimization (Catalyst, Tungsten)
- ✅ Integration with cloud data platforms
When to use PySpark:
- Data size > 100GB
- Complex transformations requiring parallelization
- Unified batch and streaming workflows
- Integration with ML libraries
Pipeline Architecture Pattern
# Standard pipeline structure
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
class DataPipeline:
def __init__(self, spark: SparkSession):
self.spark = spark
def extract(self):
"""Read data from sources"""
pass
def transform(self, df):
"""Apply business logic"""
pass
def load(self, df):
"""Write to destination"""
pass
def run(self):
"""Execute pipeline"""
df = self.extract()
df_transformed = self.transform(df)
self.load(df_transformed)
Spark Session Configuration
from pyspark.sql import SparkSession
def create_spark_session(app_name: str) -> SparkSession:
"""Create optimized Spark session"""
return SparkSession.builder \
.appName(app_name) \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("spark.sql.parquet.int96RewriteEnabled", "true") \
.config("spark.sql.parquet.datetimeRebaseModeInWrite", "CORRECTED") \
.config("spark.sql.session.timeZone", "UTC") \
.getOrCreate()
spark = create_spark_session("ETL Pipeline")
Reading Data Efficiently
1. Parquet (Recommended)
# Read Parquet with partition pruning
df = spark.read \
.format("parquet") \
.option("mergeSchema", "false") \
.load("s3://bucket/data/year=2026/month=01/")
# Predicate pushdown happens automatically
filtered_df = df.filter(col("date") >= "2026-01-01")
2. CSV with Schema
from pyspark.sql.types import *
# Define schema (faster than inference)
schema = StructType([
StructField("id", IntegerType(), False),
StructField("name", StringType(), True),
StructField("amount", DecimalType(10, 2), True),
StructField("timestamp", TimestampType(), True)
])
df = spark.read \
.format("csv") \
.option("header", "true") \
.option("mode", "DROPMALFORMED") \
.schema(schema) \
.load("s3://bucket/data.csv")
3. Delta Lake (Best for Analytics)
# Read Delta table
df = spark.read \
.format("delta") \
.load("s3://bucket/delta/table")
# Time travel
df_yesterday = spark.read \
.format("delta") \
.option("versionAsOf", 5) \
.load("s3://bucket/delta/table")
Transformation Patterns
1. Column Operations
from pyspark.sql.functions import *
df_transformed = df \
.withColumn("year", year(col("timestamp"))) \
.withColumn("month", month(col("timestamp"))) \
.withColumn("amount_usd", col("amount") * col("exchange_rate")) \
.withColumn("is_active", when(col("status") == "active", True).otherwise(False))
2. Window Functions
from pyspark.sql.window import Window
# Running total per user
window_spec = Window.partitionBy("user_id").orderBy("timestamp")
df_with_running_total = df.withColumn(
"running_total",
sum("amount").over(window_spec)
)
# Rank within partition
rank_window = Window.partitionBy("category").orderBy(col("revenue").desc())
df_ranked = df.withColumn(
"rank",
rank().over(rank_window)
).filter(col("rank") <= 10) # Top 10 per category
3. Aggregations
# Group by with multiple aggregations
df_agg = df.groupBy("user_id", "date") \
.agg(
count("*").alias("event_count"),
countDistinct("session_id").alias("session_count"),
sum("revenue").alias("total_revenue"),
avg("session_duration").alias("avg_session_duration"),
max("timestamp").alias("last_activity")
)
4. Joins Optimization
# Broadcast join (small table < 10MB)
from pyspark.sql.functions import broadcast
large_df = spark.table("large_table")
small_df = spark.table("small_lookup")
joined = large_df.join(
broadcast(small_df),
"key"
)
# Salted join for skewed data
from pyspark.sql.functions import rand, concat, lit
def salted_join(left_df, right_df, join_key, salt_factor=10):
# Add salt to left
left_salted = left_df.withColumn(
"salt", (rand() * salt_factor).cast("int")
).withColumn(
"salted_key", concat(col(join_key), lit("_"), col("salt"))
)
# Explode right with all salt values
right_salted = right_df.withColumn(
"salt", explode(array([lit(i) for i in range(salt_factor)]))
).withColumn(
"salted_key", concat(col(join_key), lit("_"), col("salt"))
)
return left_salted.join(right_salted, "salted_key") \
.drop("salt", "salted_key")
Data Quality Checks
def validate_data(df, rules):
"""Apply data quality rules"""
quality_df = df.withColumn("quality_issues", array())
for rule_name, rule_expr in rules.items():
quality_df = quality_df.withColumn(
"quality_issues",
when(~rule_expr, array_union(col("quality_issues"), array(lit(rule_name))))
.otherwise(col("quality_issues"))
)
# Separate good and bad records
good_records = quality_df.filter(size(col("quality_issues")) == 0)
bad_records = quality_df.filter(size(col("quality_issues")) > 0)
return good_records, bad_records
# Define rules
rules = {
"non_null_user_id": col("user_id").isNotNull(),
"positive_amount": col("amount") > 0,
"valid_date": col("date").between("2020-01-01", current_date())
}
good_df, bad_df = validate_data(df, rules)
# Write bad records to quarantine
bad_df.write.mode("append").parquet("s3://bucket/quarantine/")
Partitioning Strategies
1. File Partitioning
# Write with partitioning
df.write \
.partitionBy("year", "month", "day") \
.mode("overwrite") \
.parquet("s3://bucket/partitioned/")
# Read specific partitions
df = spark.read.parquet("s3://bucket/partitioned/year=2026/month=01/")
2. DataFrame Partitioning
# Check current partitions
print(f"Current partitions: {df.rdd.getNumPartitions()}")
# Repartition (full shuffle)
df_repartitioned = df.repartition(200, "user_id")
# Coalesce (no shuffle, reduce only)
df_coalesced = df.coalesce(10)
# Optimal partition size: 128MB - 1GB
data_size_gb = 100
optimal_partitions = int(data_size_gb * 8) # ~125MB per partition
df = df.repartition(optimal_partitions)
Caching Strategy
# Cache when reusing data multiple times
df_cached = df.cache()
# Use in multiple operations
result1 = df_cached.filter(col("amount") > 100).count()
result2 = df_cached.groupBy("category").count()
# Unpersist when done
df_cached.unpersist()
# Choose storage level based on needs
from pyspark import StorageLevel
# Memory only (fastest, may cause OOM)
df.persist(StorageLevel.MEMORY_ONLY)
# Memory + disk (safer)
df.persist(StorageLevel.MEMORY_AND_DISK)
# Serialized (lower memory footprint)
df.persist(StorageLevel.MEMORY_ONLY_SER)
UDFs: When and How
Python UDFs (Slow - Use Sparingly)
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
# Define UDF
@udf(returnType=StringType())
def categorize_amount(amount):
if amount < 100:
return "low"
elif amount < 1000:
return "medium"
else:
return "high"
# Use UDF
df = df.withColumn("category", categorize_amount(col("amount")))
Pandas UDFs (Faster - Vectorized)
from pyspark.sql.functions import pandas_udf
import pandas as pd
@pandas_udf("double")
def complex_calculation(amounts: pd.Series) -> pd.Series:
# Vectorized operations on pandas Series
return amounts * 1.1 + (amounts ** 0.5)
df = df.withColumn("result", complex_calculation(col("amount")))
Best Practice: Avoid UDFs when possible, use built-in functions instead!
Writing Data
1. Parquet with Compression
df.write \
.mode("overwrite") \
.option("compression", "snappy") \
.parquet("s3://bucket/output/")
2. Delta Lake (Recommended)
# Write with optimization
df.write \
.format("delta") \
.mode("overwrite") \
.option("overwriteSchema", "true") \
.save("s3://bucket/delta/table")
# Merge (Upsert)
from delta.tables import DeltaTable
deltaTable = DeltaTable.forPath(spark, "s3://bucket/delta/table")
deltaTable.alias("target").merge(
df.alias("source"),
"target.id = source.id"
).whenMatchedUpdate(set = {
"name": "source.name",
"updated_at": "source.updated_at"
}).whenNotMatchedInsert(values = {
"id": "source.id",
"name": "source.name",
"created_at": "source.created_at"
}).execute()
Performance Optimization Checklist
1. Before Writing Code
- [ ] Understand data size and distribution
- [ ] Identify skewed keys
- [ ] Plan partition strategy
- [ ] Choose right file format (Parquet/Delta)
2. During Development
- [ ] Use DataFrame API (not RDD)
- [ ] Leverage built-in functions
- [ ] Minimize UDFs
- [ ] Avoid
collect()on large datasets - [ ] Use broadcast for small tables
3. Before Production
- [ ] Enable Adaptive Query Execution (AQE)
- [ ] Set appropriate shuffle partitions
- [ ] Configure executor resources
- [ ] Add monitoring and logging
- [ ] Test with production data volumes
Monitoring & Debugging
# Enable detailed logging
spark.sparkContext.setLogLevel("WARN")
# Explain query plan
df.explain(mode="formatted")
# Show physical plan
df.explain(mode="cost")
# Monitor stage metrics
spark.sparkContext.statusTracker().getStageInfo()
# Custom metrics
from pyspark import AccumulatorParam
processed_records = spark.sparkContext.accumulator(0)
def process_with_counter(row):
global processed_records
processed_records += 1
return transform(row)
Common Pitfalls to Avoid
1. Using collect() on Large DataFrames
# BAD: Brings all data to driver
all_data = df.collect() # OOM on large datasets!
# GOOD: Process distributedly
df.write.parquet("s3://bucket/output/")
2. Not Filtering Early
# BAD: Read everything then filter
df = spark.read.parquet("s3://bucket/data/")
df_filtered = df.filter(col("date") == "2026-01-01")
# GOOD: Filter during read (partition pruning)
df = spark.read.parquet("s3://bucket/data/date=2026-01-01/")
3. Creating Too Many Small Files
# BAD: Creates thousands of tiny files
df.write.partitionBy("user_id").parquet("s3://bucket/output/")
# GOOD: Repartition first
df.repartition(100) \
.write.partitionBy("date") \
.parquet("s3://bucket/output/")
Production Pipeline Example
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from delta.tables import DeltaTable
import logging
class ProductionETL:
def __init__(self, env: str):
self.env = env
self.spark = self._create_spark_session()
self.logger = logging.getLogger(__name__)
def _create_spark_session(self):
return SparkSession.builder \
.appName(f"ETL-{self.env}") \
.config("spark.sql.adaptive.enabled", "true") \
.getOrCreate()
def extract_users(self, start_date: str, end_date: str):
"""Extract user data"""
self.logger.info(f"Extracting users from {start_date} to {end_date}")
return self.spark.read \
.format("delta") \
.load(f"s3://bucket/bronze/users") \
.filter(col("created_date").between(start_date, end_date))
def transform_users(self, df):
"""Apply transformations"""
return df \
.withColumn("created_year", year(col("created_date"))) \
.withColumn("user_segment",
when(col("total_purchases") > 1000, "premium")
.when(col("total_purchases") > 100, "standard")
.otherwise("basic")
) \
.filter(col("is_active") == True)
def load_users(self, df):
"""Load to gold layer"""
self.logger.info(f"Loading {df.count()} records")
df.write \
.format("delta") \
.mode("overwrite") \
.partitionBy("created_year") \
.save(f"s3://bucket/gold/users")
def run(self, start_date: str, end_date: str):
"""Execute pipeline"""
try:
df = self.extract_users(start_date, end_date)
df_transformed = self.transform_users(df)
self.load_users(df_transformed)
self.logger.info("Pipeline completed successfully")
except Exception as e:
self.logger.error(f"Pipeline failed: {str(e)}")
raise
# Run pipeline
etl = ProductionETL(env="production")
etl.run("2026-01-01", "2026-01-31")
Conclusion
Building production-grade PySpark pipelines requires:
- Understand your data: Size, distribution, skew
- Optimize early: Partition pruning, broadcast joins
- Monitor everything: Spark UI, metrics, logs
- Test at scale: Use production data volumes
- Iterate: Profile, optimize, repeat
Master these patterns and your PySpark pipelines will be fast, reliable, and cost-effective.
Questions about PySpark optimization? Let's discuss on LinkedIn!