I want to compute by hand some custom summary statistics of a large dataframe on PySpark. For the sake of simplicity, let me use a simpler dummy dataset, as the following:
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import DataType, NumericType, DateType, TimestampType
import pyspark.sql.types as t
import pyspark.sql.functions as f
from datetime import datetime
spark = (
SparkSession.builder
.appName("pyspark")
.master("local[*]")
.getOrCreate()
)
dd = [
("Alice", 18.0, datetime(2022, 1, 1)),
("Bob", None, datetime(2022, 2, 1)),
("Mark", 33.0, None),
(None, 80.0, datetime(2022, 4, 1)),
]
schema = t.StructType(
[
t.StructField("T", t.StringType()),
t.StructField("C", t.DoubleType()),
t.StructField("D", t.DateType()),
]
)
df = spark.createDataFrame(dd, schema)
Ok, the thing is, I want to compute some aggregations: missing counts
, stddev
, max
and min
from all the columns, and of course I'd want to do it in parallel. Well, I can take two approaches for this:
Approach 1: One select query
This way, I let the Spark engine make the parallel computing by making one big select query. Let's see:
def df_dtypes(df: DataFrame) -> List[Tuple[str, DataType]]:
"""
Like df.dtypes attribute of Spark DataFrame, but returning DataType objects instead
of strings.
"""
return [(str(f.name), f.dataType) for f in df.schema.fields]
def get_missing(df: DataFrame) -> Tuple:
suffix = "__missing"
result = (
*(
(
f.count(
f.when(
(f.isnan(c) | f.isnull(c)),
c,
)
)
/ f.count("*")
* 100
if isinstance(t, NumericType) # isnan only works for numeric types
else f.count(
f.when(
f.isnull(c),
c,
)
)
/ f.count("*")
* 100
)
.cast("double")
.alias(c suffix)
for c, t in df_dtypes(df)
),
)
return result
def get_min(df: DataFrame) -> Tuple:
suffix = "__min"
result = (
*(
(f.min(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
.cast(t)
.alias(c suffix)
for c, t in df_dtypes(df)
),
)
return result
def get_max(df: DataFrame) -> Tuple:
suffix = "__max"
result = (
*(
(f.max(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
.cast(t)
.alias(c suffix)
for c, t in df_dtypes(df)
),
)
return result
def get_std(df: DataFrame) -> Tuple:
suffix = "__std"
result = (
*(
(f.stddev(c) if isinstance(t, NumericType) else f.lit(None)).cast(t).alias(c suffix)
for c, t in df_dtypes(df)
),
)
return result
# build the big query
query = get_min(df) get_max(df) get_missing(df) get_std(df)
# run the job
df.select(*query).show()
As far as I know, this job will run in parallel because the internals of Spark works. Is this approach efficient? The problem with this might be the huge number of columns with suffixes that it creates, could it be a bottle neck?
Approach 2: Using threads
In this approach, I can make use of Python threads to try to perform each calculation concurrently.
from pyspark import InheritableThread
from queue import Queue
def get_min(df: DataFrame, q: Queue) -> None:
result = df.select(
f.lit("min").alias("summary"),
*(
(f.min(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
.cast(t)
.alias(c)
for c, t in df_dtypes(df)
),
).collect()
q.put(result)
def get_max(df: DataFrame, q: Queue) -> None:
result = df.select(
f.lit("max").alias("summary"),
*(
(f.max(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
.cast(t)
.alias(c)
for c, t in df_dtypes(df)
),
).collect()
q.put(result)
def get_std(df: DataFrame, q: Queue) -> None:
result = df.select(
f.lit("std").alias("summary"),
*(
(f.stddev(c) if isinstance(t, NumericType) else f.lit(None)).cast(t).alias(c)
for c, t in df_dtypes(df)
),
).collect()
q.put(result)
def get_missing(df: DataFrame, q: Queue) -> None:
result = df.select(
f.lit("missing").alias("summary"),
*(
(
f.count(
f.when(
(f.isnan(c) | f.isnull(c)),
c,
)
)
/ f.count("*")
* 100
if isinstance(t, NumericType) # isnan only works for numeric types
else f.count(
f.when(
f.isnull(c),
c,
)
)
/ f.count("*")
* 100
)
.cast("double")
.alias(c)
for c, t in df_dtypes(df)
),
).collect()
q.put(result)
# caching the dataframe to reuse it for all the jobs?
df.cache()
# I use a queue to retrieve the results from the threads
q = Queue()
threads = [
InheritableThread(target=fun, args=(df, q)).start()
for fun in (get_min, get_max, get_missing, get_std)
]
# and then some code to recover the results from the queue
This way has the advantage of not ending up with dozens of columns with suffixes, just the original columns. But I'm not sure how this way deals with the GIL, is that actually parallel?
Could you tell me which one do you prefer? Or some suggestions about different ways to compute them?
At the end I want to build a JSON with all of this aggregated statistics. The structure of JSON is not relevant, it would depend on the approach taken. For the first one, I'd get something like {"T__min": None, "T__max": None, "T__missing": 1, "T__std": None, "C__min": 18.0, "C__max": 80.0, ...} so this way I end up with tons of fields and the select query would be huge. For the second approach I would get one JSON per variable with those statistics.
CodePudding user response:
I'm not really familiar with the InheritableThread
and Queue
, but as far as I can see, you want to create threads based on statistics. Meaning, every thread calculating a different statistic. This doesn't look optimized by design. I mean, some statistics will likely be calculated quicker than others. And then your processing power in those threads will not be used.
As you know, Spark is a distributed computing system which performs all the parallelism for you. I very highly doubt you can outperform Spark's optimization using Python's tools. If we could do that, it would already be integrated into Spark.
The first approach is very nicely written: conditional statements based on data types, inclusion of isnan, type hints - well done. It would probably perform the best it's possible, it's definitely written efficiently. The biggest drawback is the nature that it will be run on the whole dataframe, but you can't really escape that. Regarding the number of columns, you shouldn't be worried. The whole select statement will be very long, but it's just one operation. The logical/physical plan should be efficient. In worst-case scenario, you could persist/cache the dataframe before this operation, as you may have problems if this dataframe is created using some complex code. But other than that you should be fine.
As an alternative, for some statistics you may consider using summary
:
df.summary().show()
# ------- ----- ------------------
# |summary| T| C|
# ------- ----- ------------------
# | count| 3| 3|
# | mean| null|43.666666666666664|
# | stddev| null| 32.34707611722168|
# | min|Alice| 18.0|
# | 25%| null| 18.0|
# | 50%| null| 33.0|
# | 75%| null| 80.0|
# | max| Mark| 80.0|
# ------- ----- ------------------
This approach would only work for numeric and string columns. Date/Timestamp columns (e.g. "D") are automatically excluded. But I'm not sure if this would be more efficient. And definitely it would be less clear, as it would add additional logic to your code which now is quite straightforward.