package org.nd4j.linalg.shape.reshape; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assume.assumeNotNull; /** * @author Adam Gibson */ @RunWith(Parameterized.class) public class ReshapeTests extends BaseNd4jTest { public ReshapeTests(Nd4jBackend backend) { super(backend); } @Test public void testThreeTwoTwoTwo() { INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2); INDArray sliceZero = Nd4j.create(new double[][] {{1, 7}, {4, 10}}); INDArray sliceOne = Nd4j.create(new double[][] {{2, 8}, {5, 11}}); INDArray sliceTwo = Nd4j.create(new double[][] {{3, 9}, {6, 12}}); INDArray[] assertions = new INDArray[] {sliceZero, sliceOne, sliceTwo}; for (int i = 0; i < threeTwoTwo.slices(); i++) { INDArray sliceI = threeTwoTwo.slice(i); assertEquals(assertions[i], sliceI); } INDArray linspaced = Nd4j.linspace(1, 4, 4).reshape(2, 2); INDArray[] assertionsTwo = new INDArray[] {Nd4j.create(new double[] {1, 3}), Nd4j.create(new double[] {2, 4})}; for (int i = 0; i < assertionsTwo.length; i++) assertEquals(linspaced.slice(i), assertionsTwo[i]); } @Test public void testColumnVectorReshape() { double delta = 1e-1; INDArray arr = Nd4j.create(1, 3); INDArray reshaped = arr.reshape('f', 3, 1); assertArrayEquals(new int[] {3, 1}, reshaped.shape()); assertEquals(0.0, reshaped.getDouble(1), delta); assertEquals(0.0, reshaped.getDouble(2), delta); assumeNotNull(reshaped.toString()); } @Override public char ordering() { return 'f'; } }