Home > Mobile >  Unit test pyspark and accumulator
Unit test pyspark and accumulator

Time:04-06

I'm trying to test my Spark code in Python but I'm whenever my test code runs, all my accumulators are empty. However, when I run the code locally without mocks, the code works fine and the accumulators have values. Here's a trimmed down version of the code:

Code:

from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession

columns: Any = []

class SetAccumulator(AccumulatorParam):
    def zero(self, value):
        return value.copy()

    def addInPlace(self, value1, value2):
        return value1.union(value2)

def read_columns(obj: dict) -> None:
    global columns

    for key in obj.keys():
        columns  = {key}

def run(spark: SparkSession, df: DataFrame) -> list:
    global columns
    columns = spark.sparkContext.accumulator(set(), SetAccumulator())
    df.rdd.foreach(lambda row: read_columns(row.asDict()))
    return list(columns.value)

Mock Spark test code:

import pydeequ
from unittest import TestCase
from pyspark.sql import SparkSession

class SparkTestCase(TestCase):
    spark: SparkSession

    @classmethod
    def setUpClass(cls) -> None:
        cls.spark = (
            SparkSession.builder.appName("testspark")
              .master("local")
              .enableHiveSupport()
              .config("spark.jars.packages", pydeequ.deequ_maven_coord)
              .config("spark.jars.excludes", pydeequ.f2j_maven_coord)
              .config("spark.sql.shuffle.partitions", 8)
              .getOrCreate()
        )

Test code:

from tests.spark.testcase import SparkTestCase
from foo.bar import run

class TestFoo(SparkTestCase):
    def test_foo(self):
        columns = [
            "test",
            "bar",
            "name"
        ]
        data = [
            (
                "Hello!",
                100,
                "Foobar"
            )
        ]

        df = self.spark.createDataFrame(data, columns)
        response = run(self.spark, df)
        print(response)

The test prints out an empty list. But as mentioned, when I run this outside of the testing framework (locally, on my computer), it prints out ["test", "bar", "name"].

What am I doing wrong or what do I need to add to make it work in the test case?

CodePudding user response:

I figured out one way to make the unit test work. I created a dictionary of the accumulators and passed that to each task and the test was able to update the values correctly. I'm assuming global doesn't work correctly with Spark unit tests.

Here's what the updated code looks like. The test code from the question above remains the same.

from typing import Any
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import DataFrame, SparkSession

class SetAccumulator(AccumulatorParam):
    def zero(self, value):
        return value.copy()

    def addInPlace(self, value1, value2):
        return value1.union(value2)

def read_columns(obj: dict, accumulators: dict) -> None:
    for key in obj.keys():
        accumulators["columns"]  = {key}

def run(spark: SparkSession, df: DataFrame) -> list:
    columns = spark.sparkContext.accumulator(set(), SetAccumulator())
    accumulators = {"columns": columns}
    df.rdd.foreach(lambda row: read_columns(row.asDict(), accumulators))
    return list(columns.value)

I removed the columns variable on the top and removed all global references to it. Instead, I created a dictionary accumulators = {"columns": columns}, which I pass to the read_columns function and get the accumulator by key.

The test correctly prints out ["test", "bar", "name"] now and it still works outside of the test environment.

  • Related