Home > Enterprise >  Spark Count Large Number of Columns
Spark Count Large Number of Columns

Time:09-24

Ran into this a little while ago, and I think there should be a better/more efficient way of doing this:

I have a DF with about 70k columns and roughly 10k rows. I want to essentially get a count of each column based on the value of the row.

df.columns.map( c => df.where(column(c)===1).count )

This works for a small number of columns, but in this case, the large number of columns causes the process to take hours and appears to iterate through each column and query the data.

What optimizations can I do to get to the results faster?

CodePudding user response:

You can replace value of each column to 1 or 0 depending of whether the column previous value matches condition and then sum each column in one aggregation. After you can collect the unique row of the resulting dataframe and make it an array.

So the code would be as follow:

import org.apache.spark.sql.functions.{col, lit, sum, when}

val aggregation_columns = df.columns.map(c => sum(col(c)))

df
  .columns
  .foldLeft(df)((acc, elem) => acc.withColumn(elem, when(col(elem) === 1, lit(1)).otherwise(lit(0))))
  .agg(aggregation_columns.head, aggregation_columns.tail: _*)
  .collect()
  .flatMap(row => df.columns.indices.map(i => row.getLong(i))

CodePudding user response:

count_if counts all rows for that a condition matches. This SQL expression can be evaluated for all columns in a single pass:

df = ...
df.selectExpr( df.columns.map((c => s"count_if($c=1) as $c")):_* ).show()

explain prints (for three colums a, b and c:

== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[count(if (((a#10 = 1) = false)) null else (a#10 = 1)), count(if (((b#11 = 1) = false)) null else (b#11 = 1)), count(if (((c#12 = 1) = false)) null else (c#12 = 1))])
 - Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]
    - *(1) HashAggregate(keys=[], functions=[partial_count(if (((a#10 = 1) = false)) null else (a#10 = 1)), partial_count(if (((b#11 = 1) = false)) null else (b#11 = 1)), partial_count(if (((c#12 = 1) = false)) null else (c#12 = 1))])
       - *(1) LocalTableScan [a#10, b#11, c#12]
  • Related