Home > database >  Reorder PySpark dataframe columns on specific sort logic
Reorder PySpark dataframe columns on specific sort logic

Time:10-19

I have a PySpark dataframe with the below column order. I need to order it as per the 'branch'. How do I do it? df.select(sorted(df.columns)) doesn't seem to work the way I want.

Existing column order:

store_id,
store_name,
month_1_branch_A_profit,
month_1_branch_B_profit,
month_1_branch_C_profit,
month_1_branch_D_profit,
month_2_branch_A_profit,
month_2_branch_B_profit,
month_2_branch_C_profit,
month_2_branch_D_profit,
.
.
month_12_branch_A_profit,
month_12_branch_B_profit,
month_12_branch_C_profit,
month_12_branch_D_profit

Desired column order:

store_id,
store_name,
month_1_branch_A_profit,
month_2_branch_A_profit,
month_3_branch_A_profit,
month_4_branch_A_profit,
.
.
month_12_branch_A_profit,
month_1_branch_B_profit,
month_2_branch_B_profit,
month_3_branch_B_profit,
.
.
month_12_branch_B_profit,
..

CodePudding user response:

You could manually build your list of columns.

col_fmt = 'month_{}_branch_{}_profit'
cols = ['store_id', 'store_name']
for branch in ['A', 'B', 'C', 'D']:
    for i in range(1, 13):
        cols.append(col_fmt.format(i, branch))
df.select(cols)

Alternatively, I'd recommend building a better dataframe that takes advantage of array struct/map datatypes. E.g.

months - array (size 12)
  - branches: map<string, struct>
    - key: string  (branch name)
    - value: struct
      - profit: float

This way, arrays would already be "sorted". Map order doesn't really matter, and it makes SQL queries specific to certain months and branches easier to read (and probably faster with predicate pushdowns)

CodePudding user response:

You may need to use some python coding. In the following script I split the column names based on underscore _ and then sorted according to elements [3] (branch name) and [1] (month value).

Input df:

cols = ['store_id',
        'store_name',
        'month_1_branch_A_profit',
        'month_1_branch_B_profit',
        'month_1_branch_C_profit',
        'month_1_branch_D_profit',
        'month_2_branch_A_profit',
        'month_2_branch_B_profit',
        'month_2_branch_C_profit',
        'month_2_branch_D_profit',
        'month_12_branch_A_profit',
        'month_12_branch_B_profit',
        'month_12_branch_C_profit',
        'month_12_branch_D_profit']
df = spark.createDataFrame([], ','.join([f'{c} int' for c in cols]))

Script:

branch_cols = [c for c in df.columns if c not in{'store_id', 'store_name'}]
d = {tuple(c.split('_')):c for c in branch_cols}
df = df.select(
    'store_id', 'store_name',
    *[d[c] for c in sorted(d, key=lambda x: f'{x[3]}_{int(x[1]):02}')]
)

df.printSchema()
# root
#  |-- store_id: integer (nullable = true)
#  |-- store_name: integer (nullable = true)
#  |-- month_1_branch_A_profit: integer (nullable = true)
#  |-- month_2_branch_A_profit: integer (nullable = true)
#  |-- month_12_branch_A_profit: integer (nullable = true)
#  |-- month_1_branch_B_profit: integer (nullable = true)
#  |-- month_2_branch_B_profit: integer (nullable = true)
#  |-- month_12_branch_B_profit: integer (nullable = true)
#  |-- month_1_branch_C_profit: integer (nullable = true)
#  |-- month_2_branch_C_profit: integer (nullable = true)
#  |-- month_12_branch_C_profit: integer (nullable = true)
#  |-- month_1_branch_D_profit: integer (nullable = true)
#  |-- month_2_branch_D_profit: integer (nullable = true)
#  |-- month_12_branch_D_profit: integer (nullable = true)
  • Related