Home > Software engineering >  Pruning a tree recursively depending on keys (python)
Pruning a tree recursively depending on keys (python)

Time:11-28

New to python and trying to determine how to prune a decision tree recursively by creating a new tree. If a node has a key in the "keys_to_prune" list, it and its descendents are not included in the new tree. Here is what I came up with:

def prune_tree(tree, keys_to_prune)
    new_tree = Tree()
    for child in tree.children:
        if child.key not in keys_to_prune:
            new_tree.children.append(child)
        else:
            prune_tree(child, keys_to_prune)
    return new_tree

The tree object (Tree()) has .key, .value, and .children attributes. This code does not seem to work-- it appears to be checking empty nodes with no key and no value. Any assistance would be helpful!

CodePudding user response:

A dict is a kind of tree. Consider this example:

{
    "a": {
        "a": 1,
        "b": 2,
        {
            "c": 3,
            "d": 4,
        }
    },
    "b": {
        "a": {
            "a": 1,
            "b": 2,
        },
        "z": 99,
    },
}

If keys_to_prune contains only a, the desired pruned-tree would look like this:

{
    "b": {
        "z": 99,
    },
}

In other words, if you see a key that should be pruned, do nothing: don't add it to the new tree, and don't explore its subtrees. Otherwise, add the child to the new tree and explore the child for deeper pruning.

But the more subtle problem is that your recursive implementation is mixing two different strategies. On the one hand, it is assembling and returning a new tree. On the other hand, it is ignoring the return value and expecting recursive calls to deeper levels to mutate child levels in place. You need to choose one approach or the other. For example, if you stick with the idea of returning new trees, you might take an approach like this:

def prune_tree(tree, keys_to_prune)
    new_tree = Tree()
    for child in tree.children:
        if child.key not in keys_to_prune:
            new_child = prune_tree(child, keys_to_prune)
            new_tree.children.append(new_child)
    return new_tree

CodePudding user response:

You can easily filter stuff out using a list comprehension:

class Tree:
    def __init__(self, k, v, children=None):
        self.key = k
        self.value = v
        if children:
            self.children = children
        else:
            self.children = []
    def __repr__(self):
        return str((self.key, self.value, self.children))

def pruned(t, keys):
    return Tree(t.key, t.value, [pruned(c, keys) for c in t.children if c.key not in keys])

import random

def random_tree(max_depth=5, max_children=3):
    nb_children = random.randrange(max_children 1) if max_depth > 0 else 0
    return Tree(random.randrange(100), random.choice('abcdefghijklmnopqrstuvwxyz'), [random_tree(max_depth-1, max_children) for _ in range(nb_children)])

t = random_tree()
s = pruned(t, set(range(0, 100, 2))) # keep only odd keys

print(t)
# (9, 'd', [(93, 'n', [(17, 'j', [(23, 'h', [(65, 'o', [(80, 'u', []), (89, 'b', [])]), (97, 'e', [(64, 'l', []), (81, 'q', []), (51, 'o', [])])]), (63, 'y', [(65, 'c', [(27, 'h', []), (25, 'z', []), (30, 'k', [])]), (78, 'r', [])]), (65, 'z', [(15, 's', [(39, 'm', []), (69, 'a', [])])])])]), (63, 'a', [(67, 'r', [(18, 'j', [(97, 'i', [(88, 'n', [])]), (6, 'a', [(72, 'l', []), (81, 'n', [])])])]), (52, 'q', [(92, 'z', [(27, 'm', [(88, 'm', []), (39, 't', [])]), (28, 'u', [(40, 'c', [])])]), (1, 'z', [])]), (99, 'r', [(49, 'x', []), (46, 'h', [(38, 'm', [(72, 'v', [])]), (98, 'h', [(75, 'z', []), (67, 'q', [])]), (78, 'w', [(67, 'v', []), (78, 'p', [])])]), (5, 'a', [(1, 'y', []), (73, 'g', [(87, 'o', []), (63, 'i', [])]), (67, 'z', [])])])])])

print(s)
# (9, 'd', [(93, 'n', [(17, 'j', [(23, 'h', [(65, 'o', [(89, 'b', [])]), (97, 'e', [(81, 'q', []), (51, 'o', [])])]), (63, 'y', [(65, 'c', [(27, 'h', []), (25, 'z', [])])]), (65, 'z', [(15, 's', [(39, 'm', []), (69, 'a', [])])])])]), (63, 'a', [(67, 'r', []), (99, 'r', [(49, 'x', []), (5, 'a', [(1, 'y', []), (73, 'g', [(87, 'o', []), (63, 'i', [])]), (67, 'z', [])])])])])

I must say I'm a bit puzzled by the use of key and value in your trees. Perhaps we could remove class Tree completely and use a dict directly?

def pruned(t, keys):
    return {k:pruned(v, keys) for k,v in t.items() if k not in keys}

import random

def random_tree(max_depth=5, max_children=3):
    nb_children = random.randrange(max_children 1) if max_depth > 0 else 0
    return {random.randrange(100): random_tree(max_depth-1, max_children) for _ in range(nb_children)}

t = random_tree()
s = pruned(t, set(range(0,100,2)))

print(t)
# {63: {37: {}}, 73: {4: {57: {74: {15: {}, 58: {}}, 4: {}}, 58: {4: {33: {}, 56: {}, 70: {}}, 49: {}, 38: {85: {}}}, 66: {22: {79: {}}, 1: {}, 2: {9: {}, 19: {}}}}}}

print(s)
# {63: {37: {}}, 73: {}}

  • Related