DEV Community

Cover image for Unit testing your PySpark library
Darren Fuller
Darren Fuller

Posted on

Unit testing your PySpark library

In software development we often unit test our code (hopefully). And code written for Spark is no different. So here I want to run through an example of building a small library using PySpark and unit testing it. I'm using Visual Studio Code as my editor here, mostly because I think it's brilliant, but other editors are available.

Building the demo library

I really wanted to avoid the usual "lets add two numbers together" example and use something a little more like something you might come across day-to-day. So, I'm going to create a library with some methods around creating a date dimension.

A date dimension is something you will tend to come across in any Data Warehouse or Data LakeHouse environment where you're working with dates. And really that's pretty much everywhere! This kind of dimension is typically created to contain information about days of a year which users might want to slice-and-dice data with later on. If you can join your fact tables to a date dimension based on a date key, and then want to see all data in a given month, quarter, last week, over weekends, it gets much easier.

I'm going to stick with a fairly simple date dimension for brevity, so it will just include the key, basic information about the day, month, and year, and the ISO week number.

The ISO week number is the kicker here, and it's better explained over on Wikipedia but it gets interesting because, contrary to popular belief, there aren't always 52 weeks in a year (dates are always so much fun). This isn't something for which a function in Spark exists, but fortunately this is a piece of functionality we can get straight out of python, as the datetime class has an isocalendar function which gives us the ISO values back as a tuple.

But... We can't just call Python code directly from PySpark (well, we can do, but we shouldn't because it's really bad for performance). So we're going to need to create a user defined function.

PySpark has gotten a lot easier to work with in recent years, and that's especially true of UDFs with the arrival of Pandas UDFs. If you want to know why then I suggest you read the docs as it's a bit long winded for here, but it means we can write User Defined Functions in Python which are significantly more performant. Still not quite as good as Scala, but when you think about what it needs to do under the hood, they're very impressive.

The ISO week number function

Write, lets dive into the code for the User Defined Function then.

def get_iso_weeknumber(dt: Union[date, datetime]) -> int:
    if type(dt) is not datetime and type(dt) is not date:
        return None

    return dt.isocalendar()[1]


@pandas_udf(IntegerType())
def iso_week_of_year(dt: pd.Series) -> pd.Series:
    return dt.apply(lambda x: get_iso_weeknumber(x) if not pd.isnull(x) else pd.NA)
Enter fullscreen mode Exit fullscreen mode

I've deliberately broken it into 2 functions here, and I probably could have used some better names but meh. The reason for breaking it down is because of testing. Where possible we want to test custom logic outside of Spark so we don't need to spin up a new session, this means we can write a lot of tests and execute them quickly to check it's working, and that's the purpose of the get_iso_weeknumber function.

In the iso_week_of_year function we're getting a Pandas Series as our input, and we're returning another Pandas Series. This is a Series-to-Series function and the output must be the same length as the input. The code in this function is deliberately brief as we want to keep as much logic out of here as possible, it should really just care about applying the get_iso_weeknumber function to the input. And that's all it's doing, as long as the input value isn't null.

The date dimension

Now that we're able to create our ISO week number we can create a date dimension on demand. There's a few different ways of doing this, but my preferred method is using the SparkSession.range function using UNIX timestamps.

def create_date_dimension(spark: SparkSession, start: datetime, end: datetime) -> DataFrame:
    if spark is None or type(spark) is not SparkSession:
        raise ValueError("A valid SparkSession instance must be provided")

    if type(start) is not datetime:
        raise ValueError("Start date must be a datetime.datetime object")

    if type(end) is not datetime:
        raise ValueError("End date must be a datetime.datetime object")

    if start >= end:
        raise ValueError("Start date must be before the end date")

    if start.tzinfo is None:
        start = datetime.combine(start.date(), time(0, 0, 0), tzinfo=timezone.utc)

    if end.tzinfo is None:
        end = datetime.combine(end.date(), time(0, 0, 0), tzinfo=timezone.utc)

    end = end + timedelta(days=1)

    return (
        spark.range(start=start.timestamp(), end=end.timestamp(), step=24 * 60 * 60)
             .withColumn("date", to_date(from_unixtime("id")))
             .withColumn("date_key", date_format("date", "yyyyMMdd").cast("int"))
             .withColumn("day", dayofmonth("date"))
             .withColumn("day_name", date_format("date", "EEEE"))
             .withColumn("day_short_name", date_format("date", "EEE"))
             .withColumn("month", month("date"))
             .withColumn("month_name", date_format("date", "MMMM"))
             .withColumn("month_short_name", date_format("date", "MMM"))
             .withColumn("year", year("date"))
             .withColumn("week_number", iso_week_of_year("date"))
    )
