package org.nd4j.linalg;
/**
* Created by susaneraly on 8/26/16.
*/
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
public class MmulBug {
@Test
public void simpleTest() throws Exception {
INDArray m1 = Nd4j.create(new double[][] {{1.0}, {2.0}, {3.0}, {4.0}});
m1 = m1.reshape(2, 2);
INDArray m2 = Nd4j.create(new double[][] {{1.0, 2.0, 3.0, 4.0},});
m2 = m2.reshape(2, 2);
m2.setOrder('f');
//mmul gives the correct result
INDArray correctResult;
correctResult = m1.mmul(m2);
System.out.println("================");
System.out.println(m1);
System.out.println(m2);
System.out.println(correctResult);
System.out.println("================");
INDArray newResult = Nd4j.zeros(correctResult.shape(), 'c');
m1.mmul(m2, newResult);
assertEquals(correctResult, newResult);
//But not so mmuli (which is somewhat mixed)
INDArray target = Nd4j.linspace(1, 4, 4).reshape(2, 2);
target = m1.mmuli(m2, m1);
assertEquals(true, target.equals(correctResult));
assertEquals(true, m1.equals(correctResult));
}
}