Home > Net >  How do I test this function?
How do I test this function?

Time:11-02

I have this function:

# spark already defined somewhere as:
spark = SparkSession.builder.appName("App").getOrCreate()

def read_data(spark):
    query = "SELECT * FROM table"
    pandas_df = pd.read_sql(query, conn)
    return spark.createDataFrame(pandas_df)

To test it:

from unittest import mock

@mock.patch("pandas.read_sql")
@mock.patch("pyspark.sql.SparkSession", autospec=True)
def test_read_data(spark_session, pandas_read_sql):
    result = read_data(spark_session)
    assert == ???

In what way should I test this that makes sense? Any help appreciated.

CodePudding user response:

In order to test your function you need to mock pandas.read_sql only, spark_session cannot be mocked and you need to have an instance to properly test your function. You can create your own pytest.fixture to meet this requirement.

from unittest.mock import patch

import pandas
import pyspark.sql
import pytest
from pyspark.sql import SparkSession

from your_module import read_data


@pytest.fixture
def spark_session():
    _spark_session = SparkSession.builder.appName("unit-tests").getOrCreate()
    yield _spark_session
    _spark_session.stop()


@patch("pandas.read_sql")
def test_read_data(mock_read_sql, spark_session):
    # given:
    mock_read_sql.return_value = pandas.DataFrame(
        [(1, "row1"), (2, "row2")], columns=["id", "column1"]
    )

    # when:
    spark_df = read_data(spark_session)

    # then:
    assert isinstance(spark_df, pyspark.sql.DataFrame)

You can do much more assertions and check whether created dataframe has proper schema and contains values you expect.

Tip: You should look at spark sql capabilities, because you probably don't need to use pandas to query your database.

  • Related