Home > Mobile >  PySpark - Filter dataframe columns based on list
PySpark - Filter dataframe columns based on list

Time:04-06

I have a dataframe with some column names and I want to filter out some columns based on a list.

I have a list of columns I would like to have in my final dataframe:

final_columns = ['A','C','E']

My dataframe is this:

data1 = [("James",  "Lee", "Smith","36636"),
         ("Michael","Rose","Boots","40288")]

schema1 = StructType([StructField("A",StringType(),True),    
                      StructField("B",StringType(),True),    
                      StructField("C",StringType(),True),    
                      StructField("D",StringType(),True)])

df1 = spark.createDataFrame(data=data1,schema=schema1)

I would like to transform df1 in order to have the columns of this final_columns list.

So, basically, I expect the resulting dataframe to look like this

 -------- ------ ------  
|      A |    C |    E | 
 -------- ------ ------  
|  James |Smith |      | 
|Michael |Boots |      | 
 -------- ------ ------ 

Is there any smart way to do this?

Thank you in advance

CodePudding user response:

Based on your requirement have written a dynamic code. This will select columns based on the list provided and also create column with null values if that column is not present in the source/original dataframe.

data1 = [("James",  "Lee", "Smith","36636"),
         ("Michael","Rose","Boots","40288")]

schema1 = StructType([StructField("A",StringType(),True),    
                      StructField("B",StringType(),True),    
                      StructField("C",StringType(),True),    
                      StructField("D",StringType(),True)])

df1 = spark.createDataFrame(data=data1,schema=schema1)
actual_columns = df1.schema.names
final_columns = ['A','C','E']


def Diff(li1, li2):
  diff = list(set(li2) - set(li1))
  return diff
def Same(li1, li2):
  same = list(sorted(set(li1).intersection(li2)))
  return same

df1 = df1.select(*Same(actual_columns,final_columns)).withColumn(*Diff(actual_columns,final_columns),lit(''))
display(df1)

CodePudding user response:

Here is one way: use the DataFrame drop() method with a list which represents the symmetric difference between the DataFrame's current columns and your list of final columns.

df = spark.createDataFrame([(1, 1, "1", 0.1),(1, 2, "1", 0.2),(3, 3, "3", 0.3)],('a','b','c','d'))

df.show()
 --- --- --- --- 
|  a|  b|  c|  d|
 --- --- --- --- 
|  1|  1|  1|0.1|
|  1|  2|  1|0.2|
|  3|  3|  3|0.3|
 --- --- --- --- 

# list of desired final columns
final_cols = ['a', 'c', 'd']

df2 = df.drop( *set(final_cols).symmetric_difference(df.columns) )

Note an alternate syntax for the symmetric difference operation:

df2 = df.drop( *(set(final_cols) ^ set(df.columns)) )

This gives me:

 --- --- --- 
|  a|  c|  d|
 --- --- --- 
|  1|  1|0.1|
|  1|  1|0.2|
|  3|  3|0.3|
 --- --- --- 

Which I believe is what you want.

  • Related