I'm playing with ND4J basics to come up to speed with its linear algebra capabilities.
I'm running on a Macbook Pro using nd4j-api
and nd4j-native
dependencies version 1.0.0-M2.1, Open JDK version 17, Kotlin 1.7.20, and IntelliJ 2022.2.2 Ultimate Edition.
I'm writing JUnit 5 tests to perform simple operations: add, subtract, multiply, and divide a 2x2 matrix and a scalar. All are successful and pass just fine.
I was successful at adding a 1x2 row vector to the first and second rows of a 2x2 matrix:
@ParameterizedTest
@ValueSource(longs = [0L, 1L])
fun `add a row vector to each row in a matrix`(rowIndex : Long) {
// setup
val a = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0, 4.0), intArrayOf(2, 2))
val row = Nd4j.create(doubleArrayOf(11.0, 13.0), intArrayOf(2))
// Adds the row vector to all rows
val expected = arrayOf(
Nd4j.create(doubleArrayOf(12.0, 15.0, 3.0, 4.0), intArrayOf(2, 2)),
Nd4j.create(doubleArrayOf(1.0, 2.0, 14.0, 17.0), intArrayOf(2, 2)))
// exercise
a.getRow(rowIndex).addi(row)
// assert
Assertions.assertEquals(expected[rowIndex.toInt()], a)
}
I try to duplicate the trick by adding a 2x1 column vector to the 2x2 matrix:
@Test
fun `add a column vector to the second column of a matrix`() {
// setup
val a = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0, 4.0), intArrayOf(2, 2))
val col = Nd4j.create(doubleArrayOf(11.0, 13.0), intArrayOf(2, 1))
// Adds the row vector to all rows
val expected = Nd4j.create(doubleArrayOf(1.0, 2.0, 14.0, 17.0), intArrayOf(2, 2))
// exercise
a.getColumn(1).addi(col)
// assert
Assertions.assertEquals(expected, a)
}
I get an error saying that the array shapes don't match:
java.lang.IllegalStateException: Cannot perform in-place operation "addi": result array shape does not match the broadcast operation output shape: [2].addi([2, 1]) != [2].
In-place operations like x.addi(y) can only be performed when x and y have the same shape, or x and y are broadcastable with x.shape() == broadcastShape(x,y)
I have not been successful in figuring out why. Can anyone see where I've gone wrong and suggest a solution?
CodePudding user response:
We have a function for that already. For matrix column use addiColumnVector.
For views:
Ensure that you have the exact same shape with the reshape. Do that with some vector:
INDArray vec = Nd4j.zeros(5);
vec.getColumn(0).addi(vec.reshape(5,1));
CodePudding user response:
This solution did the trick. Thanks to Adam Gibson for pointing out the need for reshape:
@Test
fun `add a column vector to the second column of a matrix`() {
// setup
val a = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0, 4.0), intArrayOf(2, 2))
val col = Nd4j.create(doubleArrayOf(11.0, 13.0), intArrayOf(2, 1))
// Adds the column vector to the 2nd column of the 2x2 matrix
val expected = Nd4j.create(doubleArrayOf(1.0, 13.0, 3.0, 17.0), intArrayOf(2, 2))
// exercise
a.getColumn(1).reshape(intArrayOf(2, 1)).addi(col)
// assert
Assertions.assertEquals(expected, a)
}