Here is the link for problem description: Flatten Binary Tree to Linked List has:
# class TreeNode(object):
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution(object):
def flatten(self, root):
"""
:type root: TreeNode
:rtype: None Do not return anything, modify root in-place instead.
"""
The solution is:
class Solution:
def flattenTree(self, node):
# Handle the null scenario
if not node:
return None
# For a leaf node, we simply return the
# node as is.
if not node.left and not node.right:
return node
# Recursively flatten the left subtree
leftTail = self.flattenTree(node.left)
# Recursively flatten the right subtree
rightTail = self.flattenTree(node.right)
# If there was a left subtree, we shuffle the connections
# around so that there is nothing on the left side
# anymore.
if leftTail:
leftTail.right = node.right
node.right = node.left
node.left = None
# We need to return the "rightmost" node after we are
# done wiring the new connections.
return rightTail if rightTail else leftTail
def flatten(self, root: TreeNode) -> None:
"""
Do not return anything, modify root in-place instead.
"""
self.flattenTree(root)
I don't understand this block of code:
if leftTail:
leftTail.right = node.right (step 1)
node.right = node.left (step 2)
node.left = None
For example, if the binary tree input is [1, 2, 3]
, the leftTail
after step 1 will be: [2, null, 3]
. My naive thought is after step 2, the tree becomes [1, null, 3]
but to my surprise, it becomes: [1,null,2,null,3]
.
CodePudding user response:
Suppose your example with tree [1, 2, 3]
:
1 (node)
/ \
2 3
And lets check what was done by every step:
if leftTail:
leftTail.right = node.right (step 1)
node.right = node.left (step 2)
node.left = None (step 3)
Step 1:
1 (node)
/ \
2 3
\
3 (same as above)
Step 2:
1 (node)
/ \
2 2 (same as left)
\ \
3 3
Step 3:
1 (node)
\
2
\
3
So, [1, null, 2, null, 3]
is achieved.