Home > OS >  Scala 3 collection partitioning with subtypes
Scala 3 collection partitioning with subtypes

Time:10-12

In Scala 3, let's say I have a List[Try[String]]. Can I split it up into success and failures, such that each list has the appropriate subtype?

If I do the following:

import scala.util.{Try, Success, Failure}
val tries = List(Success("1"), Failure(Exception("2")))
val (successes, failures) = tries.partition(_.isSuccess)

then successes and failures are still of type List[Try[String]]. The same goes if I filter based on the type:

val successes = tries.filter(_.isInstanceOf[Success[String]])

I could of course cast to Success and Failure respectively, but is there a type-safe way to achieve this?

CodePudding user response:

@Luis Miguel Mejía Suárez:

Use tries.partitionMap(_.toEither)

@mitchus:

@LuisMiguelMejíaSuárez ok the trick here is that Try has a toEither method which splits to the proper type. What if we have a regular sealed trait?

In Scala 2 I would do something like

import shapeless.{: :, ::, CNil, Coproduct, Generic, HList, HNil, Inl, Inr, Poly0}
import shapeless.ops.coproduct.ToHList
import shapeless.ops.hlist.{FillWith, Mapped, Tupler}

trait Loop[C <: Coproduct, L <: HList] {
  def apply(c: C, l: L): L
}
object Loop {
  implicit def recur[H, CT <: Coproduct, HT <: HList](implicit
    loop: Loop[CT, HT]
  ): Loop[H : : CT, List[H] :: HT] = {
    case (Inl(h), hs :: ht) => (h :: hs) :: ht
    case (Inr(ct), hs :: ht) => hs :: loop(ct, ht)
  }

  implicit val base: Loop[CNil, HNil] = (_, l) => l
}

object nilPoly extends Poly0 {
  implicit def cse[A]: Case0[List[A]] = at(Nil)
}

def partition[A, C <: Coproduct, L <: HList, L1 <: HList](as: List[A])(implicit
  generic: Generic.Aux[A, C],
  toHList: ToHList.Aux[C, L],
  mapped: Mapped.Aux[L, List, L1],
  loop: Loop[C, L1],
  fillWith: FillWith[nilPoly.type, L1],
  tupler: Tupler[L1]
): tupler.Out = {
  val partitionHList: L1 = as.foldLeft(fillWith())((l1, a) =>
    loop(generic.to(a), l1)
  )

  tupler(partitionHList)
}

sealed trait A
case class B() extends A
case class C() extends A
case class D() extends A

partition(List[A](B(), B(), C(), C(), D(), D(), B(), C())) 
// (List(B(), B(), B()),List(C(), C(), C()),List(D(), D())): (List[B], List[C], List[D])

https://scastie.scala-lang.org/DmytroMitin/uQp603sXT7WFYmYntDXmIw


I managed to translate this code into Scala 3 although the translation turned to be wordy (I remplemented Generic and Coproduct)

import scala.annotation.tailrec
import scala.deriving.Mirror

object App1 {
  // ============= Generic =====================
  trait Generic[T] {
    type Repr
    def to(t: T): Repr
    def from(r: Repr): T
  }
  object Generic {
    type Aux[T, Repr0] = Generic[T] { type Repr = Repr0 }
    def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
      new Generic[T] {
        override type Repr = Repr0
        override def to(t: T): Repr0 = f(t)
        override def from(r: Repr0): T = g(r)
      }

    object ops {
      extension [A](a: A) {
        def toRepr(using g: Generic[A]): g.Repr = g.to(a)
      }

      extension [Repr](a: Repr) {
        def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)
      }
    }

    given [T <: Product](using
      m: Mirror.ProductOf[T]
    ): Aux[T, m.MirroredElemTypes] = instance(
      _.productIterator
       .foldRight[Tuple](EmptyTuple)(_ *: _)
       .asInstanceOf[m.MirroredElemTypes],
      m.fromProduct(_).asInstanceOf[T]
    )

    inline given [T, C <: Coproduct](using
      m: Mirror.SumOf[T],
      ev: Coproduct.ToCoproduct[m.MirroredElemTypes] =:= C
    ): Generic.Aux[T, C] =
      instance(
        matchExpr[T, C](_).asInstanceOf[C],
        Coproduct.unsafeFromCoproduct(_).asInstanceOf[T]
      )

    import scala.quoted.*

