Simple binary tree DFS recursion code is stopping early and not walking full tree

I’m trying to walk down a binary tree, look for a target TreeNode, and then return that Node when I find it. My code is:

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution:
    def findTarget(self, root: TreeNode, target: TreeNode) -> TreeNode:
        # base case 
        if root.val == target.val:
            return root
        if root.left:
            return self.findTarget(root.left, target)        
        if root.right:
            return self.findTarget(root.right, target)
        return None

But for a tree that looks like this where I am looking for node 5: enter image description here

My code is only walking down nodes 1,2,4,8. What am I missing? I thought once it reached node 8, it would see that it has no left node nor a right node and return None back to Node 4’s left node call. After that, Node 4 would then look to the right of the root which would be 9. But it is just stopping at 8?

I’ve seen other solutions online, I am just curious what I am getting wrong in this specific case here.

Answer

Your first branch is if root.left: return .... That means even if there’s a right subtree as well as a left subtree, you’ll never explore the right subtree even if the left subtree has no matching results.

Make this into “or” logic: if the left child recursive call found a matching node, you can return it immediately. If no result was found in the left subtree, don’t return just yet. Explore the right subtree and see if you find anything.

As an aside, I think checking characteristics of children nodes is a recursion antipattern that is inelegant and tends to create confusion and subtle bugs. Check the null state of the current node before operating on it, not the null states of children. Let the recursion handle those checks.

class Solution:
    def findTarget(self, root: TreeNode, val: int) -> TreeNode:
        if root:
            if root.val == val:
                return root
            
            return (self.findTarget(root.left, val) or 
                    self.findTarget(root.right, val))