Home > Enterprise >  Differences in the standard error of parameters computed via the hessian inverse and via QR decompos
Differences in the standard error of parameters computed via the hessian inverse and via QR decompos

Time:09-23

I have a solved a nonlinear optimization and I am trying to compute the standard error of the parameters obtained. I found two options: One uses the fractional covariance matrix formed from the inverse of the hessian while the other uses QR decomposition. However both errors are not the same. The standard error obtained via QR is less than that obtained from the hessian inverse. I am at a loss as to how and why both approaches differ and would like to understand better which is the more correct way. Below is the working example

# import libraries
import jax
import jax.numpy as jnp  
import jaxopt
jax.config.update("jax_enable_x64", True)


# Create data
F =  jnp.asarray([1.00e-01, 1.30e-01, 1.69e-01, 2.20e-01, 2.86e-01, 3.71e-01,
             4.83e-01, 6.27e-01, 8.16e-01, 1.06e 00, 1.38e 00, 1.79e 00,
             2.33e 00, 3.03e 00, 3.94e 00, 5.12e 00, 6.65e 00, 8.65e 00,
             1.12e 01, 1.46e 01, 1.90e 01, 2.47e 01, 3.21e 01, 4.18e 01,
             5.43e 01, 7.06e 01, 9.17e 01, 1.19e 02, 1.55e 02, 2.02e 02,
             2.62e 02, 3.41e 02, 4.43e 02, 5.76e 02, 7.48e 02, 9.73e 02,
             1.26e 03, 1.64e 03, 2.14e 03, 2.78e 03, 3.61e 03, 4.70e 03,
             6.10e 03, 7.94e 03, 1.03e 04, 1.34e 04, 1.74e 04, 2.27e 04,
             2.95e 04, 3.83e 04, 4.98e 04, 6.47e 04, 8.42e 04, 1.09e 05],dtype=jnp.float64)

ydata = jnp.asarray([45.1  -1.09j, 47.5  -1.43j, 46.8  -1.77j, 46.2  -2.29j,
             46.2  -2.97j, 47.2  -3.8j , 47.   -4.85j, 45.1  -5.99j,
             45.8  -7.33j, 42.3  -9.05j, 42.6 -10.2j , 36.5 -10.8j ,
             34.5 -11.2j , 32.1 -10.2j , 30.   -9.18j, 29.4  -8.j  ,
             27.3  -6.64j, 26.7  -5.18j, 25.3  -4.12j, 25.4  -3.26j,
             25.2  -2.51j, 24.9  -1.94j, 24.9  -1.64j, 25.4  -1.35j,
             25.5  -1.24j, 24.8  -1.1j , 24.7  -1.03j, 23.9  -1.04j,
             25.2  -1.1j , 24.9  -1.27j, 25.   -1.46j, 25.4  -1.65j,
             24.4  -1.98j, 24.5  -2.34j, 24.5  -2.91j, 23.8  -3.47j,
             22.9  -4.13j, 22.3  -4.91j, 20.9  -5.66j, 20.3  -6.03j,
             18.4  -6.96j, 17.6  -7.24j, 16.5  -7.74j, 14.3  -7.42j,
             12.7  -7.17j, 11.2  -6.76j,  9.85 -5.89j,  8.68 -5.38j,
              7.92 -4.53j,  7.2  -3.83j,  6.81 -3.2j ,  6.65 -2.67j,
              6.11 -2.16j,  5.86 -1.77j], dtype=jnp.complex128)

sigma = jnp.asarray([45.11316992, 47.52152039, 46.83345919, 46.25671951,
             46.29536586, 47.35271903, 47.24957672, 45.49604488,
             46.38285136, 43.25728262, 43.8041094 , 38.06428772,
             36.27244133, 33.68159735, 31.37311588, 30.46900064,
             28.09590006, 27.19783815, 25.63326745, 25.6083502 ,
             25.32469348, 24.97545996, 24.95394959, 25.43585068,
             25.53013122, 24.82438317, 24.72146638, 23.92261691,
             25.22399651, 24.93236651, 25.04259571, 25.4535361 ,
             24.48020425, 24.61149325, 24.67221312, 24.05162988,
             23.26944133, 22.83414329, 21.65284277, 21.17665932,
             19.67235624, 19.03096424, 18.22519136, 16.11044382,
             14.58420036, 13.08195704, 11.47669813, 10.21209087,
              9.12399584,  8.15529889,  7.52436708,  7.16598912,
              6.48056325,  6.12147858], dtype=jnp.float64)