    inline def matchExpr[T, C <: Coproduct](ident: T): Coproduct =
      ${matchExprImpl[T, C]('ident)}

    def matchExprImpl[T: Type, C <: Coproduct : Type](
      ident: Expr[T]
    )(using Quotes): Expr[Coproduct] = {
      import quotes.reflect.*

      def unwrapCoproduct(typeRepr: TypeRepr): List[TypeRepr] = typeRepr match {
        case AppliedType(_, List(typ1, typ2)) => typ1 :: unwrapCoproduct(typ2)
        case _  => Nil
      }

      val typeReprs = unwrapCoproduct(TypeRepr.of[C])

      val methodIdent =
        Ident(TermRef(TypeRepr.of[Coproduct.type], "unsafeToCoproduct"))

      def caseDefs(ident: Term): List[CaseDef] =
        typeReprs.zipWithIndex.map { (typeRepr, i) =>
          CaseDef(
            Typed(ident, TypeIdent(typeRepr.typeSymbol)),
            None,
            Block(
              Nil,
              Apply(
                methodIdent,
                List(Literal(IntConstant(i)), ident)
              )
            )
          )
        }

      def matchTerm(ident: Term): Term = Match(ident, caseDefs(ident))

      matchTerm(ident.asTerm).asExprOf[Coproduct]
    }
  }

  // ============= Coproduct =====================
  sealed trait Coproduct extends Product with Serializable
  sealed trait  :[ H,  T <: Coproduct] extends Coproduct
  final case class Inl[ H,  T <: Coproduct](head: H) extends (H  : T)
  final case class Inr[ H,  T <: Coproduct](tail: T) extends (H  : T)
  sealed trait CNil extends Coproduct

  object Coproduct {
    def unsafeToCoproduct(length: Int, value: Any): Coproduct =
      (0 until length).foldLeft[Coproduct](Inl(value))((c, _) => Inr(c))

    @tailrec
    def unsafeFromCoproduct(c: Coproduct): Any = c match {
      case Inl(h) => h
      case Inr(c) => unsafeFromCoproduct(c)
      case _: CNil => sys.error("impossible")
    }

    type ToCoproduct[T <: Tuple] <: Coproduct = T match {
      case EmptyTuple => CNil
      case h *: t => h  : ToCoproduct[t]
    }

//    type ToTuple[C <: Coproduct] <: Tuple = C match {
//      case CNil => EmptyTuple
//      case h  : t => h *: ToTuple[t]
//    }

    trait ToTuple[C <: Coproduct] {
      type Out <: Tuple
    }
    object ToTuple {
      type Aux[C <: Coproduct, Out0 <: Tuple] = ToTuple[C] { type Out = Out0 }
      def instance[C <: Coproduct, Out0 <: Tuple]: Aux[C, Out0] =
        new ToTuple[C] { override type Out = Out0 }

      given [H, T <: Coproduct](using 
        toTuple: ToTuple[T]
      ): Aux[H  : T, H *: toTuple.Out] = instance
      given Aux[CNil, EmptyTuple] = instance
    }
  }
}

// different file
import App1.{ :, CNil, Coproduct, Generic, Inl, Inr}

object App2 {    
  trait Loop[C <: Coproduct, L <: Tuple] {
    def apply(c: C, l: L): L
  }
  object Loop {
    given [H, CT <: Coproduct, HT <: Tuple](using 
      loop: Loop[CT, HT]
    ): Loop[H  : CT, List[H] *: HT] = {
      case (Inl(h), hs *: ht) => (h :: hs) *: ht
      case (Inr(ct), hs *: ht) => hs *: loop(ct, ht)
    }

    given Loop[CNil, EmptyTuple] = (_, l) => l
  }

  trait FillWithNil[L <: Tuple] {
    def apply(): L
  }
  object FillWithNil {
    given [H, T <: Tuple](using 
      fillWithNil: FillWithNil[T]
    ): FillWithNil[List[H] *: T] = () => Nil *: fillWithNil()
    given FillWithNil[EmptyTuple] = () => EmptyTuple
  }

  def partition[A, /*L <: Tuple,*/ L1 <: Tuple](as: List[A])(using
    generic: Generic.Aux[A, _ <: Coproduct],
    toTuple: Coproduct.ToTuple[generic.Repr],
    //ev0: Coproduct.ToTuple[generic.Repr] =:= L, // compile-time NPE
    ev: Tuple.Map[toTuple.Out/*L*/, List] =:= L1,
    loop: Loop[generic.Repr, L1],
    fillWith: FillWithNil[L1]
  ): L1 = as.foldLeft(fillWith())((l1, a) => loop(generic.to(a), l1))

  sealed trait A
  case class B() extends A
  case class C() extends A
  case class D() extends A

  def main(args: Array[String]): Unit = {
    println(partition(List[A](B(), B(), C(), C(), D(), D(), B(), C())))
  // (List(B(), B(), B()),List(C(), C(), C()),List(D(), D()))
  }
}

Scala 3.0.2

  • Related