Home > database >  PySpark - Collect vs CrossJoin, which to choose to create a max column?
PySpark - Collect vs CrossJoin, which to choose to create a max column?

Time:11-18

Spark Masters!

Does anyone has some tips on which is better or faster on pyspark to create a column with the max number of another column.

Option A:

max_num = df.agg({"number": "max"}).collect()[0][0]
df = df.withColumn("max", f.lit(max_num))

Option B:

max_num = df2.select(f.max(f.col("number")).alias("max"))
df2 = df2.crossJoin(max_num)

Please feel free, to add any other comments, even not directly related, is more for learning purpose.

Please, feel free to add an option C, D …

On thread is a testable code I made (also any comments on the code are welcome)

Testing code:

import time
from pyspark.sql import SparkSession
import pyspark.sql.functions as f

# --------------------------------------------------------------------------------------
# 01 - Data creation
spark = SparkSession.builder.getOrCreate()

data = []
for i in range(10000):
    data.append(
        {
            "1": "adsadasd",
            "number": 1323,
            "3": "andfja"
         }
    )
    data.append(
        {
            "1": "afasdf",
            "number": 8908,
            "3": "fdssfv"
         }
    )
df = spark.createDataFrame(data)
df2 = spark.createDataFrame(data)
df.count()
df2.count()
print(df.rdd.getNumPartitions())
print(df2.rdd.getNumPartitions())
# --------------------------------------------------------------------------------------
# 02 - Tests

# B) Crossjoin
start_time = time.time()
max_num = df2.select(f.max(f.col("number")).alias("max"))
df2 = df2.crossJoin(max_num)
print(df2.count())
print("Collect time: ", time.time() - start_time)

# A) Collect
start_time = time.time()
max_num = df.agg({"number": "max"}).collect()[0][0]
df = df.withColumn("max", f.lit(max_num))
print(df.count())
print("Collect time: ", time.time() - start_time)


df2.show()
df.show()

Measure the performance of collect and crossjoin on pyspark.

CodePudding user response:

I added another method similar to your B method, which consists in creating a Window over all dataframe and then taking the maximum value on it:

df3.withColumn("max", F.max("number").over(Window.partitionBy()))

Here is how the three methods performed over a dataframe of 100 million rows (I couldn't fit much more into memory):

import time
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window

# --------------------------------------------------------------------------------------
# 01 - Data creation
spark = SparkSession.builder.getOrCreate()

data = pd.DataFrame({
  'aaa': '1',
  'number': np.random.randint(0, 100, size=100000000)
})
df = spark.createDataFrame(data)
df2 = spark.createDataFrame(data)
df3 = spark.createDataFrame(data)

# --------------------------------------------------------------------------------------
# 02 - Tests

# A) Collect
method = 'A'
start_time = time.time()
max_num = df.agg({"number": "max"}).collect()[0][0]
df = df.withColumn("max", F.lit(max_num))
print(f"Collect time method {method}: ", time.time() - start_time)

# B) Crossjoin
method = 'B'
start_time = time.time()
max_num = df2.select(F.max(F.col("number")).alias("max"))
df2 = df2.crossJoin(max_num)
print(f"Collect time method {method}: ", time.time() - start_time)

# C) Window
method = 'C'
start_time = time.time()
df3 = df3.withColumn("max", F.max("number").over(Window.partitionBy()))
print(f"Collect time method {method}: ", time.time() - start_time)

Results:

Collect time method A:  1.890228033065796
Collect time method B:  0.01714015007019043
Collect time method C:  0.03456592559814453

I tried the same code also with 100k rows; method A halves its collect time (~0.9 sec) but it's still high, whereas method B and C stay more or less the same.

No other sensible methods came to mind.
Therefore, it seems that method B may be the most efficient one.

CodePudding user response:

Made some changes to @ric-s great suggestion.

import time
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window

# --------------------------------------------------------------------------------------
# 01 - Data creation
spark = SparkSession.builder.getOrCreate()

data = pd.DataFrame({
  'aaa': '1',
  'number': np.random.randint(0, 100, size=1000000)
})
df = spark.createDataFrame(data)
df2 = spark.createDataFrame(data)
df3 = spark.createDataFrame(data)

df.count()
df2.count()
df3.count()
# --------------------------------------------------------------------------------------
# 02 - Tests

# A) Collect
method = 'A'
start_time = time.time()
max_num = df.agg({"number": "max"}).collect()[0][0]
df = df.withColumn("max", F.lit(max_num))
df.count()
print(f"Collect time method {method}: ", time.time() - start_time)

# B) Crossjoin
method = 'B'
start_time = time.time()
max_num = df2.select(F.max(F.col("number")).alias("max"))
df2 = df2.crossJoin(max_num)
df2.count()
print(f"Collect time method {method}: ", time.time() - start_time)

# C) Window
method = 'C'
start_time = time.time()
df3 = df3.withColumn("max", F.max("number").over(Window.partitionBy()))
df3.count()
print(f"Collect time method {method}: ", time.time() - start_time)

And got this results:

Collect time method A:  1.8250329494476318
Collect time method B:  1.373009204864502
Collect time method C:  0.4454350471496582
  • Related