New to python and trying to determine how to prune a decision tree recursively by creating a new tree. If a node has a key in the "keys_to_prune" list, it and its descendents are not included in the new tree. Here is what I came up with:
def prune_tree(tree, keys_to_prune)
new_tree = Tree()
for child in tree.children:
if child.key not in keys_to_prune:
new_tree.children.append(child)
else:
prune_tree(child, keys_to_prune)
return new_tree
The tree object (Tree()) has .key, .value, and .children attributes. This code does not seem to work-- it appears to be checking empty nodes with no key and no value. Any assistance would be helpful!
CodePudding user response:
A dict is a kind of tree. Consider this example:
{
"a": {
"a": 1,
"b": 2,
{
"c": 3,
"d": 4,
}
},
"b": {
"a": {
"a": 1,
"b": 2,
},
"z": 99,
},
}
If keys_to_prune
contains only a
, the desired pruned-tree would look like this:
{
"b": {
"z": 99,
},
}
In other words, if you see a key that should be pruned, do nothing: don't add it to the new tree, and don't explore its subtrees. Otherwise, add the child to the new tree and explore the child for deeper pruning.
But the more subtle problem is that your recursive implementation is mixing two different strategies. On the one hand, it is assembling and returning a new tree. On the other hand, it is ignoring the return value and expecting recursive calls to deeper levels to mutate child levels in place. You need to choose one approach or the other. For example, if you stick with the idea of returning new trees, you might take an approach like this:
def prune_tree(tree, keys_to_prune)
new_tree = Tree()
for child in tree.children:
if child.key not in keys_to_prune:
new_child = prune_tree(child, keys_to_prune)
new_tree.children.append(new_child)
return new_tree
CodePudding user response:
You can easily filter stuff out using a list comprehension:
class Tree:
def __init__(self, k, v, children=None):
self.key = k
self.value = v
if children:
self.children = children
else:
self.children = []
def __repr__(self):
return str((self.key, self.value, self.children))
def pruned(t, keys):
return Tree(t.key, t.value, [pruned(c, keys) for c in t.children if c.key not in keys])
import random
def random_tree(max_depth=5, max_children=3):
nb_children = random.randrange(max_children 1) if max_depth > 0 else 0
return Tree(random.randrange(100), random.choice('abcdefghijklmnopqrstuvwxyz'), [random_tree(max_depth-1, max_children) for _ in range(nb_children)])
t = random_tree()
s = pruned(t, set(range(0, 100, 2))) # keep only odd keys
print(t)
# (9, 'd', [(93, 'n', [(17, 'j', [(23, 'h', [(65, 'o', [(80, 'u', []), (89, 'b', [])]), (97, 'e', [(64, 'l', []), (81, 'q', []), (51, 'o', [])])]), (63, 'y', [(65, 'c', [(27, 'h', []), (25, 'z', []), (30, 'k', [])]), (78, 'r', [])]), (65, 'z', [(15, 's', [(39, 'm', []), (69, 'a', [])])])])]), (63, 'a', [(67, 'r', [(18, 'j', [(97, 'i', [(88, 'n', [])]), (6, 'a', [(72, 'l', []), (81, 'n', [])])])]), (52, 'q', [(92, 'z', [(27, 'm', [(88, 'm', []), (39, 't', [])]), (28, 'u', [(40, 'c', [])])]), (1, 'z', [])]), (99, 'r', [(49, 'x', []), (46, 'h', [(38, 'm', [(72, 'v', [])]), (98, 'h', [(75, 'z', []), (67, 'q', [])]), (78, 'w', [(67, 'v', []), (78, 'p', [])])]), (5, 'a', [(1, 'y', []), (73, 'g', [(87, 'o', []), (63, 'i', [])]), (67, 'z', [])])])])])
print(s)
# (9, 'd', [(93, 'n', [(17, 'j', [(23, 'h', [(65, 'o', [(89, 'b', [])]), (97, 'e', [(81, 'q', []), (51, 'o', [])])]), (63, 'y', [(65, 'c', [(27, 'h', []), (25, 'z', [])])]), (65, 'z', [(15, 's', [(39, 'm', []), (69, 'a', [])])])])]), (63, 'a', [(67, 'r', []), (99, 'r', [(49, 'x', []), (5, 'a', [(1, 'y', []), (73, 'g', [(87, 'o', []), (63, 'i', [])]), (67, 'z', [])])])])])
I must say I'm a bit puzzled by the use of key
and value
in your trees. Perhaps we could remove class Tree
completely and use a dict
directly?
def pruned(t, keys):
return {k:pruned(v, keys) for k,v in t.items() if k not in keys}
import random
def random_tree(max_depth=5, max_children=3):
nb_children = random.randrange(max_children 1) if max_depth > 0 else 0
return {random.randrange(100): random_tree(max_depth-1, max_children) for _ in range(nb_children)}
t = random_tree()
s = pruned(t, set(range(0,100,2)))
print(t)
# {63: {37: {}}, 73: {4: {57: {74: {15: {}, 58: {}}, 4: {}}, 58: {4: {33: {}, 56: {}, 70: {}}, 49: {}, 38: {85: {}}}, 66: {22: {79: {}}, 1: {}, 2: {9: {}, 19: {}}}}}}
print(s)
# {63: {37: {}}, 73: {}}