# Define Model
def rrpwrcwo(p, x):
    w = 2*jnp.pi*x
    s = 1j*w
    Rs = p[0]
    Qh = p[1]
    nh = p[2]
    Rct = p[3]
    C1 = p[4]
    R1 = p[5]
    Y1 = s*C1   1/R1
    Z1 = 1/Y1
    Zct = Rct   Z1
    Ydl = (s**nh)*Qh
    Yin = Ydl   1/Zct
    Zin = 1/Yin
    Z = Rs   Zin
    return jnp.concatenate((Z.real, Z.imag),axis = 0)


# Define cost function
def obj_fun(p, x, y, yerr, lb, ub):
    ndata = len(x)
    dof = (2*ndata-(len(p)))
    y_concat = jnp.concatenate([y.real, y.imag], axis = 0)
    sigma = jnp.concatenate([yerr,yerr], axis = 0)
    y_model = rrpwrcwo(p, x)
    chi_sqr = (1/dof)*(jnp.sum(jnp.abs((1/sigma**2) * (y_concat - y_model)**2)))
    return chi_sqr

# Define minimization function
def cnls(p, x, y, yerr, lb, ub):
    """
    """
    solver = jaxopt.ScipyMinimize(method = 'BFGS', fun= obj_fun)
    sol = solver.run(p, x, y, yerr, lb, ub)
    # Compute popt
    return sol

# Define initial values and bounds
p0 = jnp.asarray([5, 0.000103, 1, 20, 0.001, 20])

lb = jnp.zeros(len(p0))
lb=lb.at[2].set(0.1)
ub = jnp.full((len(p0),),jnp.inf)
ub.at[2].set(1.01)

# Run optimization
res = cnls(p0, F, ydata, sigma, lb, ub)
popt = res.params
# DeviceArray([5.26589219e 00, 7.46288724e-06, 8.27089860e-01,
#              1.99066599e 01, 3.40764484e-03, 2.19277541e 01],dtype=float64)

# Get the weighted residual mean square
chisqr = res.state.fun_val
# 0.00020399

# Method 1: Error computation using the fractional covariance matrix

# get hessian matrix from parameters at the minimum
hess = jax.jacfwd(jax.jacrev(obj_fun))(popt, F, ydata, sigma, lb, ub)

# Take the hessian inv
hess_inv = jnp.linalg.inv(hess)

# Form the fractional covariance matrix
cov_mat = hess_inv * chisqr

# Compute standard error of the parameters
perr = jnp.sqrt(jnp.diag(cov_mat))
perr
# DeviceArray([4.60842608e-01, 3.64957208e-06, 4.59190021e-02,
#              8.29162454e-01, 4.47488639e-04, 1.49346052e 00], dtype=float64)


# Method 2: Error Computation using QR Decomposition

# Compute gradient of function (model) with respect to the parameters
grads = jax.jacfwd(rrpwrcwo)(popt, F)
gradsre = grads[:len(F)]
gradsim = grads[len(F):]

# Form diagonal weight matrices
rtwre = jnp.diag((1/sigma))
rtwim = jnp.diag((1/sigma))

vre = rtwre@gradsre
vim = rtwim@gradsim

# Compute QR decomposition
Q1, R1 = jnp.linalg.qr(jnp.concatenate([vre,vim], axis = 0))

# Compute inverse of R1
invR1 = jnp.linalg.inv(R1)

# Compute standard error of the parameters
perr = jnp.linalg.norm(invR1, axis=1)*jnp.sqrt(chisqr)
perr

# DeviceArray([6.48631283e-02, 5.14577571e-07, 6.48070403e-03,
#              1.16523404e-01, 6.28434098e-05, 2.09238133e-01],dtype=float64)


CodePudding user response:

I believe the issue is you are computing the hessian of a chi-square per degree of freedom, when you should be computing the hessian of chi-square. If you change this line:

chi_sqr = (1/dof)*(jnp.sum(jnp.abs((1/sigma**2) * (y_concat - y_model)**2)))

to this:

chi_sqr = 0.5 * (jnp.sum(jnp.abs((1/sigma**2) * (y_concat - y_model)**2)))

then the two approaches return approximately the same results.

  • Related