Home > database >  Use enumerate to get partition columns from dataframe
Use enumerate to get partition columns from dataframe

Time:04-29

I am trying to get all columns and their datatypes into a variable, also only the partition columns into another variable of list type in python.

Getting details from describe extended.

df = spark.sql("describe extended schema_name.table_name")

     ---------------------------------------------------------- 
    |col_name                    |data_type                                                                       |
     ---------------------------- ----------------------------- 
    |col1                        |string                                                                          |
    |col2                        |int                                                                      
    |col3                        |string   
    |col4                        |int                                                                      
    |col5                        |string   
    |# Partition Information     |                                                                                |
    |# col_name                  |data_type                                                                       |
    |col4                        |int                                                                          |
    |col5                        |string                                                                          |
    |                            |                                                                                |
    |# Detailed Table Information|                                                                                |
    |Database                    |schema_name                                                                         |
    |Table                       |table_name                                                                        |
    |Owner                       |owner.name                                                                  |

Converting result into a list.

des_list=df.select(df.col_name,df.data_type).rdd.map(lambda x:(x[0],x[1])).collect()

Here is how I am trying to get all columns(all items until before # Partition Information).

  all_cols_name_type=[]
    for index,item in enumerate(des_list):
        if item[0]=='# Partition Information':
            all_cols_name_type.append(des_list[:index])

    

For partitions, i would like to get everything between the items '# col_name' and line before ''(line before # Detailed Table Information)

Any help is appreciated to be able to get this.

CodePudding user response:

You can try using the following answer or equivalent in Scala:

    val (partitionCols, dataCols) = spark.catalog.listColumns("schema_name.table_name")
      .collect()
      .partition(c => c.isPartition)

    val parCols = partitionCols.map(c => (c.name, c.dataType))
    val datCols = dataCols.map(c => (c.name, c.dataType))

If the table is not defined in the catalog (e.g reading parquet dataset directly from s3 using spark.read.parquet("s3://path/...")) then you can use the following snippet in Scala:

    val (partitionSchema, dataSchema) = df.queryExecution.optimizedPlan match {
      case LogicalRelation(hfs: HadoopFsRelation, _, _, _) =>
       (hfs.partitionSchema, hfs.dataSchema)
      case DataSourceV2ScanRelation(_, scan: FileScan, _) =>
        (scan.readPartitionSchema, scan.readDataSchema)
      case _ => (StructType(Seq()), StructType(Seq()))
    }

    val parCols = partitionSchema.map(f => (f.name, f.dataType))
    val datCols = dataSchema.map(f => (f.name, f.dataType))

CodePudding user response:

There is a trick to do so: You can use monotonically_increasing_id to give each row a number, find the row that has # col_name and get that index. Something like this

My sample table
df = spark.sql('describe data')
df = df.withColumn('id', F.monotonically_increasing_id())
df.show()

 -------------------- --------- ------- --- 
|            col_name|data_type|comment| id|
 -------------------- --------- ------- --- 
|                  c1|      int|   null|  0|
|                  c2|   string|   null|  1|
|# Partition Infor...|         |       |  2|
|          # col_name|data_type|comment|  3|
|                  c2|   string|   null|  4|
 -------------------- --------- ------- --- 
tricky part
idx = df.where(F.col('col_name') == '# col_name').first()['id']
# 3

partition_cols = [r['col_name'] for r in df.where(F.col('id') > idx).collect()]
# ['c2']
  • Related