Enter fullscreen mode Exit fullscreen mode

There's a lot of setup in this function, just to make sure we're giving all the right information to create the dimension, but the actual generation is pretty straight forward. You can see on the last line we're using our iso_week_of_year UDF to get the week number and calling it just like any other Spark function.

Using the SparkSession.range function with a step of 86,400 means we're incrementing by 1 day at a time, regardless of any time component we're passing in to the function in the first place.

Running this gives us a table looking like the following:

>>> start = datetime(2021, 1, 1)
>>> end = datetime(2021, 1, 10)
>>> create_date_dimension(spark, start, end).show()
+----------+----------+--------+---+---------+--------------+-----+----------+----------------+----+-----------+
|        id|      date|date_key|day| day_name|day_short_name|month|month_name|month_short_name|year|week_number|
+----------+----------+--------+---+---------+--------------+-----+----------+----------------+----+-----------+
|1609459200|2021-01-01|20210101|  1|   Friday|           Fri|    1|   January|             Jan|2021|         53|
|1609545600|2021-01-02|20210102|  2| Saturday|           Sat|    1|   January|             Jan|2021|         53|
|1609632000|2021-01-03|20210103|  3|   Sunday|           Sun|    1|   January|             Jan|2021|         53|
|1609718400|2021-01-04|20210104|  4|   Monday|           Mon|    1|   January|             Jan|2021|          1|
|1609804800|2021-01-05|20210105|  5|  Tuesday|           Tue|    1|   January|             Jan|2021|          1|
|1609891200|2021-01-06|20210106|  6|Wednesday|           Wed|    1|   January|             Jan|2021|          1|
|1609977600|2021-01-07|20210107|  7| Thursday|           Thu|    1|   January|             Jan|2021|          1|
|1610064000|2021-01-08|20210108|  8|   Friday|           Fri|    1|   January|             Jan|2021|          1|
|1610150400|2021-01-09|20210109|  9| Saturday|           Sat|    1|   January|             Jan|2021|          1|
|1610236800|2021-01-10|20210110| 10|   Sunday|           Sun|    1|   January|             Jan|2021|          1|
+----------+----------+--------+---+---------+--------------+-----+----------+----------------+----+-----------+
Enter fullscreen mode Exit fullscreen mode

You can see that the first few days of 2021 were part of week 53, and week 1 didn't actually start until the 4th January.

Unit testing the code

But the title of this post is about unit testing the code, so how is that done?

Well we've seen that the ISO week number functions are broken out, so unit testing get_iso_weeknumber is pretty straight forward and can be tested just like any other Python code.

Whilst we can test the iso_week_of_year function we lose the ability to capture any coverage information because of the way Spark calls the function. If you're using Coverage.py then you can add patterns to your .coveragerc file to exclude these functions like this from your coverage results such as the following.

[report]
exclude_lines =
    def .* -> pd.Series
Enter fullscreen mode Exit fullscreen mode

We're going to test our create_date_dimension function which will call the iso_week_of_year function.

We'll need a Spark Session to test the code with, and the most obvious way of doing that would be like this (I'm using the Python unittest library here).

class DateDimensionTests(unittest.TestCase):
        def test_simple_creation(self):
        spark = SparkSession.builder.appName("unittest").master("local[*]").getOrCreate()

        df = create_date_dimension(spark, datetime(2021, 1, 1), datetime(2021, 1, 5))
Enter fullscreen mode Exit fullscreen mode

But there's a problem with this. For every single test we're going to be creating a brand new Spark Session, which is going to get expensive. And if we need to change any of the spark configuration, we're going to need to do it in every single test!

