, , ,

3 Tips for Unit Testing PySpark Pipelines

I’m not sure what it is, but some prevailing evil in the Data Engineering world has made it not so common for PySpark pipelines to be unit tested. Who knows, it’s probably a combination of things. Data Engineers have been accused of not having good Software Engineering principles. Functional testing is a hot commodity in the Software Engineering world but probably takes a while to trickle its way into mainstream Data Engineering. It can require good Docker skills. Also, generally speaking, the old school Data and ETL Developers that preceded Data Engineers in the bygone days never unit tested …. so neither do their ancestors.

Who knows? All that being said I want to give you 3 tips to help you unit test your PySpark ETL data pipelines.

Basics of Unit Testing PySpark.

I already wrote a blog post on the basic introduction to unit testing Pyspark code here.

Generally, it requires a few prerequisite steps.

  • Docker
  • Docker compose
  • conftest with pytest
  • functional pipeline code

Please read the above link to the other blog post for a more in-depth review of setting up PySpark unit testing. Here are links to my code for each step.

Click here to get a Dockerfile with Spark etc, installed. A nice docker-compose file to easily run the unit tests. A pytest conftest.py file to attach the needed Spark Session to each unit test.

Easy peasy.

That’s the basic baseline you need to start with PySpark unit testing. Again check out my other blog post to get started.

3 Tips for Unit Testing PySpark pipelines.

Having PySpark pipeline code that can be unit tested requires a different mindset than a lot of Data Engineers have when approaching writing pipeline code. But, there are a few easy steps that anyone can implement that will make unit testing PySpark code not only easier but addicting! Being able to unit test code when adding new functionality or exploring a new codebase is key for Data Engineering success and keeping pipelines from failing.

So, here are the 3 tips that are essential for unit testing Pyspark.

  • modular and functional code.
  • configurable functions … no hard coding.
  • sample data or create your own.

Let’s dig into each one of these tips.

Modular and Functional code.

This is the baseline for writing PySpark that can be unit tested well from my perspective. I’ve seen the impossibility of unit testing code can be overwhelming if it wasn’t written in a modular and functional manner.

There are two major errors I see regularly when working with PySpark code that lead to the inability and lack of unit testing are …

  • not encapsulating ALL logic inside functions / methods.
  • functions that are two large, not functional, with many side effects and buisness logic.

Enough talk, what does this look like in real life?

Not encapsulating logic inside functions or method.

The main problem I see with a lot of PySpark code is that it couldn’t readily be unit tested at all! Everything is just written inside a main method, and business logic and transformation logic is just applied one after another without and modular thought applied.

It usually looks something like this using a contrived example…

import logging
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import os

spark = SparkSession.builder.appName('MyTerriblePipeline').getOrCreate()
logging.basicConfig(format='%(asctime)s %(levelname)s - My Terrible Pipeline - %(message)s', level=logging.INFO)
logging.getLogger().setLevel(logging.INFO)

BUCKET = os.environ['ENV']

