Home > Enterprise >  Python curve_fit with multiple independent variables (in order to get the value for some unknow para
Python curve_fit with multiple independent variables (in order to get the value for some unknow para

Time:06-25

Is there a way to use curve_fit to fit for a function with multiple independent variables like below?

I try to get the value for a1, b1, c1, a2, b2, c2, a3, b3, c3 and d while x1, x2, x3 and y1 (dependent variable) are all known. I want to optimize these values to minimize my error by using scipy.optimize. Be noted in real situation, for x1, x2, x3 and y1, I have more than hundred data points.

Or if there is a better way or more appropriate way to get the value for a1, b1, c1, a2, b2, c2, a3, b3, c3 and d?

import numpy as np
from scipy.optimize import curve_fit

x1 = [3,2,1]
x2 = [3,4,2]
x3 = [1,2,4]
y1 = [5,7,9]

def func(x1, x2, a1, b1, c1, a2, b2, c2, d):
    return (a1*x1**3 b1*x1**2 c1*x1)  (a2*x2**3 b2*x2**2 c2*x2)    d

def func2(x1, x2, x3, a1, b1, c1, a2, b2, c2, a3, b3, c3, d):
    return (a1*x1**3 b1*x1**2 c1*x1)  (a2*x2**3 b2*x2**2 c2*x2)   (a3*x3**3 b3*x3**2 c3*x3)   d

CodePudding user response:

You need to pass x1 and x2 in one object, see description of xdata in docs for curve_fit:

The independent variable where the data is measured. Should usually be an M-length sequence or an (k,M)-shaped array for functions with k predictors, but can actually be any object.

Example:

import numpy as np
from scipy.optimize import curve_fit

# generate sample data for a1, b1, c1, a2, b2, c2, d = 1, 2, 3, 4, 5, 6, 7
np.random.seed(0)
x1 = np.random.randint(0, 100, 100)
x2 = np.random.randint(0, 100, 100)
y1 = (1 * x1**3   2 * x1**2   3 * x1)   (4 * x2**3   5 * x2**2   6 * (x2 np.random.randint(-1, 1, 100)))   7

def func(x, a1, b1, c1, a2, b2, c2, d):
    return (a1 * x[0]**3   b1 * x[0]**2   c1 * x[0])   (a2 * x[1]**3   b2 * x[1]**2   c2 * x[1])   d

popt, _ = curve_fit(func, np.stack([x1, x2]), y1)

Result:

array([1.00000978, 1.99945039, 2.97065876, 4.00001038, 4.99920966,
       5.97424668, 6.71464229])
  • Related