Home > Software design >  How to change base class field in functional fashion in scala
How to change base class field in functional fashion in scala

Time:10-10

Say I have this hierarchy

trait Base {
  val tag: String
}

case class Derived1(tag: String = "Derived 1") extends Base
case class Derived2(tag: String = "Derived 2") extends Base
//etc ...

and I want to define method with following signature

def tag[T <: Base](instance: T, tag: String): T

that returns an instance of type T with modified tag: String. So when e.g. a Derived1 instance is passed in a modified instance of the same type is returned.

This goal could be easily accomplished by using mutable tag variable var tag: String. How to achieve desired behaviour using scala and functional programming?

My thought:

I could create a type class and its instances

trait Tagger[T] {
  def tag(t: T, state: String): T
}

implicit object TaggerDerived1 extends Tagger[Derived1] {
  override def tag(t: Derived1, state: String): Derived1 = ???
}

implicit object TaggerDerived2 extends Tagger[Derived2] {
  override def tag(t: Derived2, state: String): Derived2 = ???
}

implicit object TaggerBase extends Tagger[Base] {
  override def tag(t: Base, state: String): Base = ???
}

and a method

def tag[T <: Base](instance: T, tag: String)(implicit tagger: Tagger[T]): T = tagger.tag(instance, tag)

This is not ideal, because first of all user must be aware of this when defining their own derived classes. When not defining one, the implicit resolution would fall back to base implementation and narrow the returning type.

case class Derived3(tag: String = "Derived 3") extends Base


tag(Derived3(), "test") // falls back to `tag[Base](...)`

Now I am leaning towards using mutable state by employing var tag: String. However, I would love to hear some opinions how to resolve this purely functionally in scala.

CodePudding user response:

You can derive your type class Tagger and then the users will not have to define its instances for every new case class extending Base

// libraryDependencies  = "com.chuusai" %% "shapeless" % "2.3.10"
import shapeless.labelled.{FieldType, field}
import shapeless.{::, HList, HNil, LabelledGeneric, Witness}

trait Tagger[T] {
  def tag(t: T, state: String): T
}

trait LowPriorityTagger {
  implicit def notTagFieldTagger[K <: Symbol : Witness.Aux, V, T <: HList](implicit
    tagger: Tagger[T]
  ): Tagger[FieldType[K, V] :: T] =
    (t, state) => t.head :: tagger.tag(t.tail, state)
}

object Tagger extends LowPriorityTagger {
  implicit def genericTagger[T <: Base with Product, L <: HList](implicit
    generic: LabelledGeneric.Aux[T, L],
    tagger: Tagger[L]
  ): Tagger[T] = (t, state) => generic.from(tagger.tag(generic.to(t), state))

  implicit val hnilTagger: Tagger[HNil] = (_, _) => HNil

  implicit def tagFieldTagger[T <: HList]:
    Tagger[FieldType[Witness.`'tag`.T, String] :: T] = 
    (t, state) => field[Witness.`'tag`.T](state) :: t.tail
}
case class Derived1(tag: String = "Derived 1") extends Base
case class Derived2(tag: String = "Derived 2") extends Base
case class Derived3(i: Int, tag: String = "Derived 3", s: String) extends Base

tag(Derived1("aaa"), "bbb") // Derived1(bbb)
tag(Derived2("ccc"), "ddd") // Derived2(ddd)
tag(Derived3(1, "ccc", "xxx"), "ddd") // Derived3(1,ddd,xxx)

Alternatively for single-parameter case classes you can constrain T so that it has .copy

import scala.language.reflectiveCalls
def tag[T <: Base {def copy(tag: String): T}](instance: T, tag: String): T =
  instance.copy(tag = tag)

For multi-parameter case classes it's harder to express in types the existence of .copy because the method signature becomes unknown (to be calculated).

So you can make tag a macro

// libraryDependencies  = scalaOrganization.value % "scala-reflect" % scalaVersion.value
import scala.language.experimental.macros
import scala.reflect.macros.blackbox

def tag[T <: Base](instance: T, tag: String): T = macro tagImpl

def tagImpl(c: blackbox.Context)(instance: c.Tree, tag: c.Tree): c.Tree = {
  import c.universe._
  q"$instance.copy(tag = $tag)"
}

Or you can use runtime reflection (Java or Scala, using Product functionality or not)

import scala.reflect.{ClassTag, classTag}
import scala.reflect.runtime.{currentMirror => rm}
import scala.reflect.runtime.universe.{TermName, termNames}

def tag[T <: Base with Product : ClassTag](instance: T, tag: String): T = {
    // Product
  val tagIdx = instance.productElementNames.zipWithIndex
    .find{ case (fieldName, idx) => fieldName == "tag" }.map(_._2).get
  val values = instance.productIterator.zipWithIndex
    .map {case (fieldValue, idx) => if (idx == tagIdx) tag else fieldValue}.toSeq

    // Java reflection
  // val clazz = instance.getClass
  // clazz.getMethods.find(_.getName == "copy").get.invoke(instance, values: _*).asInstanceOf[T]
  // clazz.getConstructors.head.newInstance(values: _*).asInstanceOf[T]

    // Scala reflection
  val clazz = classTag[T].runtimeClass
  val classSymbol = rm.classSymbol(clazz)
  // val copyMethodSymbol = classSymbol.typeSignature.decl(TermName("copy")).asMethod
  // rm.reflect(instance).reflectMethod(copyMethodSymbol)(values: _*).asInstanceOf[T]
  val constructorSymbol = classSymbol.typeSignature.decl(termNames.CONSTRUCTOR).asMethod
  rm.reflectClass(classSymbol).reflectConstructor(constructorSymbol)(values: _*).asInstanceOf[T]
}
  • Related