raw_data = spark.read.csv(f's3a://{BUCKET}/raw/data/files/*.csv', sep='\t', header='true').select('order_id', 'customer_id', 'amount', 'order_date')
logging.info(f"Pulled raw data from bucket {BUCKET}")
with_dates = raw_data.withColumn('order_date', F.to_date(F.col('order_date'), 'yyyy-MM-dd').filter(F.year(F.col('order_date') == F.lit(2021))
summary_metrics = with_dates.groupBy(F.col('customer_id')).agg(F.regexp_replace(F.col('amount'), '$', '').alias('total_amount)).filter(F.col('total_amount') > 10)
summary_metrics.coalesce(1).write.mode('overwrite').csv(f's3a://{BUCKET}/results', header='true')

Now to some people, this all might look fine, but I assure you all is not well.

First, when you look at the code there is so much going on, you are not really sure what to look at. This is a contrived example, what about in the real world when most transformation pipelines are many times larger and more complex … how could you possibly read code like this?

It’s impossible to track down what transformations are happening where, debugging takes longer, adding new logic takes longer, just reading it takes longer. Oh, and it’s pretty much guaranteed to create more bugs.

Last but not least, how can you unit test such code? You really can’t.

Step 1, encapsulate and separate logic into methods or functions.

Let’s take the first step of simply breaking the logic down into its pieces and parts. We won’t make this code perfect, but we will take one step in the right direction by simply grouping logic together.

import logging
from pyspark.sql import SparkSession, DataFrame
import pyspark.sql.functions as F
import os

spark = SparkSession.builder.appName('MyTerriblePipeline').getOrCreate()
logging.basicConfig(format='%(asctime)s %(levelname)s - My Terrible Pipeline - %(message)s', level=logging.INFO)
logging.getLogger().setLevel(logging.INFO)

BUCKET = os.environ['ENV']

def read_raw_data(spark: SparkSession) -> DataFrame:
    raw_data = spark.read.csv(f's3a://{BUCKET}/raw/data/files/*.csv', sep='\t', header='true')
    output_df = raw_data.select('order_id', 'customer_id', 'amount', 'order_date')
    logging.info(f"Pulled raw data from bucket {BUCKET}")
    return output_df

def modify_and_filter_dates(input_df: DataFrame) -> DataFrame:
    intermediate = input_df.withColumn('order_date', F.to_date(F.col('order_date'), 'yyyy-MM-dd')
    output_df = intermediate.filter(F.year(F.col('order_date') == F.lit(2021))
    return output_df

def transform_amount(input_df: DataFrame) -> DataFrame:
    output_df = input_df.withColumn('amount', F.regexp_replace(F.col('amount'), '$', ''))


def create_metrics(input_df: DataFrame) -> DataFrame:
    intermediate = input_df.groupBy(F.col('customer_id')).agg(F.col('amount').alias('total_amount))
    output_df = intermediate.filter(F.col('total_amount') > 10)
    return output_df

raw_data = read_raw_data(spark)
with_dates = modify_and_filter_dates(raw_data)
with_amts = transform_amount(with_dates)
summary_metrics = create_metrics(with_amts)
summary_metrics.coalesce(1).write.mode('overwrite').csv(f's3a://{BUCKET}/results', header='true')

Ok, definitely not perfect, but a big improvement! Just encapsulating and moving logic together in functions makes all the difference in the world. The code is more readable, it’s easy to debug and find spots to add or remove logic. Also, it makes the code actually unit testable as we now have singular methods.

It really isn’t that hard to do either, all you have to do is define a method and copy and paste your logic inside.

But, I’m sure for the astute Data Engineer it still leaves some things to be desired as know you can see many hardcoded configurations and the code still isn’t all that reusable because of this.

That leads us to the next topic.

Configurable functions … no hard coding.

The next step in the process should always be the ability to re-use functions and methods in other pipelines so we don’t have to write the same code over and over again, only changing hard-coded values.

This seems obvious, but it truly escapes many Data Engineers for some reason. We want to reduce the amount of code we have to write and maintain. If we can write a general method that can be used in multiple spots it will reduce the size of our codebase and make debugging and other tasks easier.

Let’s take a stab at our new code and functions and see what we can do.

import logging
from pyspark.sql import SparkSession, DataFrame
import pyspark.sql.functions as F
import os

spark = SparkSession.builder.appName('MyTerriblePipeline').getOrCreate()
logging.basicConfig(format='%(asctime)s %(levelname)s - My Terrible Pipeline - %(message)s', level=logging.INFO)
logging.getLogger().setLevel(logging.INFO)

BUCKET = os.environ['ENV']

def read_raw_data(spark: SparkSession, 
                  read_loc: str = f's3a://{BUCKET}/raw/data/files/*.csv',
                  csv_sep: str = '\t',
                  hdr: str = 'true',
                  columns: list = ['order_id', 'customer_id', 'amount', 'order_date']
                  ) -> DataFrame:
    raw_data = spark.read.csv(f'{read_loc}', sep=f'{csv_sep}', header=f'{hdr}')
    output_df = raw_data.select(*columns)
    logging.info(f"Pulled raw data from bucket {BUCKET}")
    return output_df

def modify_and_filter_dates(input_df: DataFrame, 
                            new_column: str = 'order_date',
                            date_column: str = 'order_date',
                            date_format: str = 'yyyy-MM-dd',
                            yr: int = 2021
                            ) -> DataFrame:
    intermediate = input_df.withColumn(f'{new_column}', F.to_date(F.col(f'{date_column}'), f'{date_format}')
    output_df = intermediate.filter(F.year(F.col('{new_column}') == F.lit(yr))
    return output_df

def transform_amount(input_df: DataFrame,
                     new_column: str = 'amount',
                     old_column: str = 'amount',
                     reg_replace_str: str = '$') -> DataFrame:
    output_df = input_df.withColumn(f'{new_column}', 
                                     F.regexp_replace(F.col('{old_column}'), f'{reg_replace_str}', '')
                          )


def create_metrics(input_df: DataFrame, filter_amt: int = 10) -> DataFrame:
    intermediate = input_df.groupBy(F.col('customer_id')).agg(F.col('amount').alias('total_amount))
    output_df = intermediate.filter(F.col('total_amount') > filter_amt)
    return output_df

raw_data = read_raw_data(spark)
with_dates = modify_and_filter_dates(raw_data)
with_amts = transform_amount(with_dates)
summary_metrics = create_metrics(with_amts)
summary_metrics.coalesce(1).write.mode('overwrite').csv(f's3a://{BUCKET}/results', header='true')

Again, this is a contrived example, but you of course being smart people can see the difference this will make. When we pull out hardcoded values like read locations, or column names and formats, it because very simple to re-use the method in some other spot. Let’s take the example of a very common task, taking a STRING column and converting it to a date.

When you are able to modify the format of the date … date_format: str = 'yyyy-MM-dd', as well as the column names in question, along with the filter, you could use this in another dataset very easily.

def modify_and_filter_dates(input_df: DataFrame, 
                            new_column: str = 'order_date',
                            date_column: str = 'order_date',
                            date_format: str = 'yyyy-MM-dd',
                            yr: int = 2021
                            ) -> DataFrame:
    intermediate = input_df.withColumn(f'{new_column}', F.to_date(F.col(f'{date_column}'), f'{date_format}')
    output_df = intermediate.filter(F.year(F.col('{new_column}') == F.lit(yr))
    return output_df

Even the reading of a CSV file is a good example where a person could easily re-use the code, swapping out file locations, and read options as needed.

def read_raw_data(spark: SparkSession, 
                  read_loc: str = f's3a://{BUCKET}/raw/data/files/*.csv',
                  csv_sep: str = '\t',
                  hdr: str = 'true',
                  columns: list = ['order_id', 'customer_id', 'amount', 'order_date']
                  ) -> DataFrame:
    raw_data = spark.read.csv(f'{read_loc}', sep=f'{csv_sep}', header=f'{hdr}')
    output_df = raw_data.select(*columns)
    logging.info(f"Pulled raw data from bucket {BUCKET}")
    return output_df

The added benefit is that these changes make the code, even more, unit-testable as well, reading from a local test data sample becomes a trivial example when you’re not dealing with a function that is hardcoded to read from some remote s3 location.

This brings us to the last tip.

Sample data, both real and generated.

This is my last tip for PySpark unit tests. If you are going to unit test your PySpark pipeline code you are going to need sample data of some kind to do this. There are really two approaches you can take, some serve the purpose better than others.

Do you know how you’ve always been told that taking the easy road isn’t the best choice? Well, that can be true when coming up with sample data.

The easy route is just to use the .createDataFrame() call that can be found on any SparkSession object.

df = spark_session.createDataFrame([('1', 'Billbo'), ('2', 'Frodo')], ['Some_Column', 'Another_Column'])

This of course is a simple way to create a DataFrame that can be used inside any unit test, it’s easy and quick, and you are in total control.

But, this comes with one big drawback.

This has burned me many times in the past. When we grade our own papers we are too nice to ourselves. When we create our own sample data this way, we are too nice to ourselves.

Not using real sample data in the form of say 10 records from a real file … you can mask data problems and trick yourself and your code into thinking that the method or function you are testing is perfect, when in fact maybe that amount column has some weird values you were not expecting.

Use real sample data files when possible, it’s a better real-world test.

Musings

I hope these 3 PySpark unit testing tips were helpful for some of you out there in the ether. Writing code that is modular and functional in PySpark pays off in ways you can’t even imagine when you begin your journey. Every task downstream because slightly more efficient and easy.

Debugging some code? Code broken into small bite-sized chunks is easier to digest and work on.

Looking to reduce the amount and complexity of code you have to manage in your Spark pipelines? Using these three tips will allow you to do the best part of Data Engineering … delete some code.

Looking to take your unit tests to the next level? Instead of creating your own DataFrame designed to pass your unit tests … group a handful of rows from a real data source and use that instead … it’s a bug catcher for sure!