Forgive me if I'm not naming things by their actual name, I've just started to learn Scala. I've been looking around for a while, but can not find a clear answer to my question.
Suppose I have a list of objects, each object has two fields: x: Int
and l: List[String]
, where the Strings, in my case, represent categories.
The l
lists can be of arbitrary length, so an object can belong to multiple categories. Furthermore, various objects can belong to the same category. My goal is to group the objects by the individual categories, where the categories are the keys. This means that if an object is linked to say "N" categories, it will occur in "N" of the key-value pairs.
So far I managed to groupBy the lists of categories through:
objectList.groupBy(x => x.l)
However, this obviously groups the objects by list of categories rather than by categories.
I'm trying to do this with immutable collections avoiding loops etc.
If anyone has some ideas that would be much appreciated! Thanks
CodePudding user response:
Something like the following?
objectList // Seq[YourType]
.flatMap(o => o.l.map(c => c -> o)) // Seq[(String, YourType)]
.groupBy { case (c,_) => c } // Map[String,Seq[(String,YourType)]]
.mapValues { items => c -> items.map { case (_, o) => o } } // Map[String, Seq[YourType]]
(Deliberately "heavy" to help you understand the idea behind it)
EDIT, or as of Scala 2.13 thanks to groupMap
:
objectList // Seq[YourType]
.flatMap(o => o.l.map(c => c -> o)) // Seq[(String, YourType)]
.groupMap { case (c,_) => c } { case (_, o) => o } // Map[String,Seq[YourType]]
CodePudding user response:
You are very close, you just need to split each individual element in the list before the group
so try with something like this:
// I used a Set instead of a List,
// since I don't think the order of categories matters
// as well I would think having two times the same category doesn't make sense.
final case class MyObject(x: Int, categories: Set[String] = Set.empty) {
def addCategory(category: String): MyObject =
this.copy(categories = this.categories category)
}
def groupByCategories(data: List[MyObject]): Map[String, List[Int]] =
data
.flatMap(o => o.categories.map(c => c -> o.x))
.groupMap(_._1)(_._2)