Home > Software design >  Finding the minimum number of perfect squares to sum up to n
Finding the minimum number of perfect squares to sum up to n

Time:03-09

I'm trying to solve the problem of finding the minimum number of perfect squares (i.e. 1, 2, 4, 9..) that sum up to n

Here's my recursive top-down approach:

import math

class Solution:
    def numSquares(self, n: int) -> int:
        
        dp = [math.inf] * (n   1)
        dp[0] = 0
        dp[1] = 1 

        def solve(n: int):
            if dp[n] != math.inf:
                return dp[n]

            for i in range(n, 0, -1):
                if n - i * i >= 0:
                    sol = solve(n - i*i)
                    dp[i] = min(1   sol, dp[i])

            return dp[n]

        solve(n)
        print(dp)

Solution().numSquares(12)

I can't put my finger on why this code doesn't yield the correct result. Can you please help my find the bug?

Thanks!

CodePudding user response:

class Solution:
    def numSquares(self, n: int) -> int:
        dp = [math.inf] * (n   1)
        dp[0] = 0
        dp[1] = 1
        
        def solve(n: int):
            if dp[n] != math.inf:
                return dp[n]
            sol = math.inf
            for i in range(n, 0, -1):
                if n - i * i >= 0:
                    sol = min(sol, solve(n-i*i))
            dp[n] = sol 1
            
            return dp[n]
        
        return solve(n)

The above is the corrected version of your solution but it's still too slow since you check every number even if their squares are obviously greater than n. Below is an optimized version that passes LC's time limits:

class Solution:
    def numSquares(self, n: int) -> int:
        dp = [math.inf] * (n   1)
        dp[0] = 0
        dp[1] = 1
        
        def solve(n: int):
            if dp[n] != math.inf:
                return dp[n]
            sol = math.inf
            for i in range(1, n):
                if n - i * i >= 0:
                    sol = min(sol, solve(n-i*i))
                else:
                    break
            dp[n] = sol 1
            
            return dp[n]
        
        return solve(n)

But there is still room for optimization. A little bit:

class Solution:
    def numSquares(self, n: int) -> int:
        dp = [math.inf] * (n   1)
        dp[0] = 0
        dp[1] = 1
        
        def solve(n: int):
            if dp[n] != math.inf:
                return dp[n]
            sol = math.inf
            for i in range(1, math.floor(n**(1/2)) 1):
                sol = min(sol, solve(n-i*i))
            dp[n] = sol 1
            
            return dp[n]
        
        return solve(n)

Another way to solve this is using a BFS approach. Imagine all the paths to get to n as a tree where the leafs are of value n and each movement to a neighboring node has a cost of one. In a BFS the first hit is the best hit since the costs are all equal:

class Solution:
    def numSquares(self, n: int) -> int:
        deq = collections.deque([n])
        steps = 1
        
        while deq:
            n = len(deq)
            
            for _ in range(n):
                node = deq.popleft()
                for i in range(1, floor(node**(1/2)) 1):
                    num = i**2
                    if num == node:
                        return steps
                    deq.append(node - num)
            steps  = 1

As pointed out by Dmitry Bychenko, you can solve that using Lagrange's theorem. Here's a nice write up and python solution about it.

  • Related