This question is a follow-up for this one.
Quick context reminder :
Spark's CaseWhen
takes a Seq[(Expression, Expression)]
with the first Expression being the condition and the second, the value to put if this condition is satisfied :
CaseWhen(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression] = None): ...
I want to be able to do Spark's CaseWhen using a Map object.
The Map can be simple such as :
val spec = Map(
($"column_one" === 1) -> lit(2),
($"column_one" === 2 && $"column_two" === 1) -> lit(1),
($"column_one" === 3) -> lit(4),
)
It also can be nested and simple at the same time:
val spec: Map[Column, Any] = Map(
($"column_one" === 1) -> Map(
($"column_two" === 2) -> lit(54),
($"column_two" === 5) -> lit(524)
),
($"column_one" === 2) -> Map(
($"column_two" === 7) -> Map(
($"whatever_column" === "whatever") -> lit(12),
($"whatever_column" === "whatever_two") -> lit(13)
),
($"column_two" === 8) -> lit(524)
),
($"column_one" === 3) -> lit(4)
)
So, using the answer given by @aminmal on the other question I made this :
sealed trait ConditionValue {
def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)]
}
object ConditionValue {
object implicits{
implicit def test(condition: Column, value: Column): ConditionValue = {
print("test")
SingleLevelCaseWhen(Map(condition -> value))
}
implicit def testTuple(conditionValue: (Column, Column)): ConditionValue = {
print("testTuple")
SingleLevelCaseWhen(Map(conditionValue))
}
implicit def testNested(spec: Map[Column, ConditionValue]): ConditionValue = {
print("testNested")
NestedCaseWhen(spec)
}
implicit def testMap(spec: Map[Column, Column]): ConditionValue = {
print("testMap")
SingleLevelCaseWhen(spec)
}
implicit def expressionToColumn(expr: Expression): Column = new Column(expr)
implicit def columnToExpression(col: Column): Expression = col.expr
}
import implicits._
final case class SingleLevelCaseWhen(specificationMap: Map[Column, Column]) extends ConditionValue{
override def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)] =
specificationMap.map(x => (x._1.expr, x._2.expr)).toSeq
}
final case class NestedCaseWhen(specificationMap: Map[Column, ConditionValue]) extends ConditionValue{
override def enumerate(reduceFunc: (Column, Column) => Column): Seq[(Expression, Expression)] =
specificationMap.mapValues(_.enumerate(reduceFunc)).map{
case (outerCondition, innerExpressions) => innerExpressions.map{
case (innerCondition, innerValue) =>
val conditions: Expression = reduceFunc(outerCondition, innerCondition)
conditions -> innerValue
}
}.reduce(_ _)
}
}
My problem now, is how to convert Map
objects to ConditionValue
Objects. As you can see, I've provided on the code some implicits :
test
converts acondition, value
params to aConditionValue
object. Not sure this one is usefultestTuple
converts a(condition, value)
tuple params to aConditionValue
object. Not sure this one is useful eithertestMap
converts a single-levelMap[Column, Column]
to aConditionValue
objecttestNested
converts a nested Map to aConditionValue
object
It works great with single-level Maps :
import ConditionValue.implicits._
val spec = Map(
($"column_one" === 1) -> lit(2),
($"column_one" === 2 && $"column_two" === 1) -> lit(1),
($"column_one" === 3) -> lit(4)
)
val d: ConditionValue= spec
>> d: ConditionValue= SingleLevelCaseWhen(Map((column_one = 1) -> 2, ((column_one = 2) AND (column_two = 1)) -> 1, (column_one = 3) -> 4))
Also, it works with nested only Map :
val spec = Map[Column, ConditionValue](
($"column_one" === 1) -> Map(
($"column_two" === 2) -> lit(54),
($"column_two" === 5) -> lit(524)
),
($"column_one" === 2) -> Map[Column, ConditionValue](
($"column_two" === 7) -> Map(
($"whatever_column" === "whatever") -> lit(12),
($"whatever_column" === "whatever_two") -> lit(13)
)
)
)
val d: ConditionValue= spec
>>d: ConditionValue= NestedCaseWhen(Map((column_one = 1) -> SingleLevelCaseWhen(Map((column_two = 2) -> 54, (column_two = 5) -> 524)), (column_one = 2) -> NestedCaseWhen(Map((column_two = 7) -> SingleLevelCaseWhen(Map((whatever_column = whatever) -> 12, (whatever_column = whatever_two) -> 13))))))
There's two things bothering me right now :
- It doesn't work with mixed Maps (Nested and Simple in the same time)
- When dealing with Nested Maps, I have to explicitly specify the type of the Map
Map[Column, ConditionValue]
Can anyone help me up with that ?
Thanks,
EDIT
I kinda "fixed" the It doesn't work with mixed Maps (Nested and Simple in the same time)
problem with this implicits :
implicit def testVal(value: Column): ConditionValue = {
testMap(Map(lit(true) -> value))
}
Not sure If it's the best possible solution.
CodePudding user response:
I think you can solve the second issue by using type inference. For example, you can define a function that takes a Map[Column, Any] and returns a ConditionValue:
def mapToConditionValue(spec: Map[Column, Any]): ConditionValue = {
spec.map {
case (condition, value: Column) => (condition, SingleLevelCaseWhen(Map(condition -> value)))
case (condition, value: Map[Column, Any]) => (condition, mapToConditionValue(value))
}
NestedCaseWhen(spec)
}
Then you can call it like this:
val spec: Map[Column, Any] = Map(
($"column_one" === 1) -> Map(
($"column_two" === 2) -> lit(54),
($"column_two" === 5) -> lit(524)
),
($"column_one" === 2) -> Map(
($"column_two" === 7) -> Map(
($"whatever_column" === "whatever") -> lit(12),
($"whatever_column" === "whatever_two") -> lit(13)
),
($"column_two" === 8) -> lit(524)
),
($"column_one" === 3) -> lit(4)
)
val d: ConditionValue = mapToConditionValue(spec)
This should work for both nested and single-level maps.
Edit: it should be possible to do this with type erasure. You can use the scala.reflect.runtime.universe package to get the type of the value in the map and then use pattern matching to determine what type of ConditionValue to return.
Here's an example of how you can use type erasure to determine the type of the value in the map and then return the appropriate ConditionValue:
import scala.reflect.runtime.universe._
def mapToConditionValue(spec: Map[Column, Any]): ConditionValue = {
spec.map {
case (condition, value) =>
val tpe = value.getClass.getTypeName
tpe match {
case "scala.collection.immutable.Map" =>
(condition, mapToConditionValue(value.asInstanceOf[Map[Column, Any]]))
case "org.apache.spark.sql.Column" =>
(condition, SingleLevelCaseWhen(Map(condition -> value.asInstanceOf[Column])))
case _ =>
throw new IllegalArgumentException(s"Unsupported type: $tpe")
}
}
NestedCaseWhen(spec)
}
val spec: Map[Column, Any] = Map(
($"column_one" === 1) -> Map(
($"column_two" === 2) -> lit(54),
($"column_two" === 5) -> lit(524)
),
($"column_one" === 2) -> Map(
($"column_two" === 7) -> Map(
($"whatever_column" === "whatever") -> lit(12),
($"whatever_column" === "whatever_two") -> lit(13)
),
($"column_two" === 8) -> lit(524)
),
($"column_one" === 3) -> lit(4)
)
val d: ConditionValue = mapToConditionValue(spec)