Home > Blockchain >  Scala/Spark; Add column to DataFrame that increments by 1 when a value is repeated in another column
Scala/Spark; Add column to DataFrame that increments by 1 when a value is repeated in another column

Time:02-03

I have a dataframe called rawEDI that looks something like this;

Line_number Segment
1 ST
2 BPT
3 SE
4 ST
5 BPT
6 N1
7 SE
8 ST
9 PTD
10 SE

Each row represents a line in a file. Each line is called a segment and is denoted by something called a segment identifier; a short string. Segments are grouped together in chunks that start with an ST segment identifier and end with an SE segment segment identifier. There can be any number of ST chunks in a given file and the size of each any ST chunk is not fixed.

I want to create a new column on the dataframe that represents numerically what ST group a given segment belongs to. This will allow me to use groupBy to perform aggregate operations across all ST segments without having to loop over each individual ST segment, which is too slow.

The final DataFrame would look like this;

Line_number Segment ST_Group
1 ST 1
2 BPT 1
3 SE 1
4 ST 2
5 BPT 2
6 N1 2
7 SE 2
8 ST 3
9 PTD 3
10 SE 3

In short, I want to create and populate a DataFrame column with a number that increments by one whenever the value "ST" appears in the Segment column.

I am using spark 2.3.2 and scala 2.11.8

My initial thought was to use iteration. I collected another DataFrame, df, that contained the starting and ending line_number for each segment, looking like this;

Start End
1 3
4 7
8 10

Then iterate over the rows of the dataframe and use them to populate the new column like this;

var st = 1
for (row <- df.collect()) {
    val start = row(0)
    val end  = row(1)
    var labelSTs = rawEDI.filter("line_number > = ${start}").filter("line_number <= ${end}").withColumn("ST_Group", lit(st))
    st = st   1

However, this yields an empty DataFrame. Additionally, the use of a for loop is time-prohibitive, taking over 20s on my machine for this. Achieving this result without the use of a loop would be huge, but a solution with a loop may also be acceptable if performant.

I have a hunch this can be accomplished using a udf or a Window, but I'm not certain how to attack that.

This

val func = udf((s:String) => if(s == "ST") 1 else 0)
var labelSTs = rawEDI.withColumn("ST_Group", func((col("segment")))

Only populates the column with 1 at each ST segment start.

And this

val w = Window.partitionBy("Segment").orderBy("line_number")
val labelSTs = rawEDI.withColumn("ST_Group", row_number().over(w)

Returns a nonsense dataframe.

CodePudding user response:

One way is to create an intermediate dataframe of "groups" that would tell you on which line each group starts and ends (sort of what you've already done), and then join it to the original table using greater-than/less-than conditions.

Sample data

scala> val input = Seq((1,"ST"),(2,"BPT"),(3,"SE"),(4,"ST"),(5,"BPT"),
                       (6,"N1"),(7,"SE"),(8,"ST"),(9,"PTD"),(10,"SE"))
                   .toDF("linenumber","segment")
scala> input.show(false)
 ---------- ------- 
|linenumber|segment|
 ---------- ------- 
|1         |ST     |
|2         |BPT    |
|3         |SE     |
|4         |ST     |
|5         |BPT    |
|6         |N1     |
|7         |SE     |
|8         |ST     |
|9         |PTD    |
|10        |SE     |
 ---------- ------- 

Create a dataframe for groups, using Window just as your hunch was telling you:

scala> val groups = input.where("segment='ST'")
                    .withColumn("endline",lead("linenumber",1) over Window.orderBy("linenumber"))
                    .withColumn("groupnumber",row_number() over Window.orderBy("linenumber"))
                    .withColumnRenamed("linenumber","startline")
                    .drop("segment")

scala> groups.show(false)
 --------- ----------- ------- 
|startline|groupnumber|endline|
 --------- ----------- ------- 
|1        |1          |4      |
|4        |2          |8      |
|8        |3          |null   |
 --------- ----------- ------- 

Join both to get the result

scala> input.join(groups,
                  input("linenumber") >= groups("startline") && 
                  (input("linenumber") < groups("endline") || groups("endline").isNull))
            .select("linenumber","segment","groupnumber") 
            .show(false)
 ---------- ------- ----------- 
|linenumber|segment|groupnumber|
 ---------- ------- ----------- 
|1         |ST     |1          |
|2         |BPT    |1          |
|3         |SE     |1          |
|4         |ST     |2          |
|5         |BPT    |2          |
|6         |N1     |2          |
|7         |SE     |2          |
|8         |ST     |3          |
|9         |PTD    |3          |
|10        |SE     |3          |
 ---------- ------- ----------- 

The only problem with this is Window.orderBy() on an unpartitioned dataframe, which would collect all data to a single partition and thus could be a killer.

CodePudding user response:

if you want just to add column with a number that increments by one whenever the value "ST" appears in the Segment column, you can filter lines with the ST segment in a separate dataframe,

var labelSTs = rawEDI.filter("segement == 'ST'");
// then group by ST and collect to list the linenumbers
var groupedDf = labelSTs.groupBy("Segment").agg(collect_list("Line_number").alias("Line_numbers"))
// now you need to flat back the data frame and log the line number index 
var flattedDf = groupedDf.select($"Segment", explode($"Line_numbers").as("Line_number"))   
// log the line_number index in your target column ST_Group
val withIndexDF = flattenedDF.withColumn("ST_Group", row_number().over(Window.partitionBy($"Segment").orderBy($"Line_number")))

and you have this as result:

 ------- ---------- ---------------- 
|Segment|Line_number|ST_Group       |
 ------- ---------- ---------------- 
|     ST|         1|               1|
|     ST|         4|               2|
|     ST|         8|               3|
 -------|----------|----------------|

then you concat this with other Segement in the initial dataframe.

CodePudding user response:

Found a more simpler way, add a column which will have 1 when the segment column value is ST, otherwise it will have 0. Then using Window function find the cummulative sum of that new column. This will give you the desired results.

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val rawEDI=Seq((1,"ST"),(2,"BPT"),(3,"SE"),(4,"ST"),(5,"BPT"),(6,"N1"),(7,"SE"),(8,"ST"),(9,"PTD"),(10,"SE")).toDF("line_number","segment")

val newDf=rawEDI.withColumn("ST_Group", ($"segment" === "ST").cast("bigint"))

val windowSpec = Window.orderBy("line_number")
newDf.withColumn("ST_Group", sum("ST_Group").over(windowSpec))
.show
 ----------- ------- -------- 
|line_number|segment|ST_Group|
 ----------- ------- -------- 
|          1|     ST|       1|
|          2|    BPT|       1|
|          3|     SE|       1|
|          4|     ST|       2|
|          5|    BPT|       2|
|          6|     N1|       2|
|          7|     SE|       2|
|          8|     ST|       3|
|          9|    PTD|       3|
|         10|     SE|       3|
 ----------- ------- -------- 
  • Related