So what's a better way? Well, we can create our own test class based on unittest.TestCase which does all of the set up for us.

class PySparkTestCase(unittest.TestCase):

    @classmethod
    def setUpClass(cls) -> None:
        conf = SparkConf()
        conf.setAppName("demo-unittests")
        conf.setMaster("local[*]")
        conf.set("spark.sql.session.timezone", "UTC")

        cls.spark = SparkSession.builder.config(conf=conf).getOrCreate()

    @classmethod
    def tearDownClass(cls) -> None:
        cls.spark.stop()

    def tearDown(self) -> None:
        self.spark.catalog.clearCache()

    def assert_dataframes_equal(self, expected: DataFrame, actual: DataFrame, display_differences: bool = False) -> None:
        diff: DataFrame = actual.subtract(expected).cache()
        diff_count: int = diff.count()

        try:
            if diff_count != 0 and display_differences:
                diff.show()
            self.assertEqual(0, diff_count)
        finally:
            diff.unpersist
Enter fullscreen mode Exit fullscreen mode

What have we got now? Well, we have our setUpClass method which is doing the work of setting up our Spark Session for us; and we have our tearDownClass method which clears up the Spark Session once we're done.

We've also added in a tearDown method which clears up after each test, in this case it's clearing out the cache, just in case. And there's a new method for comparing data frames which abstracts a lot of repetitive code we might write in our unit tests.

One of the other things that's different in this setup is the line conf.set("spark.sql.session.timezone", "UTC"). Timestamps are evil. It's that simple. A date is sort of okay, it represents a point in the calendar so they're nice to work with. But the range we're creating is using seconds from the epoch, and so we're in the world of Timestamps, and that world has Daylight Savings Times and Timezones! And how they work varies based on your OS of choice.

Running the following on my Windows 10 laptop using Python 3 gives an error of OSError: [Errno 22] Invalid argument

from datetime import datetime
datetime(1066, 10, 14).timestamp()
Enter fullscreen mode Exit fullscreen mode

But running it in Ubuntu 20.04 on WSL gives a result of -28502755125.0. Now, running it again in both, but with timezone information provided.

from datetime import datetime, timezone
datetime(1066, 10, 14, tzinfo=timezone.utc).timestamp()
Enter fullscreen mode Exit fullscreen mode

Well, now they both give an answer of -28502755200.0. So, setting the spark session timezone to UTC means that it's going to be handling timestamps consistently, regardless of where in the world I am or on what OS I'm running.

Writing the tests themselves

Now that we have our test case class, we can write some unit tests, and by putting a number of them under the same test class we're going to use the same Spark Session, which means that we're not going to incur a performance penalty for constantly re-creating it.

class DateDimensionTests(PySparkTestCase):
    def test_simple_creation(self):
        expected = self.spark.read.csv(
            self.get_input_file("simple_date_dim.csv"),
            header=True,
            schema=DateDimensionTests.get_dimension_schema())

        result = create_date_dimension(self.spark, datetime(2021, 1, 1), datetime(2021, 1, 5))

        self.assert_dataframes_equal(expected, result, display_differences=True)

    def test_simple_range_week53(self):
        week_53_day_count = (
            create_date_dimension(self.spark, datetime(2020, 12, 1), datetime(2021, 1, 31))
            .filter(col("week_number") == 53)
            .count()
        )

        self.assertEqual(7, week_53_day_count)
Enter fullscreen mode Exit fullscreen mode

These can now be executed using the Visual Studio Code test runner...

Tests executed using the VSCode test runner

... or from the command line.

python -m unittest discover -v -s ./tests -p test_*.py
Enter fullscreen mode Exit fullscreen mode

They're still not as performant as tests which don't require a Spark Session and so, where possible, make sure business logic is tested outside of a Spark Session. But if you need one, then this is definitely a way to go.

Summary

Phew. So this is my first post to Dev.to. I've been sat reading content from others for a while and thought it was time I contributed something of my own. Hopefully it helps someone out, or gives someone a few ideas for doing something else.

If you can think of anything which could be improved, or you have a different way of doing things then let me know, things only improve with collaboration.

If you want to have a look at the code, or use it in any way (it's MIT licensed), then you can find it over on Github.

Top comments (0)