Home > front end >  Flattening a pandas dataframe by creating new columns resulting in unique ID pairs
Flattening a pandas dataframe by creating new columns resulting in unique ID pairs

Time:01-28

I have a pandas dataframe like:

   id sid X_animal X_class Y_animal Y_class
0   1   A       88    Home   Monkey  Mammal
1   1   A       88    Home   Parrot    Bird
2   1   B
3   2   C       11    Work
4   2   C       11    Work
5   2   C       33  School      Dog  Mammal
6   3   D       44    Home   Salmon    Fish
7   3   D       44    Home     Bear  Mammal
8   3   D       44    Home      Dog  Mammal
9   4   E       55  School

and I want to flatten it so each id pairing (id, sid) is unique across rows. In this process, I want to create new columns from columns *_animal and *_class when their values differ for a given unique id pair. This is the dataframe I want:

   id sid X_animal_1 X_class_1 X_animal_2 X_class_2 Y_animal_1 Y_class_1 Y_animal_2 Y_class_2 Y_animal_3 Y_class_3
0   1   A         88      Home                          Monkey    Mammal     Parrot      Bird
1   1   B
2   2   C         11      Work         33    School        Dog    Mammal
3   3   D         44      Home                          Salmon      Fish       Bear    Mammal        Dog    Mammal
4   4   E         55    School

To build the initial and final dataframes, the code is:

import pandas as pd
from numpy import nan

cols = ['id', 'sid', 'X_animal', 'X_class', 'Y_animal', 'Y_class']
l = [
    [1, 'A', 88, 'Home', 'Monkey', 'Mammal'],
    [1, 'A', 88, 'Home', 'Parrot', 'Bird'],
    [1, 'B', nan, nan, nan, nan],
    [2, 'C', 11, 'Work', nan, nan],
    [2, 'C', 11, 'Work', nan, nan],
    [2, 'C', 33, 'School', 'Dog', 'Mammal'],
    [3, 'D', 44, 'Home', 'Salmon', 'Fish'],
    [3, 'D', 44, 'Home', 'Bear', 'Mammal'],
    [3, 'D', 44, 'Home', 'Dog', 'Mammal'],
    [4, 'E', 55, 'School', nan, nan],
]

df = pd.DataFrame(data=l, columns=cols)
print(df.fillna(''))

cols2 = ['id', 'sid', 'X_animal_1', 'X_class_1', 'X_animal_2', 'X_class_2', 'Y_animal_1', 'Y_class_1', 'Y_animal_2', 'Y_class_2', 'Y_animal_3', 'Y_class_3']
l2 = [
    [1, 'A', 88, 'Home', nan, nan, 'Monkey', 'Mammal', 'Parrot', 'Bird'],
    [1, 'B', nan, nan, nan, nan, nan, nan, nan, nan],
    [2, 'C', 11, 'Work', 33, 'School', 'Dog', 'Mammal', nan, nan],
    [3, 'D', 44, 'Home', nan, nan, 'Salmon', 'Fish', 'Bear', 'Mammal', 'Dog', 'Mammal'],
    [3, 'E', 55, 'School', nan, nan, nan, nan, nan, nan],
]

df2 = pd.DataFrame(data=l2, columns=cols2)
print(df2.fillna(''))

I've tried using pivot() and pivot_table() with no success. The variable amount of columns creates issues with that approach, giving me a KeyError.

CodePudding user response:

This is also known as pivot with two columns, basically, you need to enumerate the rows within a group before pivoting:

out = (df.assign(count=df.groupby(['id','sid']).cumcount().add(1))
   .pivot(['id','sid'],columns='count')
   .fillna('')
)

out.columns = [f'{x}_{y}' for x,y in out.columns]
out = out.reset_index()

Output:

   id sid X_animal_1 X_animal_2 X_animal_3 X_class_1 X_class_2 X_class_3 Y_animal_1 Y_animal_2 Y_animal_3 Y_class_1 Y_class_2 Y_class_3
0   1   A         88         88                 Home      Home               Monkey     Parrot               Mammal      Bird          
1   1   B                                                                                                                              
2   2   C         11         11         33      Work      Work    School                              Dog                        Mammal
3   3   D         44         44         44      Home      Home      Home     Salmon       Bear        Dog      Fish    Mammal    Mammal
4   4   E         55                          School                                                                                   

CodePudding user response:

new_cols = []
new_rows = []
for i, s in set(zip(df.id, df.sid)):
    _df = df[(df.id==i) & (df.sid==s)].copy()
    for pair in [['X_animal', 'X_class'],['Y_animal', 'Y_class']]:
        zets = set(zip(_df[pair[0]], _df[pair[1]])) # [{(88, Home)}, {(nan, nan)}, {('Dog', 'Mammal'), ('Bear', 'Mammal'), ('Salmon', 'Fish')}]
        if len(zets) > 1:
            # put in new col
            for zet_i, zet in enumerate(zets, start=1):
                col_name_0 = f'{pair[0]}_{zet_i}'
                col_name_1 = f'{pair[1]}_{zet_i}'
                if col_name_0 not in _df.columns:
                    _df[col_name_0] = nan
                    _df[col_name_1] = nan
                _df.loc[0, col_name_0] = zet[0]
                _df.loc[0, col_name_1] = zet[1]
        else:
            _zet = zets.pop() # (88, Home)
            _df.loc[0, pair[0]   '_1'] = _zet[0]
            _df.loc[0, pair[1]   '_1'] = _zet[1]
        _df.loc[0, 'id'] = i
        _df.loc[0, 'sid'] = s
        _df = _df.drop(pair[0], axis=1)
        _df = _df.drop(pair[1], axis=1)
    new_rows.append(_df.loc[[0]])

df_f = pd.concat(new_rows, ignore_index=True).dropna(axis=1, how='all')
print(df_f)
  • Related