Home > Enterprise >  Pass/Parameterize the function's return subtype
Pass/Parameterize the function's return subtype

Time:09-04

I have to deal with classes defined in a lib like this (can't change it):

trait parent {
  def joke: String
}
trait accountant {
  def number: Int
}
case class Person1(val joke: String, val number: Int, val age: Int, val isSmart: Boolean) extends parent with accountant
case class Person2(val joke: String, val number: Int, val age: Int, val isStrong: Boolean) extends parent with accountant
case class Person3(val joke: String, val age: Int, val isTall: Boolean) extends parent

I will need to do some calculation based on the fields value and emit a signal defined in the API like this (can't change it)

trait MySignal
case class GotPerson1() extends MySignal
case class GotPerson2() extends MySignal

Keep in mind that this is a scaled-down example, there is around 10 different PersonX, there is some common calculation to be done to some pairs/group of these Person types:

class DoingWork(message: parent) {

  type WithAge = {def age: Int, def number: Int}

  def dispatcher(msg: parent) = {
    msg match {
      case something: Person1 => doSomeMapping(something)
      case something: Person2 => doSomeMapping(something)
      case something: Person3 => doSomethingElse(something)
    }
  }

  def doSomeMapping(msg: WithAge): MySignal = {
    // do some common cal based on WithAge fields
    if(msg.number > 10 && msg.isInstanceOf[Person1]) GotPerson1()
    else if(msg.number > 10 && msg.isInstanceOf[Person2]) GotPerson2()
    else throw new RuntimeException()
  }

  def doSomethingElse(msg: Person3): Unit = println("blah")
}

to avoid replicating the code of common calculation, I introduced type WithAge but because I have to emit different signals based on the type, I have to test for type.

How can I avoid the isInstantOf test in doSomeMapping?

Can I pass the MySignal subtype as parameter in doSomeMapping to allow the function to instantiate it?

CodePudding user response:

You could introduce a typeclass:

trait SignalProvider[T] {
  def signal: MySignal
}

object SignalProvider {
  // instances for Person1 and Person2
  implicit val person1Instance: SignalProvider[Person1] =
    new SignalProvider[Person1] {
      def signal: MySignal = GotPerson1()
    }
  
  implicit val person2Instance: SignalProvider[Person2] =
    new SignalProvider[Person2] {
      def signal: MySignal = GotPerson2()
    }
}

Then this would work:

  def dispatcher(msg: parent) = {
    msg match {
      case something: Person1 => doSomeMapping[Person1](something)
      case something: Person2 => doSomeMapping[Person2](something)
      case something: Person3 => doSomethingElse(something)
    }
  }

  def doSomeMapping[T <: WithAge](msg: WithAge)(implicit s: SignalProvider[T]): MySignal = {
    // do some common cal based on WithAge fields
    if(msg.number > 10) s.signal
    else throw new RuntimeException()
  }

Note that you'd have to explicitly provide the type parameters when calling doSomeMapping - the compiler won't figure out which typeclass instance to use if you just write doSomeMapping(something) (and compilation will fail with "ambiguous implicit values").

Update: alternative solution

While the above may be fancy, I'm not sure if I wouldn't just go for this simple solution which does not involve concepts like implicits and type classes if it were my code (that would probably also depend on who else would need to work with that code and how familiar they are with these concepts):

  def dispatcher(msg: parent) = {
    msg match {
      case something: Person1 => doSomeMapping(something, GotPerson1())
      case something: Person2 => doSomeMapping(something, GotPerson2())
      case something: Person3 => doSomethingElse(something)
    }
  }

  // note that signal is passed by name, so no instance will be created
  // if one ends up in the else-branch that throws the exception:
  def doSomeMapping(msg: WithAge, signal: => MySignal): MySignal = {
    // do some common cal based on WithAge fields
    if(msg.number > 10) signal
    else throw new RuntimeException()
  }
  • Related