package org.nd4j.linalg.factory; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import static org.junit.Assert.assertEquals; /** */ @RunWith(Parameterized.class) public class Nd4jTest extends BaseNd4jTest { public Nd4jTest(Nd4jBackend backend) { super(backend); } @Test public void testRandShapeAndRNG() { INDArray ret = Nd4j.rand(new int[] {4, 2}, Nd4j.getRandomFactory().getNewRandomInstance(123)); INDArray ret2 = Nd4j.rand(new int[] {4, 2}, Nd4j.getRandomFactory().getNewRandomInstance(123)); assertEquals(ret, ret2); } @Test public void testRandShapeAndMinMax() { INDArray ret = Nd4j.rand(new int[] {4, 2}, -0.125f, 0.125f, Nd4j.getRandomFactory().getNewRandomInstance(123)); INDArray ret2 = Nd4j.rand(new int[] {4, 2}, -0.125f, 0.125f, Nd4j.getRandomFactory().getNewRandomInstance(123)); assertEquals(ret, ret2); } @Test public void testCreateShape() { INDArray ret = Nd4j.create(new int[] {4, 2}); assertEquals(ret.length(), 8); } @Test public void testGetRandom() { Random r = Nd4j.getRandom(); Random t = Nd4j.getRandom(); assertEquals(r, t); } @Test public void testGetRandomSetSeed() { Random r = Nd4j.getRandom(); Random t = Nd4j.getRandom(); assertEquals(r, t); r.setSeed(123); assertEquals(r, t); } @Test public void testOrdering() { INDArray fNDArray = Nd4j.create(new float[] {1f}, NDArrayFactory.FORTRAN); assertEquals(NDArrayFactory.FORTRAN, fNDArray.ordering()); INDArray cNDArray = Nd4j.create(new float[] {1f}, NDArrayFactory.C); assertEquals(NDArrayFactory.C, cNDArray.ordering()); } @Override public char ordering() { return 'c'; } @Test public void testMean() { INDArray data = Nd4j.create(new double[] {4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8, 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4.}, new int[] {2, 2, 4, 4}); INDArray actualResult = data.mean(0); INDArray expectedResult = Nd4j.create(new double[] {3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6., 3., 3., 3., 3., 6., 6., 6., 6.}, new int[] {2, 4, 4}); assertEquals(getFailureMessage(), expectedResult, actualResult); } @Test public void testVar() { INDArray data = Nd4j.create(new double[] {4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8, 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4., 2., 2., 2., 2., 4., 4., 4., 4.}, new int[] {2, 2, 4, 4}); INDArray actualResult = data.var(false, 0); INDArray expectedResult = Nd4j.create(new double[] {1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4., 1., 1., 1., 1., 4., 4., 4., 4.}, new int[] {2, 4, 4}); assertEquals(getFailureMessage(), expectedResult, actualResult); } @Test public void testVar2() { INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray var = arr.var(false, 0); assertEquals(Nd4j.create(new double[] {2.25, 2.25, 2.25}), var); } }