Home > Back-end >  Spark CaseWhen as Map with ADTs
Spark CaseWhen as Map with ADTs

Time:12-04

This question is a follow-up for this one.

Quick context reminder :

Spark's CaseWhen takes a Seq[(Expression, Expression)] with the first Expression being the condition and the second, the value to put if this condition is satisfied :

CaseWhen(
    branches: Seq[(Expression, Expression)],
    elseValue: Option[Expression] = None): ...

I want to be able to do Spark's CaseWhen using a Map object.

The Map can be simple such as :

val spec = Map(
    ($"column_one" === 1) -> lit(2),
    ($"column_one" === 2 && $"column_two" === 1) -> lit(1),
    ($"column_one" === 3) -> lit(4),
)

It also can be nested and simple at the same time:

val spec: Map[Column, Any] = Map(
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map(
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        ),
        ($"column_two" === 8) -> lit(524)
    ),
    ($"column_one" === 3) -> lit(4)
)

So, using the answer given by @aminmal on the other question I made this :

sealed trait ConditionValue  {
  def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)]
}
object ConditionValue {

  object implicits{

    implicit def test(condition: Column, value: Column): ConditionValue = {
        print("test")
        SingleLevelCaseWhen(Map(condition -> value))
        
    }
    
    implicit def testTuple(conditionValue: (Column, Column)): ConditionValue = {
        print("testTuple")
        SingleLevelCaseWhen(Map(conditionValue))
        
    }
        
    implicit def testNested(spec: Map[Column, ConditionValue]): ConditionValue = {
        print("testNested")
        NestedCaseWhen(spec)
        
    }
        
    implicit def testMap(spec: Map[Column, Column]): ConditionValue = {
        print("testMap")
        SingleLevelCaseWhen(spec)
        
    }
    
    implicit def expressionToColumn(expr: Expression): Column = new Column(expr)

    implicit def columnToExpression(col: Column): Expression = col.expr
  }
    
  import implicits._

  final case class SingleLevelCaseWhen(specificationMap: Map[Column, Column]) extends ConditionValue{
    override def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)] =
      specificationMap.map(x => (x._1.expr, x._2.expr)).toSeq
  }
  
  final case class NestedCaseWhen(specificationMap: Map[Column, ConditionValue]) extends ConditionValue{
    override def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)] =
      specificationMap.mapValues(_.enumerate(reduceFunc)).map{
        case (outerCondition, innerExpressions) => innerExpressions.map{
          case (innerCondition, innerValue) =>
            val conditions: Expression = reduceFunc(outerCondition, innerCondition)
            conditions -> innerValue

        }
      }.reduce(_    _)
  }

}

My problem now, is how to convert Map objects to ConditionValue Objects. As you can see, I've provided on the code some implicits :

  • test converts a condition, value params to a ConditionValue object. Not sure this one is useful
  • testTuple converts a (condition, value) tuple params to a ConditionValue object. Not sure this one is useful either
  • testMap converts a single-level Map[Column, Column] to a ConditionValue object
  • testNested converts a nested Map to a ConditionValue object

It works great with single-level Maps :

import ConditionValue.implicits._
val spec = Map(
    ($"column_one" === 1) -> lit(2),
    ($"column_one" === 2 && $"column_two" === 1) -> lit(1),
    ($"column_one" === 3) -> lit(4)
)
val d: ConditionValue= spec
>> d: ConditionValue= SingleLevelCaseWhen(Map((column_one = 1) -> 2, ((column_one = 2) AND (column_two = 1)) -> 1, (column_one = 3) -> 4))

Also, it works with nested only Map :

val spec = Map[Column, ConditionValue](
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map[Column, ConditionValue](
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        )
    )
)
val d: ConditionValue= spec
>>d: ConditionValue= NestedCaseWhen(Map((column_one = 1) -> SingleLevelCaseWhen(Map((column_two = 2) -> 54, (column_two = 5) -> 524)), (column_one = 2) -> NestedCaseWhen(Map((column_two = 7) -> SingleLevelCaseWhen(Map((whatever_column = whatever) -> 12, (whatever_column = whatever_two) -> 13))))))

There's two things bothering me right now :

  • It doesn't work with mixed Maps (Nested and Simple in the same time)
  • When dealing with Nested Maps, I have to explicitly specify the type of the Map Map[Column, ConditionValue]

Can anyone help me up with that ?

Thanks,

EDIT

I kinda "fixed" the It doesn't work with mixed Maps (Nested and Simple in the same time) problem with this implicits :

implicit def testVal(value: Column): ConditionValue = {
        testMap(Map(lit(true) -> value))
    }

Not sure If it's the best possible solution.

CodePudding user response:

I think you can solve the second issue by using type inference. For example, you can define a function that takes a Map[Column, Any] and returns a ConditionValue:

def mapToConditionValue(spec: Map[Column, Any]): ConditionValue = {
  spec.map {
    case (condition, value: Column) => (condition, SingleLevelCaseWhen(Map(condition -> value)))
    case (condition, value: Map[Column, Any]) => (condition, mapToConditionValue(value))
  }
  NestedCaseWhen(spec)
}

Then you can call it like this:

val spec: Map[Column, Any] = Map(
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map(
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        ),
        ($"column_two" === 8) -> lit(524)
    ),
    ($"column_one" === 3) -> lit(4)
)
val d: ConditionValue = mapToConditionValue(spec)

This should work for both nested and single-level maps.

Edit: it should be possible to do this with type erasure. You can use the scala.reflect.runtime.universe package to get the type of the value in the map and then use pattern matching to determine what type of ConditionValue to return.

Here's an example of how you can use type erasure to determine the type of the value in the map and then return the appropriate ConditionValue:

import scala.reflect.runtime.universe._

def mapToConditionValue(spec: Map[Column, Any]): ConditionValue = {
  spec.map {
    case (condition, value) =>
      val tpe = value.getClass.getTypeName
      tpe match {
        case "scala.collection.immutable.Map" =>
          (condition, mapToConditionValue(value.asInstanceOf[Map[Column, Any]]))
        case "org.apache.spark.sql.Column" =>
          (condition, SingleLevelCaseWhen(Map(condition -> value.asInstanceOf[Column])))
        case _ =>
          throw new IllegalArgumentException(s"Unsupported type: $tpe")
      }
  }
  NestedCaseWhen(spec)
}

val spec: Map[Column, Any] = Map(
    ($"column_one" === 1) -> Map(
        ($"column_two" === 2) -> lit(54),
        ($"column_two" === 5) -> lit(524)
    ),
    ($"column_one" === 2) -> Map(
        ($"column_two" === 7) -> Map(
            ($"whatever_column" === "whatever") -> lit(12),
            ($"whatever_column" === "whatever_two") -> lit(13)
        ),
        ($"column_two" === 8) -> lit(524)
    ),
    ($"column_one" === 3) -> lit(4)
)
val d: ConditionValue = mapToConditionValue(spec)
  • Related