Home > OS >  Leveraging a generic return type in Scala
Leveraging a generic return type in Scala

Time:10-28

So I would like to use a generic return type and be able to use the info of that type within the function. Not sure this is possible but here is what I would like:

  def getStuff[A](a: MyObj, b: String): Option[A] = {
    // do some stuff
    A match {
      case String => Some(a.getString(b))
      case Integer => Some(a.getInt(b))
      ...
      case _ => None
    }
  }

However, as you know, A match is not a possibility. Any ideas on how I could achieve this ?

CodePudding user response:

This is a classic case for using a typeclass:

trait StuffGetter[T] { // typeclass
  def get(obj: MyObj, s: String): Option[T]
}  

implicit val stringGetter = new StuffGetter[String] {
   def get(o: MyObj, s: String): Option[String] = ???
}
implicit val intGetter = new StuffGetter[Int] {
   def get(o: MyObj, s: String): Option[Int] = ???
}

def getStuff[A](a: MyObj, b: String)(implicit ev: StuffGetter[A]): Option[A] =
  ev.get(a, b)

val stuff0 = getStuff[String](obj, "Hello")  // calls get on stringGetter
val stuff1 = getStuff[Int](obj, "World") // call get on intGetter
val stuff2 = getStuff[Boolean](obj, "!") // Compile-time error

The StuffGetter trait defines the operations that you want to perform on the generic type, and each implicit value of that trait provides the implementation for a specific type. (For a custom type these are typically place in the companion object for the type; the compiler will look there for them)

When getStuff is called the compiler will look for an implicit instance of StuffGetter with the matching type. This will fail if no such instance exists, otherwise it will be passed in the ev parameter.

The advantage of this is that the "match" is done at compile time and unsupported types are also detected at compile time.

CodePudding user response:

Conceptually we can differentiate between pattern matching at run-time which looks something like this

def getStuff[A](...) =
  A match {
    ...
  }

and pattern matching at compile-time which looks something like this

def getStuff[A](...)(implicit ev: Foo[A]) = {
   ev.bar(...)
}

Key concept to understand is that types do not exists at run-time because they get "erased" after compilation so there is not enough information to pattern match on types once the program is running. However at compile-time, that is before the program runs, types do exist and Scala provides means to ask the compiler to effectively pattern match on them via implicit/givens mechanism which looks something like so

// type class requirements for type A 
trait StringConverter[A] {
  def getOptValue(b: String): Option[A]
}

// evidence which types satisfy the type class 
implicit val intStringConverter: StringConverter[Int] = (b: String) => b.toIntOption
implicit val strStringConverter: StringConverter[String] = (b: String) => Some(b)
implicit def anyStringConverter[A]: StringConverter[A] = (b: String) => None

// compile-time pattern matching on type A
 def getStuff[A](b: String)(implicit ev: StringConverter[A]): Option[A] = {
   ev.getOptValue(b)
  }

getStuff[Int]("3")     // : Option[Int] = Some(value = 3)
getStuff[String]("3")  // : Option[String] = Some(value = "3")
getStuff[Double]("3")  // : Option[Double] = None

This compile-time pattern matching is called type class pattern.

Understanding the distinction between types and classes is one of the fundamental concepts in Scala https://docs.scala-lang.org/tutorials/FAQ/index.html#whats-the-difference-between-types-and-classes and gorking it will help understand how to write type classes.

CodePudding user response:

Using custom typeclass similar to Getter:

trait KeyedGetter[S, K, A]:
  def get(s: S, key: K): Option[A]

case class MyObj(ints: Map[String, Int], strs: Map[String, String])

object MyObj:
  given KeyedGetter[MyObj, String, Int] with
    def get(m: MyObj, k: String) = m.ints.get(k)

  given KeyedGetter[MyObj, String, String] with
    def get(m: MyObj, k: String) = m.strs.get(k)

def getStuff[A](m: MyObj, key: String)(using g: KeyedGetter[MyObj, String, A]): Option[A] =
  g.get(m, key)

Using class tags:

case class MyObj(ints: Map[String, Int], strs: Map[String, String])

import reflect._
def getStuff[A](m: MyObj, key: String)(using ct: ClassTag[A]): Option[A] = (ct match
  case _ if ct == classTag[String] => m.strs.get(key)
  case _ if ct == classTag[Int] => m.ints.get(key)
  case _ => None
).asInstanceOf[Option[A]]

If the erased types are insufficient, for a similar approach with type tags see this answer (and ignore the rest).

  • Related