I have a complex function that returns a 2D array, to make it simple, let's consider the following:
import numpy as np
def get_array():
a = np.array([[0, 9, 5]])
return a
Is there a numpy command that allows me to retrieve the single row automatically instead of doing the following?
def get_array():
a = np.array([[0, 9, 5]])
if a.shape[0] == 1:
return a[0]
else:
return a
return a
Thanks a lot!
CodePudding user response:
np.squeeze(a)
will remove any unit sized axes from a
.
>>> np.squeeze([1, 2, 3])
array([1, 2, 3])
>>> np.squeeze([[1, 2, 3]])
array([1, 2, 3])
>>> np.squeeze([[1, 2, 3], [4, 5, 6]])
array([[1, 2, 3],
[4, 5, 6]])
It might not be exactly what you want though, as also:
>>> np.squeeze([[1], [2]])
array([1, 2])
CodePudding user response:
On top of @orlp answer, if you are conscious of performance, ravel
seems to be faster than squeeze: