I have a dataframe like below:
groupid | datacol1 | datacol2 | datacol3 | datacol* | corr_co |
---|---|---|---|---|---|
00001 | 1 | 2 | 3 | 4 | 5 |
00001 | 2 | 3 | 4 | 6 | 5 |
00002 | 4 | 2 | 1 | 7 | 5 |
00002 | 8 | 9 | 3 | 2 | 5 |
00003 | 7 | 1 | 2 | 3 | 5 |
00003 | 3 | 5 | 3 | 1 | 5 |
I want to calculate the correlation between datacol* columns and corr_col column by each groupid.
So I used the following spark scala codes as below:
df.groupby("groupid").agg(functions.corr("datacol1","corr_col"),functions.corr("datacol2","corr_col"),functions.corr("datacol3","corr_col"),.....)
This is very inefficient,is there an efficient way to do this?
[EDIT] I mean if I have 30 data_cols columns, I need to input 30 times functions.corr to calculate correlation.
I have searched, it seems that functions.corr doesn't accept a List/Array parameter, and df.agg doesn't accept a function to be parameter.
So any way to do this efficiently? I prefer to use spark scala API to do this.
Thanks
CodePudding user response:
I have found one solution. The steps are as below:
- use following codes to create a mutable data frame df_all. df.groupby("groupid").agg(functions.corr("datacol1","corr_col")
- iterate all remaining data_col columns, create a temp data frame for this iteration. In this iteration, use df_all to join the temp data frame on the groupid column, then drop duplicated groupid column.
- after the iteration, I will get the dataframe which contains all correlation data. I need to verify the data.
UPDATE: Found the efficient way to do it. generate a list of function which calculate the correlation, like List(corr(),corr(),...,corr()). Then pass this list into agg function to generate the correlation data frame.