Home > Back-end >  Can I interrogate a PySpark DataFrame to get the list of referenced columns?
Can I interrogate a PySpark DataFrame to get the list of referenced columns?

Time:12-02

Given a PySpark DataFrame is it possible to obtain a list of source columns that are being referenced by the DataFrame?

Perhaps a more concrete example might help explain what I'm after. Say I have a DataFrame defined as:

import pyspark.sql.functions as func
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
source_df = spark.createDataFrame(
    [("pru", 23, "finance"), ("paul", 26, "HR"), ("noel", 20, "HR")],
    ["name", "age", "department"],
)
source_df.createOrReplaceTempView("people")
sqlDF = spark.sql("SELECT name, age, department FROM people")
df = sqlDF.groupBy("department").agg(func.max("age").alias("max_age"))
df.show()

which returns:

 ---------- --------                                                            
|department|max_age |
 ---------- -------- 
|   finance|      23|
|        HR|      26|
 ---------- -------- 

The columns that are referenced by df are [department, age]. Is it possible to get that list of referenced columns programatically?

Thanks to Capturing the result of explain() in pyspark I know I can extract the plan as a string:

df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "formatted")

which returns:

== Physical Plan ==
AdaptiveSparkPlan (6)
 - HashAggregate (5)
    - Exchange (4)
       - HashAggregate (3)
          - Project (2)
             - Scan ExistingRDD (1)


(1) Scan ExistingRDD
Output [3]: [name#0, age#1L, department#2]
Arguments: [name#0, age#1L, department#2], MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)

(2) Project
Output [2]: [age#1L, department#2]
Input [3]: [name#0, age#1L, department#2]

(3) HashAggregate
Input [2]: [age#1L, department#2]
Keys [1]: [department#2]
Functions [1]: [partial_max(age#1L)]
Aggregate Attributes [1]: [max#22L]
Results [2]: [department#2, max#23L]

(4) Exchange
Input [2]: [department#2, max#23L]
Arguments: hashpartitioning(department#2, 200), ENSURE_REQUIREMENTS, [plan_id=60]

(5) HashAggregate
Input [2]: [department#2, max#23L]
Keys [1]: [department#2]
Functions [1]: [max(age#1L)]
Aggregate Attributes [1]: [max(age#1L)#12L]
Results [2]: [department#2, max(age#1L)#12L AS max_age#13L]

(6) AdaptiveSparkPlan
Output [2]: [department#2, max_age#13L]
Arguments: isFinalPlan=false

which is useful, however its not what I need. I need a list of the referenced columns. Is this possible?

Perhaps another way of asking the question is... is there a way to obtain the explain plan as an object that I can iterate over/explore?

CodePudding user response:

There is an object for that unfortunately its a java object, and not translated to pyspark.

You can still access it with Spark constucts:

>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(0).toString()
u'department#1621'
>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(1).toString()
u'max_age#1632L'

You could loop through both the above apply to get the information you are looking for with something like:

plan = df._jdf.queryExecution().executedPlan()
steps = [ plan.apply(i).toString() for i in range(1,100) if not isinstance(plan.apply(i), type(None)) ]

Bit of a hack but apparently size doesn't work.

CodePudding user response:

You can try the below codes, this will give you a column list and its data type in the data frame.

for field in df.schema.fields:
    print(field.name  " , " str(field.dataType))
  • Related