I am learning spark and want to work on the intersection of all values in the file
The format of the file looks like the following:
a
1, 2, 3, 0, ...
b
0, 5, 20, 3, ...
c
0, 7, 9, 10, 2, 20, ...
d
empty
e
empty
I tried doing the following:
rdd = spark.sparkContext.textFile('data.txt')
rdd1 = rdd.map(lambda x: x.split('\t')).map(lambda x: (x[0], x[1])).map(lambda x : (x[0], list(x[1].split('\n,'))))
ab = rdd1.map(lambda x: (x[0], (x[1]))).reduceByKey(lambda x, y: (set(x[0]))).map(lambda x: (x[0], list(set(x[1]))))
And I now have the data in the following format as key-value pairs.
[('a', [1, 2, 3, 0, ...]), ('b', [0, 5, 20, 3,...]), ('c', [0, 7, 9, 10, 2, 20, ...]), ...]
I need to find the intersection of each value and other values in the data and attach the keys where the length of the intersection if >= 2
.
Like:
[key, [list of keys in the entire data whose length of intersections with the current key is >=2]]
For example, the values of key a
has intersection [0, 3]
with values of key b
. values of Key a
also has intersection [0, 2]
with values of key b
. Similarly, for key b
, same thing should happen in the case of a
and b
. Then b
and c
have intersection [0, 20]
. Finally, d
will be assigned e
and vice versa because both are empty.
Sample output:
[('a', [b, c]), ('b', [a, c]), ('c', [a, b]), ('d', [e]), ('e', [d])]
CodePudding user response:
Assuming you input RDD is something like:
rdd = spark.sparkContext.parallelize([
('a', [1, 2, 3, 0]), ('b', [0, 5, 20, 3]),
('c', [0, 7, 9, 10, 2, 20]), ('d', []), ('e', []),
])
In order to check values intersections between all combinations of keys you need to apply cartesian product on the RDD, then filter values where there is intersection:
from operator import add
def lists_intersect(l1, l2):
if len(l1) == len(l2) == 0: # both empty
return True
if len(set(l1).intersection(l2)) >= 2: # have 2 or more same elements
return True
return False
result = rdd.cartesian(rdd) \
.filter(lambda x: x[0][0] != x[1][0] and lists_intersect(x[0][1], x[1][1])) \
.map(lambda x: (x[0][0], [x[1][0]])) \
.reduceByKey(add)
result.collect()
# [('a', ['b', 'c']), ('d', ['e']), ('c', ['a', 'b']), ('e', ['d']), ('b', ['a', 'c'])]
Another way by using join instead of cartesian product:
# add index to the original rdd
rdd = rdd.zipWithIndex().map(lambda x: (x[1], x[0]))
# generates another rdd that contains pairs of row indices to check
# combinations of all rows except the row itself
indices = range(rdd.count())
indices_rdd = spark.sparkContext.parallelize([
(i, j) for i in indices for j in indices if i != j
])
result = indices_rdd.join(rdd) \
.map(lambda x: (x[1][0], (x[0], x[1][1]))) \
.join(rdd) \
.filter(lambda x: lists_intersect(x[1][0][1][1], x[1][1][1])) \
.map(lambda x: (x[1][0][1][0], [x[1][1][0]])) \
.reduceByKey(add)