Home > Mobile >  How to use one-hot encoding or get_dummies for pyspark with lists as values in column?
How to use one-hot encoding or get_dummies for pyspark with lists as values in column?

Time:05-12

I have a spark dataframe:

 ----------- -------------------- 
|columnIndex|               lists|
 ----------- -------------------- 
|          1|[1.0,2.0]           |
|          2|[2.0]               |
|          3|[1.0]               |
|          4|[1.0,3.0]           |
 ----------- -------------------- 

I need to get the following after one-hot encoding or get_dummies:

 ----------- -------------------- --- --- --- 
|columnIndex|               lists|1.0|2.0|3.0|
 ----------- -------------------- --- --- --- 
|          1|[1.0,2.0]           |  1|  1|  0|
|          2|[2.0]               |  0|  1|  0|
|          3|[1.0]               |  1|  0|  0|
|          4|[1.0,3.0]           |  1|  0|  1|
 ----------- -------------------- --- --- --- 

I have tried CountVectorizer(), but could not get the needed output. This example is a sample for the data that I have.

CodePudding user response:

Here's a solution for one-hot encoding from a column of lists of categories.

In Pandas

import pandas as pd

data = {'col1': list(range(1,5)), 'lists': [[1.0,2.0], [2.0],[1.0],[1.0,3.0]]}
df = pd.DataFrame.from_dict(data)
df
# Out:
#    col1       lists
# 0     1  [1.0, 2.0]
# 1     2       [2.0]
# 2     3       [1.0]
# 3     4  [1.0, 3.0]

s = df['lists'].explode()
df[['col1']].join(pd.crosstab(s.index, s))
# Out: 
#    col1  1.0  2.0  3.0
# 0     1    1    1    0
# 1     2    0    1    0
# 2     3    1    0    0
# 3     4    1    0    1

In Pyspark

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder.getOrCreate()
sparkDF = spark.createDataFrame(df) # spark is the Spark session
sparkDF.show()
# Out:
#  ---- ---------- 
# |col1|     lists|
#  ---- ---------- 
# |   1|[1.0, 2.0]|
# |   2|     [2.0]|
# |   3|     [1.0]|
# |   4|[1.0, 3.0]|
#  ---- ---------- 

sparkDF2 = sparkDF.select(sparkDF.col1,explode(sparkDF.lists).alias('newcol'))
sparkDF2.show()
# Out:
#  ---- ------ 
# |col1|newcol|
#  ---- ------ 
# |   1|   1.0|
# |   1|   2.0|
# |   2|   2.0|
# |   3|   1.0|
# |   4|   1.0|
# |   4|   3.0|
#  ---- ------ 

sparkDF.join(sparkDF2.crosstab('col1', 'newcol').withColumnRenamed('col1_newcol','col1'), 'col1').show()
# Out:
#  ---- ---------- --- --- --- 
# |col1|     lists|1.0|2.0|3.0|
#  ---- ---------- --- --- --- 
# |   1|[1.0, 2.0]|  1|  1|  0|
# |   2|     [2.0]|  0|  1|  0|
# |   3|     [1.0]|  1|  0|  0|
# |   4|[1.0, 3.0]|  1|  0|  1|
#  ---- ---------- --- --- --- 

Note that this assumes that the lists contain unique categories, if a category is repeated in a list then its count will appear due to crosstab.

For example, if sparkDF is

 ---- --------------- 
|col1|          lists|
 ---- --------------- 
|   1|     [1.0, 2.0]|
|   2|          [2.0]|
|   3|          [1.0]|
|   4|[1.0, 3.0, 3.0]|
 ---- --------------- 
           ^^^^^^^^^^

then the result is:

 ---- --------------- --- --- --- 
|col1|          lists|1.0|2.0|3.0|
 ---- --------------- --- --- --- 
|   1|     [1.0, 2.0]|  1|  1|  0|
|   2|          [2.0]|  0|  1|  0|
|   3|          [1.0]|  1|  0|  0|
|   4|[1.0, 3.0, 3.0]|  1|  0|  2|
 ---- --------------- --- --- --- 
                              ^^^^

This can be adjusted with a simple transformation.

CodePudding user response:

import more_itertools as mit
import numpy as np
import pandas as pd

df = pd.DataFrame({'columnIndex': [1, 2, 3, 4], 'lists': [[1.0, 2.0], [2.0], [1.0], [1.0, 3.0]]})
df = pd.concat([df, pd.DataFrame(columns=[1, 2, 3])])
df[[1, 2, 3]] = [0, 0, 0]

for i in range(0, len(df['lists'])):
    index = list(mit.locate([1, 2, 3], lambda x: x in df.loc[i, 'lists']))
    index = np.array(index)   1
    df.loc[i, index] = 1

Output

   columnIndex       lists  1  2  3
0          1.0  [1.0, 2.0]  1  1  0
1          2.0       [2.0]  0  1  0
2          3.0       [1.0]  1  0  0
3          4.0  [1.0, 3.0]  1  0  1

I can offer as an option based on pandas. First, the missing columns with values of 0 are created. Further, based on the column numbers, indexes are obtained for setting values. Note that 1 is added to the indexes.

  • Related