Given a PySpark DataFrame is it possible to obtain a list of source columns that are being referenced by the DataFrame?
Perhaps a more concrete example might help explain what I'm after. Say I have a DataFrame defined as:
import pyspark.sql.functions as func
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
source_df = spark.createDataFrame(
[("pru", 23, "finance"), ("paul", 26, "HR"), ("noel", 20, "HR")],
["name", "age", "department"],
)
source_df.createOrReplaceTempView("people")
sqlDF = spark.sql("SELECT name, age, department FROM people")
df = sqlDF.groupBy("department").agg(func.max("age").alias("max_age"))
df.show()
which returns:
---------- --------
|department|max_age |
---------- --------
| finance| 23|
| HR| 26|
---------- --------
The columns that are referenced by df
are [department, age]
. Is it possible to get that list of referenced columns programatically?
Thanks to Capturing the result of explain() in pyspark I know I can extract the plan as a string:
df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "formatted")
which returns:
== Physical Plan ==
AdaptiveSparkPlan (6)
- HashAggregate (5)
- Exchange (4)
- HashAggregate (3)
- Project (2)
- Scan ExistingRDD (1)
(1) Scan ExistingRDD
Output [3]: [name#0, age#1L, department#2]
Arguments: [name#0, age#1L, department#2], MapPartitionsRDD[4] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(2) Project
Output [2]: [age#1L, department#2]
Input [3]: [name#0, age#1L, department#2]
(3) HashAggregate
Input [2]: [age#1L, department#2]
Keys [1]: [department#2]
Functions [1]: [partial_max(age#1L)]
Aggregate Attributes [1]: [max#22L]
Results [2]: [department#2, max#23L]
(4) Exchange
Input [2]: [department#2, max#23L]
Arguments: hashpartitioning(department#2, 200), ENSURE_REQUIREMENTS, [plan_id=60]
(5) HashAggregate
Input [2]: [department#2, max#23L]
Keys [1]: [department#2]
Functions [1]: [max(age#1L)]
Aggregate Attributes [1]: [max(age#1L)#12L]
Results [2]: [department#2, max(age#1L)#12L AS max_age#13L]
(6) AdaptiveSparkPlan
Output [2]: [department#2, max_age#13L]
Arguments: isFinalPlan=false
which is useful, however its not what I need. I need a list of the referenced columns. Is this possible?
Perhaps another way of asking the question is... is there a way to obtain the explain plan as an object that I can iterate over/explore?
CodePudding user response:
There is an object for that unfortunately its a java object, and not translated to pyspark.
You can still access it with Spark constucts:
>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(0).toString()
u'department#1621'
>>> df._jdf.queryExecution().executedPlan().apply(0).output().apply(1).toString()
u'max_age#1632L'
You could loop through both the above apply
to get the information you are looking for with something like:
plan = df._jdf.queryExecution().executedPlan()
steps = [ plan.apply(i).toString() for i in range(1,100) if not isinstance(plan.apply(i), type(None)) ]
Bit of a hack but apparently size
doesn't work.
CodePudding user response:
You can try the below codes, this will give you a column list and its data type in the data frame.
for field in df.schema.fields:
print(field.name " , " str(field.dataType))