Home > other >  min of given keys from python defaultdictionary
min of given keys from python defaultdictionary

Time:08-07

I got a defaultdict with lists as values and tuples as keys (ddict in the code below). I want to find the min and max of values for a given set of keys. The keys are given as a numpy array. The numpy array is a 3D array containing the keys. Each row of the 3D array is the block of keys for which we need to find the min and max i.e. for each row we take the corresponding 2D array entries, and get the values corresponding to those entries and find the min and max over those values. I need to do it for all the rows of the 3D array.

from operator import itemgetter
import numpy as np

ddict =  {(1.0, 1.0): [1,2,3,4], (1.0, 2.5): [2,3,4,5], (1.0, 3.75): [], (1.5, 1.0): [8,9,10], (1.5, 2.5): [2,6,8,19,1,31], (1.5,3.75): [4]}
indA = np.array([ [ [( 1.0, 1.0), ( 1.0, 3.75)], [(1.5,1.0), (1.5,3.75)] ], [ [(1.0, 2.5), (1.5,1.0)], [(1.5, 2.5), (1.5,3.75)] ] ])

mins = min(ddict, key=itemgetter(*[tuple(i) for b in indA for i in b.flatten()]))
maxs = max(ddict, key=itemgetter(*[tuple(i) for b in indA for i in b.flatten()]))

I tried the above code to get the output of

min1 = min([1,2,3,4,8,9,10,4]) & min2 = min([2,3,4,5,8,9,10,2,6,8,19,1,31,4]) and max1= max([1,2,3,4,8,9,10,4]) & max2 = max([2,3,4,5,8,9,10,2,6,8,19,1,31,4])

I want to calculate the min and max for every 2D array in the numpy array. Any workaround ? Why my code is not working ? It gives me error TypeError: tuple indices must be integers or slices, not tuple

CodePudding user response:

Here is what I think you're after:

import numpy as np

# I've reformatted your example data, to make it a bit clearer
# no change in content though, just different whitespace
# whether d is a dict or defaultdict doesn't matter
d = {
    (1.0, 1.0): [1, 2, 3, 4],
    (1.0, 2.5): [2, 3, 4, 5],
    (1.0, 3.75): [],
    (1.5, 1.0): [8, 9, 10],
    (1.5, 2.5): [2, 6, 8, 19, 1, 31],
    (1.5, 3.75): [4]
}

# indA is just an array of indices, avoid capitals in variable names
indices = np.array(
    [
        [[(1.0, 1.0), (1.0, 3.75)], [(1.5, 1.0), (1.5, 3.75)]],
        [[(1.0, 2.5), (1.5, 1.0)], [(1.5, 2.5), (1.5, 3.75)]]
    ])

for group in indices:
    # you flattened each grouping of indices, but that flattens
    # the tuples you need intact as well:
    print('not: ', group.flatten())
    # Instead, you just want all the tuples:
    print('but: ', group.reshape(-1, group.shape[-1]))

# with that knowledge, this is how you can get the lists you want
# the min and max for
for group in indices:
    group = group.reshape(-1, group.shape[-1])
    values = list(x for key in group for x in d[tuple(key)])
    print(values)

# So, the solution:
result = [
    (min(vals), max(vals)) for vals in (
        list(x for key in grp.reshape(-1, grp.shape[-1]) for x in d[tuple(key)])
        for grp in indices
    )
]
print(result)

Output:

not:  [1.   1.   1.   3.75 1.5  1.   1.5  3.75]
but:  [[1.   1.  ]
 [1.   3.75]
 [1.5  1.  ]
 [1.5  3.75]]
not:  [1.   2.5  1.5  1.   1.5  2.5  1.5  3.75]
but:  [[1.   2.5 ]
 [1.5  1.  ]
 [1.5  2.5 ]
 [1.5  3.75]]
[1, 2, 3, 4, 8, 9, 10, 4]
[2, 3, 4, 5, 8, 9, 10, 2, 6, 8, 19, 1, 31, 4]
[(1, 10), (1, 31)]

That is, [(1, 10), (1, 31)] is the result you are after, 1 being the minimum of the combined values of the first group of indices, 10 the maximum of that same group of values, etc.

Some explanation of key lines:

values = list(x for key in group for x in d[tuple(key)])

This constructs a list of combined values by looping over every pair of key values in group and using them as an index into the dictionary d. However, since key will be an ndarray after the reshaping, it is passed to the tuple() function first, so that the dict is correctly indexed. It loops over the retrieved values and adds each value x to the resulting list.

The solution is put together in a single comprehension:

[
    (min(vals), max(vals)) for vals in (
        list(x for key in grp.reshape(-1, grp.shape[-1]) for x in d[tuple(key)])
        for grp in indices
    )
]

The outer brackets indicate that a list is being constructed. (min(vals), max(vals)) is a tuple of min and max of vals, and vals loops over the inner comprehension. The inner comprehension is a generator (with parentheses instead of brackets) generating the lists for each group in indices, as explained above.

  • Related