Home > database >  Python fitting a curve spitting TypeError: only size-1 arrays can be converted to Python scalars
Python fitting a curve spitting TypeError: only size-1 arrays can be converted to Python scalars

Time:10-30

I am trying to fit a curve, this is my code:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.optimize import curve_fit
import math
vector = np.vectorize(np.int_)

x_data = np.array([-5.0, -4, -3, -2, -1, 0, 1, 2, 3, 4])
x1 = vector(x_data)


y_data = np.array([77, 81, 171, 303, 409, 302, 139, 115, 88, 89])
y1 = vector(y_data)
def model_f(x, a, b, c, d):
    return a/(math.sqrt(2*math.pi*d**2)) * math.exp( -(x-c)**2/(2*d**2) )   b

popt, pcov = curve_fit(model_f, x1, y1, p0=[3,2,-16, 2])

This is the error I get:

TypeError: only size-1 arrays can be converted to Python scalars

From what I understand math.sqrt() and math.exp() are causing the problem. I thought that vectorizing the arrays would fix it. Am I missing something?

CodePudding user response:

Don't call vectorize, and don't use the math module; use np. equivalents. Also your initial values were way off and produced a degenerate solution. Either don't provide initial values at all, or provide ones in the ballpark of what you know to be needed:

import numpy as np
from scipy.optimize import curve_fit


def model_f(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
    return a/d/np.sqrt(2*np.pi) * np.exp(-((x-c)/d)**2 / 2)   b


x1 = np.arange(-5, 5)
y1 = np.array((77, 81, 171, 303, 409, 302, 139, 115, 88, 89))
popt, _ = curve_fit(model_f, x1, y1, p0=(1000, 100, -1, 1))
print('Parameters:', popt)
print('Ideal vs. fit y:')
print(np.stack((y1, model_f(x1, *popt))))
Parameters: [916.86287196  85.71611182  -1.03419295   1.13753421]
Ideal vs. fit y:
[[ 77.          81.         171.         303.         409.
  302.         139.         115.          88.          89.        ]
 [ 86.45393326  96.46010314 157.95219577 309.95808531 407.12196914
  298.41481145 150.70663751  94.88484707  86.3133437   85.73407366]]
  • Related