Home > database >  Reshape THEN explode an array in a spark dataframe
Reshape THEN explode an array in a spark dataframe

Time:05-10

I have a spark dataframe in the following format:

 -------------------- -------------------- 
|            profiles|__record_timestamp__|
 -------------------- -------------------- 
|[0, 1, 1, 1, 3, 1...| 1651737406300000000|
|[1, 0, 1, 2, 1, 0...| 1651736986300000000|
|[2, 1, 3, 1, 0, 0...| 1651737232300000000|
|[1, 1, 3, 1, 2, 0...| 1651737352300000000|
|[0, 1, 0, 0, 0, 1...| 1651737412300000000|
|[0, 1, 0, 1, 1, 1...| 1651737142300000000|
|[3, 1, 0, 1, 1, 1...| 1651737574300000000|
|[2, 0, 3, 1, 0, 1...| 1651737178300000000|
|[0, 0, 0, 1, 2, 1...| 1651737364300000000|
|[0, 0, 1, 0, 0, 0...| 1651737280300000000|
|[1, 0, 0, 1, 0, 0...| 1651737196300000000|
|[0, 0, 0, 0, 0, 1...| 1651737436300000000|
|[8, 2, 0, 0, 0, 3...| 1651737166300000000|
|[4, 0, 1, 2, 0, 0...| 1651737538300000000|
|[1, 2, 0, 1, 1, 0...| 1651737052300000000|
|[1, 3, 0, 1, 0, 1...| 1651737082300000000|
|[1, 1, 1, 2, 0, 0...| 1651737100300000000|
|[1, 0, 0, 0, 1, 0...| 1651736980300000000|
|[1, 1, 0, 0, 0, 0...| 1651737040300000000|
|[1, 0, 1, 0, 1, 1...| 1651737004300000000|
 -------------------- -------------------- 
only showing top 20 rows

The array in profiles is 91260 long. I need to first reshape it into 90*1024 arrays and then I plan to explode, each with an integer from 0-89 matching its place in the original array.

Any ideas how this can be done? f.explode() will only give me 1 element per column, split() only seems to work on strings, and I can't find a reshape, or array_split() function or anything. TIA

CodePudding user response:

I think what you are looking for is posexplode:

from pyspark.sql import Row
eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
eDF.select(posexplode(eDF.intlist)).collect()
[Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]

CodePudding user response:

You can do so by creating an array of the required slices and then using posexplode function. For practicality I've created a smaller df to show how it works

import pyspark.sql.functions as F
from pyspark.sql import Row

SPLIT_COUNT = 3
SPLIT_SIZE = 5
ARRAY_SIZE = SPLIT_COUNT * SPLIT_SIZE

df = spark.createDataFrame([
    Row(profiles=list(range(ARRAY_SIZE)), timestamp=12345)
    ])

slices = [F.slice(F.col('profiles'), i * SPLIT_SIZE   1, SPLIT_SIZE) for i in range(SPLIT_COUNT)]

df.select(
    F.posexplode(F.array(*slices)),
    F.col('timestamp')
).show()


 --- -------------------- ---------                                             
|pos|                 col|timestamp|
 --- -------------------- --------- 
|  0|     [0, 1, 2, 3, 4]|    12345|
|  1|     [5, 6, 7, 8, 9]|    12345|
|  2|[10, 11, 12, 13, 14]|    12345|
 --- -------------------- --------- 
  • Related