Home > Back-end >  Find index of element in sublist that is same across all the sublists in a nested list
Find index of element in sublist that is same across all the sublists in a nested list

Time:12-28

I have a nested list like this:

[[8.0, 16.275953626987484, 5.923962654520423],
   [8.0, 3.0076746636087575, 17.05737063904884]),
  ([8.0, 3.0076746636087575, 17.05737063904884],
   [8.0, -13.268278963378728, 11.133407984528427]),
  ([8.0, -13.268278963378728, 11.133407984528427],
   [8.0, -16.275953626987487, -5.923962654520416]),
  ([8.0, -16.275953626987487, -5.923962654520416],
   [8.0, -3.0076746636087694, -17.057370639048848]),
  ([8.0, -3.0076746636087694, -17.057370639048848],
   [8.0, 13.268278963378716, -11.133407984528437]),
  ([8.0, 13.268278963378716, -11.133407984528437],
   [8.0, 16.275953626987484, 5.923962654520423]]

In that list, I want to remove the elements in the sublists that have the same values in the same index, for instance, here I want to remove 8 from all the sublists since its the same across all the sublists at index 0.

How do I do it in a short and simple way without too many hard-coded conditions?

One thing to note here is that the sublists will always have 3 elements while the length of the outer nested list can be anything.

CodePudding user response:

Your problem seems a bit odd, but this is a solution:

example = [
  ([8.0, 16.275953626987484, 5.923962654520423],
   [8.0, 3.0076746636087575, 17.05737063904884]),
  ([8.0, 3.0076746636087575, 17.05737063904884],
   [8.0, -13.268278963378728, 11.133407984528427]),
  ([8.0, -13.268278963378728, 11.133407984528427],
   [8.0, -16.275953626987487, -5.923962654520416]),
  ([8.0, -16.275953626987487, -5.923962654520416],
   [8.0, -3.0076746636087694, -17.057370639048848]),
  ([8.0, -3.0076746636087694, -17.057370639048848],
   [8.0, 13.268278963378716, -11.133407984528437]),
  ([8.0, 13.268278963378716, -11.133407984528437],
   [8.0, 16.275953626987484, 5.923962654520423])
]


def unchanging_mask(data):
    def all_inner_iterables(xss):
        if xss and isinstance(xss[0], (list, tuple)):
            for xs in xss:
                yield from all_inner_iterables(xs)
        else:
            yield xss

    return list(map(lambda xs: len(set(xs)) == 1, zip(*all_inner_iterables(data))))


def select_columns(data, mask):
    if data and isinstance(data[0], (list, tuple)):
        return type(data)(select_columns(xs, mask) for xs in data)
    else:
        return type(data)(x for i, x in enumerate(data) if not mask[i])


mask = unchanging_mask(example)
result = select_columns(example, mask)
print(result)

It works by determining a mask first, which tells it whether or not a column always has the same value for each element at the deepest level of the nesting.

It then applies that mask to all the values a the deepest level of the the nested structure, preserving type.

It does assume the elements at the deepest level all have the same size.

CodePudding user response:

Here is an approach using list comprehension and flattening of the data.

First, we make sure the list of lists is flattened using a simple recursion

DATA = [
    [[8.0, 16.275953626987484, 5.923962654520423]],
    [[8.0, 3.0076746636087575, 17.05737063904884]],
    [[8.0, 3.0076746636087575, 17.05737063904884]],
    [[8.0, -13.268278963378728, 11.133407984528427]],
    [[8.0, -13.268278963378728, 11.133407984528427]],
    [[8.0, -16.275953626987487, -5.923962654520416]],
    [[8.0, -16.275953626987487, -5.923962654520416]],
    [[8.0, -3.0076746636087694, -17.057370639048848]],
    [[8.0, -3.0076746636087694, -17.057370639048848]],
    [[8.0, 13.268278963378716, -11.133407984528437]],
    [[8.0, 13.268278963378716, -11.133407984528437]],
    [[8.0, 16.275953626987484, 5.923962654520423]],
]

def is_last_layer(x: list) -> bool:
    return isinstance(x, list) and len(x) == 3 and all(isinstance(_, (float, int)) for _ in x)


def get_flattened_list(data: list, res: list = None) -> list:
    res = res or []
    for x in data:
        if is_last_layer(x):
            res.append(x)
        else:
            res = get_flattened_list(data=x, res=res)
    return res


FLAT_DATA = get_flattened_list(DATA)
print(FLAT_DATA)
# [
#   [8.0, 16.275953626987484, 5.923962654520423], 
#   [8.0, 3.0076746636087575, 17.05737063904884], 
#   [8.0, 3.0076746636087575, 17.05737063904884], 
#   [8.0, -13.268278963378728, 11.133407984528427], 
#   [8.0, -13.268278963378728, 11.133407984528427], 
#   [8.0, -16.275953626987487, -5.923962654520416], 
#   [8.0, -16.275953626987487, -5.923962654520416], 
#   [8.0, -3.0076746636087694, -17.057370639048848], 
#   [8.0, -3.0076746636087694, -17.057370639048848], 
#   [8.0, 13.268278963378716, -11.133407984528437], 
#   [8.0, 13.268278963378716, -11.133407984528437], 
#   [8.0, 16.275953626987484, 5.923962654520423]
# ]

Now, we can use a list comprehension to retrieve the indices to remove

index_to_remove = [i for i in range(3) if len({_[i] for _ in FLAT_DATA}) == 1]
print(index_to_remove)
# [0]

To now obtain a result with same shape as the original object, one can use a recursion again

def get_filtered_data(index_rm: list, data: list = None, res: list = None) -> list:
    res = res or []
    for x in data:
        if is_last_layer(x):
            return [_ for i, _ in enumerate(x) if i not in index_rm]
        else:
            res.append([get_filtered_data(index_rm=index_rm, data=x, res=res)])
    return res


FILTERED_DATA = get_filtered_data(index_rm=index_to_remove, data=DATA)
print(FILTERED_DATA)
# [
#   [[16.275953626987484, 5.923962654520423]], 
#   [[3.0076746636087575, 17.05737063904884]], 
#   [[3.0076746636087575, 17.05737063904884]], 
#   [[-13.268278963378728, 11.133407984528427]], 
#   [[-13.268278963378728, 11.133407984528427]], 
#   [[-16.275953626987487, -5.923962654520416]], 
#   [[-16.275953626987487, -5.923962654520416]], 
#   [[-3.0076746636087694, -17.057370639048848]], 
#   [[-3.0076746636087694, -17.057370639048848]], 
#   [[13.268278963378716, -11.133407984528437]], 
#   [[13.268278963378716, -11.133407984528437]], 
#   [[16.275953626987484, 5.923962654520423]]
# ]
  • Related