Home > OS >  FS2 parallel evaluation of items with different discriminator
FS2 parallel evaluation of items with different discriminator

Time:11-10

In the following example, items with different discriminators ("a", "b" and "c") are evaluated (printed) in parallel:

package org.example

import cats.effect.std.Random
import cats.effect.{ExitCode, IO, IOApp, Temporal}
import cats.syntax.all._
import cats.{Applicative, Monad}
import fs2._

import scala.concurrent.duration._

object GitterQuestion extends IOApp {

  override def run(args: List[String]): IO[ExitCode] =
    Random.scalaUtilRandom[IO].flatMap { implicit random =>
      val flat = Stream(
        ("a", 1),
        ("a", 2),
        ("a", 3),

        ("b", 1),
        ("b", 2),
        ("b", 3),

        ("c", 1),
        ("c", 2),
        ("c", 3)
      ).covary[IO]

      val a = flat.filter(_._1 === "a").through(rndDelay)
      val b = flat.filter(_._1 === "b").through(rndDelay)
      val c = flat.filter(_._1 === "c").through(rndDelay)

      val nested = Stream(a, b, c)

      nested.parJoin(100).printlns.compile.drain.as(ExitCode.Success)
    }

  def rndDelay[F[_]: Monad: Random: Temporal, A]: Pipe[F, A, A] =
    in =>
      in.evalMap { v =>
        (Random[F].nextDouble.map(_.seconds) >>= Temporal[F].sleep) >> Applicative[F].pure(v)
      }
}

The result of running this program will look similar to this:

(c,1)
(a,1)
(c,2)
(a,2)
(c,3)
(b,1)
(a,3)
(b,2)
(b,3)

Note that there's no reordering between items with the same discriminator - they are processed sequentially. (a, 2) will never be printed before (a, 1).

In my real-world scenario, the discriminator values are not known ahead of time and there can be many of them, but I would like to have the same behavior, how can I do this?

CodePudding user response:

I believe that broadcastThrough does what you want.
(but make sure to check the Scaladoc carefully)

I am using IO directly for simplicity but it should be easy to adapt to abstract F[_]

def discriminateProcessing[A, B](stream: Stream[IO, A])(discriminators: List[A => Boolean])(pipe: Pipe[IO, A, B]): Stream[IO, B] = {
  val allPipes: List[Pipe[IO, A, B]] = discriminators.map { p =>
    s => s.filter(p).through(pipe)
  }

  stream.broadcastThrough(allPipes : _*)
}

Which would be used like this:

val result = discriminateProcessing(stream = flat)(discriminators = List(
  _._1 === "a",
  _._1 === "b",
  _._1 === "c",
)) { s =>
  s.evalMap { v =>
    random.nextDouble.map(_.seconds).flatMap(IO.sleep).as(v)
  }
}

You can see the code running here.

CodePudding user response:

I think you need to roll your own groupBy function for this. I think you would have to create a Queue for every discriminator. Then for every Queue emit one inner Stream that pulls elements from that Queue.

Here's an untested and probably naive implementation of what I had in mind:

import cats.effect.std.Queue

val nested = 
  (flat.map(Some(_))    Stream(None))
    .evalScan(Map.empty[String, Queue[IO, Option[(String, Int)]]] -> Option.empty[Stream[IO, (String, Int)]]){
      case ((map, _), t @ Some((key, value))) =>
        if (map.contains(key))
          map(key).offer(t).as(map -> None)
        else {
          for {
            q <- Queue.unbounded[IO, Option[(String, Int)]]
            _ <- q.offer(t)
            r = (map   (key -> q)) -> Some(Stream.fromQueueNoneTerminated(q))
          } yield r
        }
      case ((map, _), None) => 
      // None means the flat stream is finished
        map.values.toList.traverse(_.offer(None))
          .as(Map.empty -> None)
    }
    .map(_._2).unNone

CodePudding user response:

We could use foldMap to group the values as a Map of Stream. We could use groupAdjacentBy but it wouldn't work for if the same keys are non-adjacent. Additional flatMap would be required to de-nest the Map.

//val a = flat.filter(_._1 === "a").through(rndDelay)
//val b = flat.filter(_._1 === "b").through(rndDelay)
//val c = flat.filter(_._1 === "c").through(rndDelay)
//
//val nested = Stream(a, b, c)

//edit: this would only work if keys are adj
//val nested = flat
//  .groupAdjacentBy(_._1)
//  .map { case (_, c) => Stream.chunk(c).covary[IO].through(rndDelay) }

val nested = flat
  .foldMap { case t@(k, _) => Map(k -> Stream.emit(t).covary[IO]) }
  .flatMap(m => Stream.fromIterator[IO](m.values.map(_.through(rndDelay)).iterator, 4096))

nested.parJoin(100).printlns.compile.drain.as(ExitCode.Success)
  • Related