Home > Software design >  How can I flatten a multi-dimensional CSV file using pandas?
How can I flatten a multi-dimensional CSV file using pandas?

Time:05-21

I have a CSV that has two dimensions that I'm reading into pandas

import pandas as pd

# pretend I'm reading in a CSV here- I'm just dumping some of the data in directly so it's easier to follow here
columns = ["tract_id", "age_group", "nq(x)", "l(x)", "e(x)"]
input = [["1", "Under 1", "0.019736", "100000", "73.1"],
["1", "1-4", "0.008884", "98026", "73.5"],
["1", "5-14", "0.00618", "97156", "70.2"],
["1", "15-24", "0.008613", "96555", "60.6"],
["1", "25-34", "0.023923", "95723", "51.1"],
["1", "35-44", "0.020311", "93433", "42.2"],
["1", "45-54", "0.017889", "91536", "33"],
["1", "55-64", "0.084746", "89898", "23.5"],
["1", "65-74", "0.226782", "82280", "15.2"],
["1", "75-84", "0.505319", "63620", "8.2"],
["1", "85 and older", "1", "31472", "1.4"],
["2", "Under 1", "0.008647", "100000", "76.9"],
["2", "1-4", "0.009001", "99135", "76.5"],
["2", "5-14", "0.00143", "98243", "73.2"],
["2", "15-24", "0.004869", "98103", "63.3"],
["2", "25-34", "0.01039", "97625", "53.6"],
["2", "35-44", "0.026667", "96611", "44.1"],
["2", "45-54", "0.061996", "94034", "35.2"],
["2", "55-64", "0.133411", "88205", "27.2"],
["2", "65-74", "0.289017", "76437", "20.6"],
["2", "75-84", "0.282686", "54345", "16.9"],
["2", "85 and older", "1", "38983", "11.6"],
["3", "Under 1", "0.00426", "100000", "75.4"],
["3", "1-4", "0.001026", "99574", "74.7"],
["3", "5-14", "0.002392", "99472", "70.8"],
["3", "15-24", "0.002786", "99234", "61"],
["3", "25-34", "0.012019", "98957", "51.1"],
["3", "35-44", "0.025924", "97768", "41.7"],
["3", "45-54", "0.057895", "95234", "32.7"],
["3", "55-64", "0.113929", "89720", "24.4"],
["3", "65-74", "0.197795", "79498", "16.9"],
["3", "75-84", "0.32413", "63774", "9.8"],
["3", "85 and older", "1", "43103", "2.1"]]

df = pd.DataFrame(input, columns = columns)
print(df.to_string())

Output:

   tract_id     age_group     nq(x)    l(x)  e(x)
0         1       Under 1  0.019736  100000  73.1
1         1           1-4  0.008884   98026  73.5
2         1          5-14   0.00618   97156  70.2
3         1         15-24  0.008613   96555  60.6
4         1         25-34  0.023923   95723  51.1
5         1         35-44  0.020311   93433  42.2
6         1         45-54  0.017889   91536    33
7         1         55-64  0.084746   89898  23.5
8         1         65-74  0.226782   82280  15.2
9         1         75-84  0.505319   63620   8.2
10        1  85 and older         1   31472   1.4
11        2       Under 1  0.008647  100000  76.9
12        2           1-4  0.009001   99135  76.5
13        2          5-14   0.00143   98243  73.2
14        2         15-24  0.004869   98103  63.3
15        2         25-34   0.01039   97625  53.6
16        2         35-44  0.026667   96611  44.1
17        2         45-54  0.061996   94034  35.2
18        2         55-64  0.133411   88205  27.2
19        2         65-74  0.289017   76437  20.6
20        2         75-84  0.282686   54345  16.9
21        2  85 and older         1   38983  11.6
22        3       Under 1   0.00426  100000  75.4
23        3           1-4  0.001026   99574  74.7
24        3          5-14  0.002392   99472  70.8
25        3         15-24  0.002786   99234    61
26        3         25-34  0.012019   98957  51.1
27        3         35-44  0.025924   97768  41.7
28        3         45-54  0.057895   95234  32.7
29        3         55-64  0.113929   89720  24.4
30        3         65-74  0.197795   79498  16.9
31        3         75-84   0.32413   63774   9.8
32        3  85 and older         1   43103   2.1

What I want to do is flatten the dataset such that for each tract_id I have one row containing every permutation of age_group and the other inputs- so the columns would look like this:

new_columns = ["tract_id"]
col1 = ["under_1", "1_4", "5_14", "15_24", "25_34", "35_44", "45_54", "55_64", "65_74", "75_84", "85_older"]
col2 = ["nq(x)", "l(x)", "e(x)"]
for l1 in col1:
    for l2 in col2:
        new_columns.append(f"{l1}_{l2}")
