Home > Back-end >  how to increment field in pyspark?
how to increment field in pyspark?

Time:04-23

I have to generate an incremental number, based on a condition, for example, I have the following dataframe:

 -- |----------- ------ ---  
seq |cod         |trans |ant 
 -- ------------- ------ --- 
01 |05           |00    |1   
02 |05           |01    |00  
03 |05           |02    |01  
04 |05           |05    |02  
05 |05           |00    |05  
06 |05           |01    |00  
07 |05           |02    |01  
08 |05           |05    |02  
09 |05           |07    |05  
10 |05           |00    |07  
11 |05           |01    |00  
12 |05           |02    |01  
13 |05           |05    |02  

I use:

global cont
df1 = df.withColumn("id",when(col("trans ").cast("int") < col("ant").cast("int"),cont 1).otherwise(cont))

With that I get the following output:

 -- |----------- ------ ---  ---
seq |cod          |trans |ant |id 
 -- ------------- ------ ---  ---
01 |05           |00    |1    |1  
02 |05           |01    |00   |0  
03 |05           |02    |01   |0  
04 |05           |05    |02   |0 
05 |05           |00    |05   |1  
06 |05           |01    |00   |0  
07 |05           |02    |01   |0  
08 |05           |05    |02   |0  
09 |05           |07    |05   |0  
10 |05           |00    |07   |1  
11 |05           |01    |00   |0  
12 |05           |02    |01   |0 
13 |05           |05    |02   |0  

but i expect something like:

 -- |----------- ------ ---  ---
seq |cod          |trans |ant |id 
 -- ------------- ------ ---  ---
01 |05           |00    |1    |1  
02 |05           |01    |00   |1  
03 |05           |02    |01   |1  
04 |05           |05    |02   |1  
05 |05           |00    |05   |2  
06 |05           |01    |00   |2  
07 |05           |02    |01   |2  
08 |05           |05    |02   |2  
09 |05           |07    |05   |2  
10 |05           |00    |07   |3  
11 |05           |01    |00   |3  
12 |05           |02    |01   |3  
13 |05           |05    |02   |3  

Does anyone have any suggestions to help me?

CodePudding user response:

You need a cumulative sum, for which you can use a window function here:

from pyspark.sql import functions as F, Window as W
w = (W.partitionBy("cod").orderBy(F.col("seq").cast("int"))
     .rangeBetween(W.unboundedPreceding,W.currentRow))
df1 = df.withColumn("id",F.sum(
        F.when(F.col("trans").cast("int") < F.col("ant").cast("int"),1).otherwise(0)
                              ).over(w)
                   )

df1.show()
 --- --- ----- --- --- 
|seq|cod|trans|ant| id|
 --- --- ----- --- --- 
| 01| 05|   00|  1|  1|
| 02| 05|   01| 00|  1|
| 03| 05|   02| 01|  1|
| 04| 05|   05| 02|  1|
| 05| 05|   00| 05|  2|
| 06| 05|   01| 00|  2|
| 07| 05|   02| 01|  2|
| 08| 05|   05| 02|  2|
| 09| 05|   07| 05|  2|
| 10| 05|   00| 07|  3|
| 11| 05|   01| 00|  3|
| 12| 05|   02| 01|  3|
| 13| 05|   05| 02|  3|
 --- --- ----- --- --- 
  • Related