package org.nd4j.linalg.shape;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import static org.junit.Assert.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class ShapeTests extends BaseNd4jTest {
public ShapeTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testSixteenZeroOne() {
INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2);
assertEquals(4, baseArr.tensorssAlongDimension(0, 1));
INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 3}, {2, 4}});
INDArray columnVectorSecond = Nd4j.create(new double[][] {{9, 11}, {10, 12}});
INDArray columnVectorThird = Nd4j.create(new double[][] {{5, 7}, {6, 8}});
INDArray columnVectorFourth = Nd4j.create(new double[][] {{13, 15}, {14, 16}});
INDArray[] assertions =
new INDArray[] {columnVectorFirst, columnVectorSecond, columnVectorThird, columnVectorFourth};
for (int i = 0; i < baseArr.tensorssAlongDimension(0, 1); i++) {
INDArray test = baseArr.tensorAlongDimension(i, 0, 1);
assertEquals("Wrong at index " + i, assertions[i], test);
}
}
@Test
public void testVectorAlongDimension1() {
INDArray arr = Nd4j.create(1, 5, 5);
assertEquals(arr.vectorsAlongDimension(0), 5);
assertEquals(arr.vectorsAlongDimension(1), 5);
for (int i = 0; i < arr.vectorsAlongDimension(0); i++) {
if (i < arr.vectorsAlongDimension(0) - 1 && i > 0)
assertEquals(25, arr.vectorAlongDimension(i, 0).length());
}
}
@Test
public void testSixteenSecondDim() {
INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2);
INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 5}), Nd4j.create(new double[] {9, 13}),
Nd4j.create(new double[] {3, 7}), Nd4j.create(new double[] {11, 15}),
Nd4j.create(new double[] {2, 6}), Nd4j.create(new double[] {10, 14}),
Nd4j.create(new double[] {4, 8}), Nd4j.create(new double[] {12, 16}),
};
for (int i = 0; i < baseArr.tensorssAlongDimension(2); i++) {
INDArray arr = baseArr.tensorAlongDimension(i, 2);
assertEquals("Failed at index " + i, assertions[i], arr);
}
}
@Test
public void testVectorAlongDimension() {
INDArray arr = Nd4j.linspace(1, 24, 24).reshape(4, 3, 2);
INDArray assertion = Nd4j.create(new float[] {5, 17}, new int[] {1, 2});
INDArray vectorDimensionTest = arr.vectorAlongDimension(1, 2);
assertEquals(assertion, vectorDimensionTest);
INDArray zeroOne = arr.vectorAlongDimension(0, 1);
assertEquals(zeroOne, Nd4j.create(new float[] {1, 5, 9}));
INDArray testColumn2Assertion = Nd4j.create(new float[] {13, 17, 21});
INDArray testColumn2 = arr.vectorAlongDimension(1, 1);
assertEquals(testColumn2Assertion, testColumn2);
INDArray testColumn3Assertion = Nd4j.create(new float[] {2, 6, 10});
INDArray testColumn3 = arr.vectorAlongDimension(2, 1);
assertEquals(testColumn3Assertion, testColumn3);
INDArray v1 = Nd4j.linspace(1, 4, 4).reshape(new int[] {2, 2});
INDArray testColumnV1 = v1.vectorAlongDimension(0, 0);
INDArray testColumnV1Assertion = Nd4j.create(new float[] {1, 2});
assertEquals(testColumnV1Assertion, testColumnV1);
INDArray testRowV1 = v1.vectorAlongDimension(1, 0);
INDArray testRowV1Assertion = Nd4j.create(new float[] {3, 4});
assertEquals(testRowV1Assertion, testRowV1);
}
@Test
public void testThreeTwoTwo() {
INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2);
INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 4}), Nd4j.create(new double[] {7, 10}),
Nd4j.create(new double[] {2, 5}), Nd4j.create(new double[] {8, 11}),
Nd4j.create(new double[] {3, 6}), Nd4j.create(new double[] {9, 12}),
};
assertEquals(assertions.length, threeTwoTwo.tensorssAlongDimension(1));
for (int i = 0; i < assertions.length; i++) {
INDArray test = threeTwoTwo.tensorAlongDimension(i, 1);
assertEquals(assertions[i], test);
}
}
@Test
public void testNoCopy() {
INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12);
INDArray arr = Shape.newShapeNoCopy(threeTwoTwo, new int[] {3, 2, 2}, true);
assertArrayEquals(arr.shape(), new int[] {3, 2, 2});
}
@Test
public void testThreeTwoTwoTwo() {
INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2);
INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 7}), Nd4j.create(new double[] {4, 10}),
Nd4j.create(new double[] {2, 8}), Nd4j.create(new double[] {5, 11}),
Nd4j.create(new double[] {3, 9}), Nd4j.create(new double[] {6, 12}),
};
assertEquals(assertions.length, threeTwoTwo.tensorssAlongDimension(2));
for (int i = 0; i < assertions.length; i++) {
INDArray test = threeTwoTwo.tensorAlongDimension(i, 2);
assertEquals(assertions[i], test);
}
}
@Test
public void testNewAxis() {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArray newAxisAssertion = Nd4j.create(new double[] {1, 3}).reshape(1, 2, 1);
INDArray newAxisGet = arr.get(NDArrayIndex.point(0), NDArrayIndex.newAxis());
assertEquals(newAxisAssertion, newAxisGet);
INDArray tensor = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2);
INDArray assertion = Nd4j.create(new double[][] {{1, 7}, {4, 10}}).reshape(1, 2, 2);
INDArray tensorGet = tensor.get(NDArrayIndex.point(0), NDArrayIndex.newAxis());
assertEquals(assertion, tensorGet);
}
@Test
public void testSixteenFirstDim() {
INDArray baseArr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2);
INDArray[] assertions = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {9, 11}),
Nd4j.create(new double[] {5, 7}), Nd4j.create(new double[] {13, 15}),
Nd4j.create(new double[] {2, 4}), Nd4j.create(new double[] {10, 12}),
Nd4j.create(new double[] {6, 8}), Nd4j.create(new double[] {14, 16}),
};
for (int i = 0; i < baseArr.tensorssAlongDimension(1); i++) {
INDArray arr = baseArr.tensorAlongDimension(i, 1);
assertEquals("Failed at index " + i, assertions[i], arr);
}
}
@Test
public void testDimShuffle() {
INDArray scalarTest = Nd4j.scalar(0.0);
INDArray broadcast = scalarTest.dimShuffle(new Object[] {'x'}, new int[] {0, 1}, new boolean[] {true, true});
assertTrue(broadcast.rank() == 3);
INDArray rowVector = Nd4j.linspace(1, 4, 4);
assertEquals(rowVector,
rowVector.dimShuffle(new Object[] {0, 1}, new int[] {0, 1}, new boolean[] {false, false}));
//add extra dimension to row vector in middle
INDArray rearrangedRowVector =
rowVector.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {true, true});
assertArrayEquals(new int[] {1, 1, 4}, rearrangedRowVector.shape());
INDArray dimshuffed = rowVector.dimShuffle(new Object[] {'x', 0, 'x', 'x'}, new int[] {0, 1},
new boolean[] {true, true});
assertArrayEquals(new int[] {1, 1, 1, 1, 4}, dimshuffed.shape());
}
@Test
public void testEight() {
INDArray baseArr = Nd4j.linspace(1, 8, 8).reshape(2, 2, 2);
assertEquals(2, baseArr.tensorssAlongDimension(0, 1));
INDArray columnVectorFirst = Nd4j.create(new double[][] {{1, 3}, {2, 4}});
INDArray columnVectorSecond = Nd4j.create(new double[][] {{5, 7}, {6, 8}});
assertEquals(columnVectorFirst, baseArr.tensorAlongDimension(0, 0, 1));
assertEquals(columnVectorSecond, baseArr.tensorAlongDimension(1, 0, 1));
}
@Override
public char ordering() {
return 'f';
}
}