Home > other >  How do I find intersection of value lists in a txt file RDD with pyspark?
How do I find intersection of value lists in a txt file RDD with pyspark?

Time:02-01

I am learning spark and want to work on the intersection of all values in the file

The format of the file looks like the following:

a
1, 2, 3, 0, ...

b
0, 5, 20, 3, ...

c
0, 7, 9, 10, 2, 20, ...

d
empty

e
empty

I tried doing the following:

rdd = spark.sparkContext.textFile('data.txt')
rdd1 = rdd.map(lambda x: x.split('\t')).map(lambda x: (x[0], x[1])).map(lambda x : (x[0], list(x[1].split('\n,'))))
ab = rdd1.map(lambda x: (x[0], (x[1]))).reduceByKey(lambda x, y: (set(x[0]))).map(lambda x: (x[0], list(set(x[1]))))

And I now have the data in the following format as key-value pairs.

[('a', [1, 2, 3, 0, ...]), ('b', [0, 5, 20, 3,...]), ('c', [0, 7, 9, 10, 2, 20, ...]), ...]

I need to find the intersection of each value and other values in the data and attach the keys where the length of the intersection if >= 2.

Like:

[key, [list of keys in the entire data whose length of intersections with the current key is >=2]] 

For example, the values of key a has intersection [0, 3] with values of key b. values of Key a also has intersection [0, 2] with values of key b. Similarly, for key b, same thing should happen in the case of a and b. Then b and c have intersection [0, 20]. Finally, d will be assigned e and vice versa because both are empty.

Sample output:

[('a', [b, c]), ('b', [a, c]), ('c', [a, b]), ('d', [e]), ('e', [d])]

CodePudding user response:

Assuming you input RDD is something like:

rdd = spark.sparkContext.parallelize([
    ('a', [1, 2, 3, 0]), ('b', [0, 5, 20, 3]),
    ('c', [0, 7, 9, 10, 2, 20]), ('d', []), ('e', []),
])

In order to check values intersections between all combinations of keys you need to apply cartesian product on the RDD, then filter values where there is intersection:

from operator import add

def lists_intersect(l1, l2):
    if len(l1) == len(l2) == 0:  # both empty
        return True
    if len(set(l1).intersection(l2)) >= 2:  # have 2 or more same elements
        return True
    return False


result = rdd.cartesian(rdd) \
    .filter(lambda x: x[0][0] != x[1][0] and lists_intersect(x[0][1], x[1][1])) \
    .map(lambda x: (x[0][0], [x[1][0]])) \
    .reduceByKey(add)

result.collect()
# [('a', ['b', 'c']), ('d', ['e']), ('c', ['a', 'b']), ('e', ['d']), ('b', ['a', 'c'])]

Another way by using join instead of cartesian product:

# add index to the original rdd
rdd = rdd.zipWithIndex().map(lambda x: (x[1], x[0]))

# generates another rdd that contains pairs of row indices to check
# combinations of all rows except the row itself 
indices = range(rdd.count())
indices_rdd = spark.sparkContext.parallelize([
    (i, j) for i in indices for j in indices if i != j
])

result = indices_rdd.join(rdd) \
    .map(lambda x: (x[1][0], (x[0], x[1][1]))) \
    .join(rdd) \
    .filter(lambda x: lists_intersect(x[1][0][1][1], x[1][1][1])) \
    .map(lambda x: (x[1][0][1][0], [x[1][1][0]])) \
    .reduceByKey(add)
  •  Tags:  
  • Related