Home > Net >  Functional way to express a stateful filter in Scala
Functional way to express a stateful filter in Scala

Time:04-22

Imagine the following List[Int] in Scala:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

I want to apply kind of a dynamic filter to it such that towards head/tail less data is filtered compared to the middle of the list:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
 _  ____  __________  ___________  ______  __

[0, 1, 3, 7, 11, 13] // result

To do this, from both ends an index is increased by the next power of 2 and from the middle onwards the powers of 2 decrease again: 0, 0 2**0 = 1, 1 2**1 = 3, 3 2^2 = 7, etc.

To achieve this with an imperative approach, code similar to this can be used:

var log = 0
var idx = 0
val mask: ListBuffer[Int] = mutable.ListBuffer()
while (idx < buffer.size) {
  mask  = idx
  if (idx   (2 ** log) < buffer.size / 2) {
    idx  = 2 ** log
    log  = 1
  } else {
    idx = buffer.size - (2 ** log)   1
    log -= 1
  }
}

This produces a mask array that can then be used to filter the original list like mask.flatMap(list.lift)

Can somebody help me do this in a more concise, functional way? What I basically need is a way to filter the list using some changing state from the outside.

Thanks in advance.

CodePudding user response:

The usual approach to iterating with state is tail recursion (or you often do the same thing with reduceLeft with simple enough cases).

This is better than the other answer, because it is linear (accessing list elements by index makes the whole thing quadratic), and tail-recursive (no extra space on stack). Also, I think, the other version reverses the order of filtered elements.

You can do it with recursively in one go (this is better than the other answer, because it is tail-recursive, and linear (accessing list elements by index makes the implementation quadratic).

I didn't check the logic, which the other answer suggested was incorrect, just used it as is from your snippet, but here is the idea:

@tailrec
def filter(
    in: List[Int], 
    midpoint: Int,
    out: List[Int]=Nil, 
    idx: Int = 0, 
    next: Int = 0,
    log: Int = 0
): List[Int] = in match {
    case Nil => out.reverse
    case head::tail if (idx == next) => 
        filter(tail, midpoint, head::out, idx 1, idx   pow(2, log).toInt, if (idx < midpoint) log   1 else log-1)
    case head::tail => filter(tail, midpoint, out, idx 1, next, log)
}

Note, that this may seem less efficient than your "mask" idea because it looks at every element in the list, rather than jumping over indices being filtered out, but in fact, as long as you are working with List, it is actually more efficient: first, yours is (at least) O(N) anyway, because you have to traverse the whole list to figure out the size, and secondly, list.lift(idx) is O(idx), so towards the end of the list, this will be require multiple traversals of almost entire list.

Now, if you had an indexed container instead of a list, the whole "masking" idea would indeed improve things:

def filter(list: IndexedSeq[Int]) = {
  val size = list.size
  Iterator.iterate((0, 0)) { case (idx, log) => 
    (idx   math.pow(2, log).toInt, if idx < size/2 log 1 else log-1)
  }.map(_._1).takeWhile(_ < size).map(list)
}

CodePudding user response:

You code snippet does not work very well, I had to tweak it a bit to make it output the result you want:

var log = 0
var idx = 0
val resultList: mutable.ListBuffer[Int] = mutable.ListBuffer()
// Fill the result until the middle, increasing the jump size
while (idx < list.size / 2) {
  resultList  = list(idx)
  idx  = math.pow(2, log).toInt
  log  = 1
}
// Fill the result from the middle until the end, decreasing the jump size again
while (idx < list.size && log >= 0) {
  resultList  = list(idx)
  log -= 1
  idx  = math.pow(2, log).toInt
}

With your example it works:

val list = (0 to 13).toList  ->  ListBuffer(0, 1, 3, 7, 11, 13)

However with another example I got that:

val list = (0 to 22).toList  ->  ListBuffer(0, 1, 3, 7, 15)

I don't think this is really what you want, do you?

Anyway here is a more functionnal version of it:

def filter(list: List[Int]) = {
  // recursive function to traverse the list
  def recursive(l: List[Int], log: Int, idx: Int, size: Int, halfSize: Int): List[Int] = {
    if (idx >= l.size || log < 0)  // terminal case: end of the list
      Nil
    else if (idx < l.size / 2)  // first half of the list: increase the jump size
      l(idx) :: recursive(l, log   1, idx   math.pow(2, log).toInt, size, halfSize)
    else  // second half of the list: decrease the jump size
      l(idx) :: recursive(l, log - 1, idx   math.pow(2, log-1).toInt, size, halfSize)
  }
  // call the recursive function with initial parameters
  recursive(list, 0, 0, list.size, list.size / 2)
}

However, jumping by powers if 2 is too aggressive. If you are near the middle of the list, the next jump will ends at the very end, and you will not be able to get a progressive jump decay.

Another solution could be to increase the jump size by one each time instead of working with powers of 2. You can also keep a constant jump size when you are near the middle of the list to avoid skipping too much values in the second half before starting to reduce the jump size:

def filter2(list: List[Int]) = {
  def recursive(l: List[Int], jumpsize: Int, idx: Int, size: Int, halfSize: Int): List[Int] = {
    if (idx >= l.size)  // terminal case: end of the list
      Nil
    else if (idx   jumpsize < l.size/2)  // first half of the list: increase the jump size
      l(idx) :: recursive(l, jumpsize 1, idx   jumpsize, size, halfSize)
    else if (idx < l.size/2)  // around the middle of the list: keep the jump size
      l(idx) :: recursive(l, jumpsize, idx   jumpsize, size, halfSize)
    else {  // second half of the list: decrease the jump size
      val nextJumpSize = jumpsize - 1
      l(idx) :: recursive(l, nextJumpSize, idx   nextJumpSize, size, halfSize)
    }
  }
  // call the recursive function with initial parameters
  recursive(list, 1, 0, list.size, list.size / 2)
}

In my opinion, the results with this version are a bit better:

val list = (0 to 22).toList  ->  List(0, 1, 3, 6, 10, 15, 19, 22)

CodePudding user response:

Your question is not so clear for some corner case, here is my solution:

scala> def filter[A](seq: Seq[A], n: Int = 1): Seq[A] = seq match {
     |   case Nil    => Nil
     |   case Seq(x) => Seq(x)
     |   case _      => seq.head  : filter(seq.drop(n).dropRight(2*n), 2*n) :  seq.last
     | }
def filter[A](seq: Seq[A], n: Int): Seq[A]

scala> filter(0 to 13)
val res0: Seq[Int] = List(0, 1, 3, 7, 11, 13)

scala> filter(0 to 100)
val res1: Seq[Int] = List(0, 1, 3, 7, 15, 31, 38, 70, 86, 94, 98, 100) //  I am not sure if 38 should in the result
  • Related