print(new_columns

Output:

['tract_id', 'under_1_nq(x)', 'under_1_l(x)', 'under_1_e(x)', '1_4_nq(x)', '1_4_l(x)', '1_4_e(x)', '5_14_nq(x)', '5_14_l(x)', '5_14_e(x)', '15_24_nq(x)', '15_24_l(x)', '15_24_e(x)', '25_34_nq(x)', '25_34_l(x)', '25_34_e(x)', '35_44_nq(x)', '35_44_l(x)', '35_44_e(x)', '45_54_nq(x)', '45_54_l(x)', '45_54_e(x)', '55_64_nq(x)', '55_64_l(x)', '55_64_e(x)', '65_74_nq(x)', '65_74_l(x)', '65_74_e(x)', '75_84_nq(x)', '75_84_l(x)', '75_84_e(x)', '85_older_nq(x)', '85_older_l(x)', '85_older_e(x)']

You get the idea. I know how to do this in a brute-forcey way, but my intuition is that there has to be some pandas method I could be using that'd be both less verbose and less computationally intensive (the full CSV is fairly large)

all_data = []
for tract in df["tract_id"].unique():
    f = df[df["tract_id"] == tract] 
    k = df[df["tract_id"] == tract][col2].values.tolist()
    flat_list = [tract]   [item for sublist in k for item in sublist]
    all_data.append(flat_list)

new_df = pd.DataFrame(all_data, columns=new_columns)
print(new_df.to_string())

Output:

  tract_id under_1_nq(x) under_1_l(x) under_1_e(x) 1_4_nq(x) 1_4_l(x) 1_4_e(x) 5_14_nq(x) 5_14_l(x) 5_14_e(x) 15_24_nq(x) 15_24_l(x) 15_24_e(x) 25_34_nq(x) 25_34_l(x) 25_34_e(x) 35_44_nq(x) 35_44_l(x) 35_44_e(x) 45_54_nq(x) 45_54_l(x) 45_54_e(x) 55_64_nq(x) 55_64_l(x) 55_64_e(x) 65_74_nq(x) 65_74_l(x) 65_74_e(x) 75_84_nq(x) 75_84_l(x) 75_84_e(x) 85_older_nq(x) 85_older_l(x) 85_older_e(x)
0        1      0.019736       100000         73.1  0.008884    98026     73.5    0.00618     97156      70.2    0.008613      96555       60.6    0.023923      95723       51.1    0.020311      93433       42.2    0.017889      91536         33    0.084746      89898       23.5    0.226782      82280       15.2    0.505319      63620        8.2              1         31472           1.4
1        2      0.008647       100000         76.9  0.009001    99135     76.5    0.00143     98243      73.2    0.004869      98103       63.3     0.01039      97625       53.6    0.026667      96611       44.1    0.061996      94034       35.2    0.133411      88205       27.2    0.289017      76437       20.6    0.282686      54345       16.9              1         38983          11.6
2        3       0.00426       100000         75.4  0.001026    99574     74.7   0.002392     99472      70.8    0.002786      99234         61    0.012019      98957       51.1    0.025924      97768       41.7    0.057895      95234       32.7    0.113929      89720       24.4    0.197795      79498       16.9     0.32413      63774        9.8              1         43103           2.1

So, how would I go about doing this more efficiently in pandas? I tried searching for the right method for this but I guess I didn't have the right combination of search parameters.

CodePudding user response:

You can use melt and then pivot:

df = df.melt(id_vars=['tract_id', 'age_group'])
df = (df.assign(col_names=df['age_group']   '_'   df['variable'])
      .pivot(index='tract_id',columns='col_names', values='value'))
df
Out[1]: 
col_names 1-4_e(x) 1-4_l(x) 1-4_nq(x) 15-24_e(x) 15-24_l(x) 15-24_nq(x)  \
tract_id                                                                  
1             73.5    98026  0.008884       60.6      96555    0.008613   
2             76.5    99135  0.009001       63.3      98103    0.004869   
3             74.7    99574  0.001026         61      99234    0.002786   

col_names 25-34_e(x) 25-34_l(x) 25-34_nq(x) 35-44_e(x)  ... 65-74_nq(x)  \
tract_id                                                ...               
1               51.1      95723    0.023923       42.2  ...    0.226782   
2               53.6      97625     0.01039       44.1  ...    0.289017   
3               51.1      98957    0.012019       41.7  ...    0.197795   

col_names 75-84_e(x) 75-84_l(x) 75-84_nq(x) 85 and older_e(x)  \
tract_id                                                        
1                8.2      63620    0.505319               1.4   
2               16.9      54345    0.282686              11.6   
3                9.8      63774     0.32413               2.1   

col_names 85 and older_l(x) 85 and older_nq(x) Under 1_e(x) Under 1_l(x)  \
tract_id                                                                   
1                     31472                  1         73.1       100000   
2                     38983                  1         76.9       100000   
3                     43103                  1         75.4       100000   

col_names Under 1_nq(x)  
tract_id                 
1              0.019736  
2              0.008647  
3               0.00426  

[3 rows x 33 columns]
  • Related