My goal is to efficiently find in a large list of list (let's take as an example 1 mln of entries and each entry is a list composed of 3 elements) the index of the element containing a certain value:
e.g let's take the list a
a = [[0,1,2],[0,5,6],[7,8,9]]
i want to retrive the indices of the elements containing the value 0, hence my function would return 0,1
My first try has been the following:
def any_identical_value(elements,index):
for el in elements:
if el == index:
return True
return False
def get_dual_points(compliant_cells, index ):
compliant = [i for i,e in enumerate(compliant_cells) if any_identical_value(e,index)]
return compliant
result = get_dual_points(a,0)
The solution works correctly but it is highly inefficient for large list of lists. In particular my goal is to perform a number of quesries that is the total number of values in the primary list, hence n_queries = len(a)*3
, in the example above 9.
Here comes 2 questions:
- Is the list the good data structure to achieve this task?
- Is there a more efficient algorithm solution?
CodePudding user response:
You can hash all indexes in one go (single O(N)
pass), which would allow you to answer the queries in O(1)
time.
from collections import defaultdict
d = defaultdict(list)
a = [[0,1,2],[0,5,6],[7,8,9]]
queries = [0,1]
for i in range(len(a)):
for element in a[i]:
d[element].append(i)
for x in queries:
print(d[x])
# prints
# [0, 1]
# [0]
CodePudding user response:
You can create a dictionary that maps from a value to a set of row indices. Then, for each query, you can simply look up the value, returning an empty set if it doesn't exist anywhere in the 2D list:
from itertools import product
a = [[0,1,2],[0,5,6],[7,8,9]]
values = {}
for row, col in product(range(len(a)), range(len(a[0]))):
value_at_index = a[row][col]
values.setdefault(value_at_index, set()).add(row)
print(values.get(0, set()))
This outputs:
{0, 1}
If you know in advance that the elements within each sublist are unique, then you can change the dictionary update line to:
values.setdefault(value_at_index, []).append(row)
and change the .get()
call to:
values.get(0, [])
to maintain the ordering of the indices in the output.
CodePudding user response:
Here is a proposed algorithm: iterate on the list of lists once, to build a dict that maps every unique element to all the indices of the sublists it belongs to.
With this method, the dict-building takes time proportional to the total number of elements in the list of lists. Then every query is constant-time.
This requires a dict of lists:
def dict_of_indices(a):
d = {}
for i,l in enumerate(a):
for e in l:
d.setdefault(e, []).append(i)
return d
a = [[0,1,2],[0,5,6],[7,8,9]]
d = dict_of_indices(a)
print( d[0] )
# [0, 1]