I'm trying to solve the problem of finding the minimum number of perfect squares (i.e. 1, 2, 4, 9..) that sum up to n
Here's my recursive top-down approach:
import math
class Solution:
def numSquares(self, n: int) -> int:
dp = [math.inf] * (n 1)
dp[0] = 0
dp[1] = 1
def solve(n: int):
if dp[n] != math.inf:
return dp[n]
for i in range(n, 0, -1):
if n - i * i >= 0:
sol = solve(n - i*i)
dp[i] = min(1 sol, dp[i])
return dp[n]
solve(n)
print(dp)
Solution().numSquares(12)
I can't put my finger on why this code doesn't yield the correct result. Can you please help my find the bug?
Thanks!
CodePudding user response:
class Solution:
def numSquares(self, n: int) -> int:
dp = [math.inf] * (n 1)
dp[0] = 0
dp[1] = 1
def solve(n: int):
if dp[n] != math.inf:
return dp[n]
sol = math.inf
for i in range(n, 0, -1):
if n - i * i >= 0:
sol = min(sol, solve(n-i*i))
dp[n] = sol 1
return dp[n]
return solve(n)
The above is the corrected version of your solution but it's still too slow since you check every number even if their squares are obviously greater than n
. Below is an optimized version that passes LC's time limits:
class Solution:
def numSquares(self, n: int) -> int:
dp = [math.inf] * (n 1)
dp[0] = 0
dp[1] = 1
def solve(n: int):
if dp[n] != math.inf:
return dp[n]
sol = math.inf
for i in range(1, n):
if n - i * i >= 0:
sol = min(sol, solve(n-i*i))
else:
break
dp[n] = sol 1
return dp[n]
return solve(n)
But there is still room for optimization. A little bit:
class Solution:
def numSquares(self, n: int) -> int:
dp = [math.inf] * (n 1)
dp[0] = 0
dp[1] = 1
def solve(n: int):
if dp[n] != math.inf:
return dp[n]
sol = math.inf
for i in range(1, math.floor(n**(1/2)) 1):
sol = min(sol, solve(n-i*i))
dp[n] = sol 1
return dp[n]
return solve(n)
Another way to solve this is using a BFS approach. Imagine all the paths to get to n
as a tree where the leafs are of value n
and each movement to a neighboring node has a cost of one. In a BFS the first hit is the best hit since the costs are all equal:
class Solution:
def numSquares(self, n: int) -> int:
deq = collections.deque([n])
steps = 1
while deq:
n = len(deq)
for _ in range(n):
node = deq.popleft()
for i in range(1, floor(node**(1/2)) 1):
num = i**2
if num == node:
return steps
deq.append(node - num)
steps = 1
As pointed out by Dmitry Bychenko, you can solve that using Lagrange's theorem. Here's a nice write up and python solution about it.