Home > Software engineering >  Fail to overwrite a 2D numpy.ndarray in a loop
Fail to overwrite a 2D numpy.ndarray in a loop

Time:12-07

I found my program failed to overwrite an np.ndarray (the X variable) in the for loop by assignment statement like "X[i] = another np.ndarray with matched shape". I have no idea how this could happen...

Codes:

import numpy as np
def qr_tridiagonal(T: np.ndarray):
    m, n = T.shape
    X = T.copy()
    Qt = np.identity(m)
    for i in range(n-1):
        ai = X[i, i]
        ak = X[i 1, i]
        c = ai/(ai**2   ak**2)**.5
        s = ak/(ai**2   ak**2)**.5
        # Givens rotation
        tmp1 = c*X[i]   s*X[i 1]
        tmp2 = c*X[i 1] - s*X[i]
        print("tmp1 before:", tmp1)
        print("X[i] before:", X[i])
        X[i] = tmp1
        X[i 1] = tmp2
        print("tmp1 after:", tmp1)
        print("X[i] after:", X[i])
        print()

        print(X)

    return Qt.T, X


A = np.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 1, 1]])
Q, R = qr_tridiagonal(A)

Output (the first 4 lines):

tmp1 before: [1.41421356 1.41421356 0.70710678 0.        ]
X[i] before: [1 1 0 0]
tmp1 after: [1.41421356 1.41421356 0.70710678 0.        ]
X[i] after: [1 1 0 0]

Though X[i] is assigned by tmp1, the values in the array X[i] or X[i, :] remain unchanged. Hope somebody help me out....

Other info: the above is a function to compute QR factorization for tridiagonal matrices using Givens Rotation.

I did check that assigning constant values to X[i] work, e.g. X[i] = 10 then the printed results fit this statement. But if X[i] = someArray then in my codes it would fail. I am not sure whether this is a particular issue triggered by the algorithm I was implementing in the above codes, because such scenarios never happen before.

I did try to install new environments using conda to make sure that my python is not problematic. The above strange outputs should be able to re-generate on other devices.

CodePudding user response:

Many thanks to @hpaulj

It turns out to be a problem of datatype. The program is ok but the input datatype is int, which results in intermediate trancation errors.

A lesson learned: be aware of the dtype of np.ndarray!

  • Related