package org.nd4j.linalg.shape; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.*; /** * @author Adam Gibson */ @RunWith(Parameterized.class) public class ShapeTestsC extends BaseNd4jTest { public ShapeTestsC(Nd4jBackend backend) { super(backend); } @Test public void testSixteenZeroOne() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); assertEquals(4, baseArr.tensorssAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 5}, {9, 13}}); INDArray columnVectorSecond = Nd4j.create(new double[][] {{2, 6}, {10, 14}}); INDArray columnVectorThird = Nd4j.create(new double[][] {{3, 7}, {11, 15}}); INDArray columnVectorFourth = Nd4j.create(new double[][] {{4, 8}, {12, 16}}); INDArray[] assertions = new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth}; for (int i = 0; i < baseArr.tensorssAlongDimension(0, 1); i++) { INDArray test = baseArr.tensorAlongDimension(i, 0, 1); assertEquals("Wrong at index " + i, assertions[i], test); } } @Test public void testSixteenSecondDim() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {2, 4}), Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {6, 8}), Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {10, 12}), Nd4j.create(new double[] {13, 15}), Nd4j.create(new double[] {14, 16}), }; for (int i = 0; i < baseArr.tensorssAlongDimension(2); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 2); assertEquals("Failed at index " + i, assertions[i], arr); } } @Test public void testThreeTwoTwo() { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {2, 4}), Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {6, 8}), Nd4j.create(new double[] {9, 11}), Nd4j.create(new double[] {10, 12}), }; assertEquals(assertions.length, threeTwoTwo.tensorssAlongDimension(1)); for (int i = 0; i < assertions.length; i++) { INDArray arr = threeTwoTwo.tensorAlongDimension(i, 1); assertEquals(assertions[i], arr); } } @Test public void testThreeTwoTwoTwo() { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 2}), Nd4j.create(new double[] {3, 4}), Nd4j.create(new double[] {5, 6}), Nd4j.create(new double[] {7, 8}), Nd4j.create(new double[] {9, 10}), Nd4j.create(new double[] {11, 12}), }; assertEquals(assertions.length, threeTwoTwo.tensorssAlongDimension(2)); for (int i = 0; i < assertions.length; i++) { assertEquals(assertions[i], threeTwoTwo.tensorAlongDimension(i, 2)); } } @Test public void testPutRow() { INDArray matrix = Nd4j.create(new double[][] {{1, 2}, {3, 4}}); for (int i = 0; i < matrix.rows(); i++) { INDArray row = matrix.getRow(i); System.out.println(matrix.getRow(i)); } matrix.putRow(1, Nd4j.create(new double[] {1, 2})); assertEquals(matrix.getRow(0), matrix.getRow(1)); } @Test public void testSixteenFirstDim() { INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 5}), Nd4j.create(new double[] {2, 6}), Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {4, 8}), Nd4j.create(new double[] {9, 13}), Nd4j.create(new double[] {10, 14}), Nd4j.create(new double[] {11, 15}), Nd4j.create(new double[] {12, 16}), }; for (int i = 0; i < baseArr.tensorssAlongDimension(1); i++) { INDArray arr = baseArr.tensorAlongDimension(i, 1); assertEquals("Failed at index " + i, assertions[i], arr); } } @Test public void testReshapePermute() { INDArray arrNoPermute = Nd4j.ones(5, 3, 4); INDArray reshaped2dNoPermute = arrNoPermute.reshape(5 * 3, 4); //OK assertArrayEquals(reshaped2dNoPermute.shape(), new int[] {5 * 3, 4}); INDArray arr = Nd4j.ones(5, 4, 3); INDArray permuted = arr.permute(0, 2, 1); assertArrayEquals(arrNoPermute.shape(), permuted.shape()); INDArray reshaped2D = permuted.reshape(5 * 3, 4); //NullPointerException assertArrayEquals(reshaped2D.shape(), new int[] {5 * 3, 4}); } @Test public void testEight() { INDArray baseArr = Nd4j.linspace(1, 8, 8).reshape(2, 2, 2); assertEquals(2, baseArr.tensorssAlongDimension(0, 1)); INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 3}, {5, 7}}); INDArray columnVectorSecond = Nd4j.create(new double[][] {{2, 4}, {6, 8}}); INDArray test1 = baseArr.tensorAlongDimension(0, 0, 1); assertEquals(columnVectorFirst, test1); INDArray test2 = baseArr.tensorAlongDimension(1, 0, 1); assertEquals(columnVectorSecond, test2); } @Test public void testOtherReshape() { INDArray nd = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {2, 3}); INDArray slice = nd.slice(1, 0); INDArray vector = slice.reshape(1, 3); for (int i = 0; i < vector.length(); i++) { System.out.println(vector.getDouble(i)); } assertEquals(Nd4j.create(new double[] {4, 5, 6}), vector); } @Test public void testVectorAlongDimension() { INDArray arr = Nd4j.linspace(1, 24, 24).reshape(4, 3, 2); INDArray assertion = Nd4j.create(new float[] {3, 4}, new int[] {1, 2}); INDArray vectorDimensionTest = arr.vectorAlongDimension(1, 2); assertEquals(assertion, vectorDimensionTest); int vectorsAlongDimension1 = arr.vectorsAlongDimension(1); assertEquals(8, vectorsAlongDimension1); INDArray zeroOne = arr.vectorAlongDimension(0, 1); assertEquals(zeroOne, Nd4j.create(new float[] {1, 3, 5})); INDArray testColumn2Assertion = Nd4j.create(new float[] {2, 4, 6}); INDArray testColumn2 = arr.vectorAlongDimension(1, 1); assertEquals(testColumn2Assertion, testColumn2); INDArray testColumn3Assertion = Nd4j.create(new float[] {7, 9, 11}); INDArray testColumn3 = arr.vectorAlongDimension(2, 1); assertEquals(testColumn3Assertion, testColumn3); INDArray v1 = Nd4j.linspace(1, 4, 4).reshape(new int[] {2, 2}); INDArray testColumnV1 = v1.vectorAlongDimension(0, 0); INDArray testColumnV1Assertion = Nd4j.create(new float[] {1, 3}); assertEquals(testColumnV1Assertion, testColumnV1); INDArray testRowV1 = v1.vectorAlongDimension(1, 0); INDArray testRowV1Assertion = Nd4j.create(new float[] {2, 4}); assertEquals(testRowV1Assertion, testRowV1); INDArray n = Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {2, 2, 2}); INDArray vectorOne = n.vectorAlongDimension(1, 2); INDArray assertionVectorOne = Nd4j.create(new double[] {3, 4}); assertEquals(assertionVectorOne, vectorOne); INDArray oneThroughSixteen = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); assertEquals(8, oneThroughSixteen.vectorsAlongDimension(1)); assertEquals(Nd4j.create(new double[] {1, 5}), oneThroughSixteen.vectorAlongDimension(0, 1)); assertEquals(Nd4j.create(new double[] {2, 6}), oneThroughSixteen.vectorAlongDimension(1, 1)); assertEquals(Nd4j.create(new double[] {3, 7}), oneThroughSixteen.vectorAlongDimension(2, 1)); assertEquals(Nd4j.create(new double[] {4, 8}), oneThroughSixteen.vectorAlongDimension(3, 1)); assertEquals(Nd4j.create(new double[] {9, 13}), oneThroughSixteen.vectorAlongDimension(4, 1)); assertEquals(Nd4j.create(new double[] {10, 14}), oneThroughSixteen.vectorAlongDimension(5, 1)); assertEquals(Nd4j.create(new double[] {11, 15}), oneThroughSixteen.vectorAlongDimension(6, 1)); assertEquals(Nd4j.create(new double[] {12, 16}), oneThroughSixteen.vectorAlongDimension(7, 1)); INDArray fourdTest = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); double[][] assertionsArr = new double[][] {{1, 3}, {2, 4}, {5, 7}, {6, 8}, {9, 11}, {10, 12}, {13, 15}, {14, 16}, }; assertEquals(assertionsArr.length, fourdTest.vectorsAlongDimension(2)); for (int i = 0; i < assertionsArr.length; i++) { INDArray test = fourdTest.vectorAlongDimension(i, 2); INDArray assertionEntry = Nd4j.create(assertionsArr[i]); assertEquals(assertionEntry, test); } } @Test public void testColumnSum() { INDArray twoByThree = Nd4j.linspace(1, 600, 600).reshape(150, 4); INDArray columnVar = twoByThree.sum(0); INDArray assertion = Nd4j.create(new float[] {44850.0f, 45000.0f, 45150.0f, 45300.0f}); assertEquals(getFailureMessage(), assertion, columnVar); } @Test public void testRowMean() { INDArray twoByThree = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray rowMean = twoByThree.mean(1); INDArray assertion = Nd4j.create(new double[] {1.5, 3.5}); assertEquals(getFailureMessage(), assertion, rowMean); } @Test public void testRowStd() { INDArray twoByThree = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray rowStd = twoByThree.std(1); INDArray assertion = Nd4j.create(new float[] {0.7071067811865476f, 0.7071067811865476f}); assertEquals(getFailureMessage(), assertion, rowStd); } @Test public void testColumnSumDouble() { DataBuffer.Type initialType = Nd4j.dataType(); DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); INDArray twoByThree = Nd4j.linspace(1, 600, 600).reshape(150, 4); INDArray columnVar = twoByThree.sum(0); INDArray assertion = Nd4j.create(new float[] {44850.0f, 45000.0f, 45150.0f, 45300.0f}); assertEquals(getFailureMessage(), assertion, columnVar); DataTypeUtil.setDTypeForContext(initialType); } @Test public void testColumnVariance() { INDArray twoByThree = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray columnVar = twoByThree.var(true, 0); INDArray assertion = Nd4j.create(new double[] {2, 2}); assertEquals(assertion, columnVar); } @Test public void testCumSum() { INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {1, 4}); INDArray cumSumAnswer = Nd4j.create(new float[] {1, 3, 6, 10}, new int[] {1, 4}); INDArray cumSumTest = n.cumsum(0); assertEquals(getFailureMessage(), cumSumAnswer, cumSumTest); INDArray n2 = Nd4j.linspace(1, 24, 24).reshape(4, 3, 2); INDArray axis0assertion = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0}, n2.shape()); INDArray axis0Test = n2.cumsum(0); assertEquals(getFailureMessage(), axis0assertion, axis0Test); } @Test public void testSumRow() { INDArray rowVector10 = Nd4j.ones(10); INDArray sum1 = rowVector10.sum(1); assertArrayEquals(sum1.shape(), new int[] {1, 1}); assertTrue(sum1.getDouble(0) == 10); } @Test public void testSumColumn() { INDArray colVector10 = Nd4j.ones(10, 1); INDArray sum0 = colVector10.sum(0); assertArrayEquals(sum0.shape(), new int[] {1, 1}); assertTrue(sum0.getDouble(0) == 10); } @Test public void testSum2d() { INDArray arr = Nd4j.ones(10, 10); INDArray sum0 = arr.sum(0); assertArrayEquals(sum0.shape(), new int[] {1, 10}); INDArray sum1 = arr.sum(1); assertArrayEquals(sum1.shape(), new int[] {10, 1}); } @Test public void testSum2dv2() { INDArray arr = Nd4j.ones(10, 10); INDArray sumBoth = arr.sum(0, 1); assertArrayEquals(sumBoth.shape(), new int[] {1, 1}); assertTrue(sumBoth.getDouble(0) == 100); } @Test public void testPermuteReshape() { INDArray arrTest = Nd4j.arange(60).reshape('c', 3, 4, 5); INDArray permute = arrTest.permute(2, 1, 0); assertArrayEquals(new int[] {5, 4, 3}, permute.shape()); assertArrayEquals(new int[] {1, 5, 20}, permute.stride()); INDArray reshapedPermute = permute.reshape(-1, 12); assertArrayEquals(new int[] {5, 12}, reshapedPermute.shape()); assertArrayEquals(new int[] {12, 1}, reshapedPermute.stride()); } @Test public void testRavel() { INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray asseriton = Nd4j.linspace(1, 4, 4); INDArray raveled = linspace.ravel(); assertEquals(asseriton, raveled); INDArray tensorLinSpace = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2); INDArray linspaced = Nd4j.linspace(1, 16, 16); INDArray tensorLinspaceRaveled = tensorLinSpace.ravel(); assertEquals(linspaced, tensorLinspaceRaveled); } @Test public void testPutScalar() { //Check that the various putScalar methods have the same result... int[][] shapes = new int[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {1, 4, 5}, {3, 1, 5}, {3, 4, 1}, {1, 1, 5}, {3, 4, 5, 6}, {1, 4, 5, 6}, {3, 1, 5, 6}, {3, 4, 1, 6}, {3, 4, 5, 1}, {1, 1, 5, 6}, {3, 1, 1, 6}, {3, 1, 1, 1}}; for (int[] shape : shapes) { int rank = shape.length; NdIndexIterator iter = new NdIndexIterator(shape); INDArray firstC = Nd4j.create(shape, 'c'); INDArray firstF = Nd4j.create(shape, 'f'); INDArray secondC = Nd4j.create(shape, 'c'); INDArray secondF = Nd4j.create(shape, 'f'); int i = 0; while (iter.hasNext()) { int[] currIdx = iter.next(); firstC.putScalar(currIdx, i); firstF.putScalar(currIdx, i); switch (rank) { case 2: secondC.putScalar(currIdx[0], currIdx[1], i); secondF.putScalar(currIdx[0], currIdx[1], i); break; case 3: secondC.putScalar(currIdx[0], currIdx[1], currIdx[2], i); secondF.putScalar(currIdx[0], currIdx[1], currIdx[2], i); break; case 4: secondC.putScalar(currIdx[0], currIdx[1], currIdx[2], currIdx[3], i); secondF.putScalar(currIdx[0], currIdx[1], currIdx[2], currIdx[3], i); break; default: throw new RuntimeException(); } i++; } assertEquals(firstC, firstF); assertEquals(firstC, secondC); assertEquals(firstC, secondF); } } @Override public char ordering() { return 'c'; } }