Home > OS >  Using case when in Spark Scala
Using case when in Spark Scala

Time:11-16

I am new with Apache Spark, I am using Scala to work with it.

I have some doubts and one of them is how to use case when for my example. I have to work with distributed systems and what I have to do is classify some instances. To do this I have a DataFrame as you can see here:

 -------------------- ------------ 
|       group        |   info     |
 -------------------- ------------ 
|gr=nat, dfn         |   x1betdfn |
|gr=pjc, ntp         |   x2b1trkn |
|gr=ntp, ntt         |   x3b2td3n |
 -------------------- ------------ 

So through the column called group I have to filter or do a regex to get just the three letters after gr= (would be nat, pjc or ntp) and then write it in a new column.

What is the most efficient way to do this?

This is what I am doing:

val df2 =
  df.withColumn("tgroup", when(col("group").match === "nat", "nat_1_nm")
  .when(col("group").match === "pjc", "pjc_determined")
  .when(col("group").match === "ntp", "ntp_dway")
  .otherwise("Unknown"))

But is not working. I am trying to get this:

 -------------------- ------------ ---------------- 
|       group        |   info     |                |
 -------------------- ------------ ---------------- 
|gr=nat, dfn         |   x1betdfn | nat_1_nm       |
|gr=pjc, ntp         |   x2b1trkn | pjc_determined |
|gr=e2p, ntt         |   x3b2td3n | Unknown        |
|gr=ntp, ntt         |   x3b2td3n | ntp_dway       |
 -------------------- ------------ ---------------- 

What am I doing wrong? Thanks in advance.

CodePudding user response:

Try to do something like this, using rlike. Just build your own regexp:

val df2 =
      df.withColumn("tgroup",when(col("group").rlike("(?i)^*nat,\\s[a-zA-Z]*$"), "nat_1_nm").
        otherwise(when(col("group").rlike("(?i)^*pjc,\\s[a-zA-Z]*$"), "pjc_determined").
          otherwise(when(col("group").rlike("(?i)^*ntp,\\s[a-zA-Z]*$"), "ntp_dway").otherwise("Unknown"))))

I tried it with an example and I am getting your expected output:

import spark.implicits._
    val df = Seq(
      ("gr=nat, dfn"),
      ("gr=pjc, ntp"),
      ("gr=e2p, ntt"),
      ("gr=ntp, ntt")
    ).toDF("group")

CodePudding user response:

If the values of group column always of the form gr=XXX, using startsWith method should be enough in your case :

val df2 = df.withColumn(
  "tgroup",
  when(col("group").startsWith("gr=nat"), "nat_1_nm")
    .when(col("group").startsWith("gr=pjc"), "pjc_determined")
    .when(col("group").startsWith("gr=ntp"), "ntp_dway")
    .otherwise("Unknown")
)

df2.show
// ----------- -------- -------------- 
//|      group|    info|        tgroup|
// ----------- -------- -------------- 
//|gr=nat, dfn|x1betdfn|      nat_1_nm|
//|gr=pjc, ntp|x2b1trkn|pjc_determined|
//|gr=e2p, ntt|x3b2td3n|       Unknown|
//|gr=ntp, ntt|x3b2td3n|      ntp_dway|
// ----------- -------- -------------- 

You may also want to extract the value XXX after gr= using regexp_extract function:

val df2 = df.withColumn(
    "tgroup",
    regexp_extract(col("group"), "^gr=(.{3}),.*", 1)
).withColumn(
    "tgroup",
    when(col("tgroup") === "nat", "nat_1_nm")
      .when(col("tgroup") === "pjc", "pjc_determined")
      .when(col("tgroup") === "ntp", "ntp_dway")
      .otherwise("Unknown")
)
  • Related