package org.nd4j.linalg.inverse;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.Pair;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.CheckUtil;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.util.List;
import static org.junit.Assert.*;
/**
* Created by agibsoncccc on 12/7/15.
*/
@RunWith(Parameterized.class)
public class TestInvertMatrices extends BaseNd4jTest {
public TestInvertMatrices(Nd4jBackend backend) {
super(backend);
}
@Test
public void testInverse() {
RealMatrix matrix = new Array2DRowRealMatrix(new double[][] {{1, 2}, {3, 4}});
RealMatrix inverse = MatrixUtils.inverse(matrix);
INDArray arr = InvertMatrix.invert(Nd4j.linspace(1, 4, 4).reshape(2, 2), false);
for (int i = 0; i < inverse.getRowDimension(); i++) {
for (int j = 0; j < inverse.getColumnDimension(); j++) {
assertEquals(arr.getDouble(i, j), inverse.getEntry(i, j), 1e-1);
}
}
}
@Test
public void testInverseComparison() {
List<Pair<INDArray, String>> list = NDArrayCreationUtil.getAllTestMatricesWithShape(10, 10, 12345);
for (Pair<INDArray, String> p : list) {
INDArray orig = p.getFirst();
orig.assign(Nd4j.rand(orig.shape()));
INDArray inverse = InvertMatrix.invert(orig, false);
RealMatrix rm = CheckUtil.convertToApacheMatrix(orig);
RealMatrix rmInverse = new LUDecomposition(rm).getSolver().getInverse();
INDArray expected = CheckUtil.convertFromApacheMatrix(rmInverse);
assertTrue(p.getSecond(), CheckUtil.checkEntries(expected, inverse, 1e-3, 1e-4));
}
}
@Test
public void testInvalidMatrixInversion() {
try {
InvertMatrix.invert(Nd4j.create(5, 4), false);
fail("No exception thrown for invalid input");
} catch (Exception e) {
}
try {
InvertMatrix.invert(Nd4j.create(5, 5, 5), false);
fail("No exception thrown for invalid input");
} catch (Exception e) {
}
try {
InvertMatrix.invert(Nd4j.create(1, 5), false);
fail("No exception thrown for invalid input");
} catch (Exception e) {
}
}
@Override
public char ordering() {
return 'c';
}
}