package org.nd4j.linalg.slicing;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class SlicingTestsC extends BaseNd4jTest {
public SlicingTestsC(Nd4jBackend backend) {
super(backend);
}
@Test
public void testSliceRowVector() {
INDArray arr = Nd4j.zeros(5);
System.out.println(arr.slice(1));
}
@Test
public void testSliceAssertion() {
INDArray arr = Nd4j.linspace(1, 30, 30).reshape(3, 5, 2);
INDArray firstRow = arr.slice(0).slice(0);
for (int i = 0; i < firstRow.length(); i++) {
System.out.println(firstRow.getDouble(i));
}
System.out.println(firstRow);
}
@Test
public void testSliceShape() {
INDArray arr = Nd4j.linspace(1, 30, 30).reshape(3, 5, 2);
INDArray sliceZero = arr.slice(0);
for (int i = 0; i < sliceZero.rows(); i++) {
INDArray row = sliceZero.slice(i);
for (int j = 0; j < row.length(); j++) {
System.out.println(row.getDouble(j));
}
System.out.println(row);
}
INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, new int[] {5, 2});
for (int i = 0; i < assertion.rows(); i++) {
INDArray row = assertion.slice(i);
for (int j = 0; j < row.length(); j++) {
System.out.println(row.getDouble(j));
}
System.out.println(row);
}
assertArrayEquals(new int[] {5, 2}, sliceZero.shape());
assertEquals(assertion, sliceZero);
INDArray assertionTwo = Nd4j.create(new double[] {11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, new int[] {5, 2});
INDArray sliceTest = arr.slice(1);
assertEquals(assertionTwo, sliceTest);
}
@Test
public void testSwapReshape() {
INDArray n2 = Nd4j.create(Nd4j.linspace(1, 30, 30).data(), new int[] {3, 5, 2});
INDArray swapped = n2.swapAxes(n2.shape().length - 1, 1);
INDArray firstSlice2 = swapped.slice(0).slice(0);
INDArray oneThreeFiveSevenNine = Nd4j.create(new float[] {1, 3, 5, 7, 9});
assertEquals(firstSlice2, oneThreeFiveSevenNine);
INDArray raveled = oneThreeFiveSevenNine.reshape(5, 1);
INDArray raveledOneThreeFiveSevenNine = oneThreeFiveSevenNine.reshape(5, 1);
assertEquals(raveled, raveledOneThreeFiveSevenNine);
INDArray firstSlice3 = swapped.slice(0).slice(1);
INDArray twoFourSixEightTen = Nd4j.create(new float[] {2, 4, 6, 8, 10});
assertEquals(firstSlice2, oneThreeFiveSevenNine);
INDArray raveled2 = twoFourSixEightTen.reshape(5, 1);
INDArray raveled3 = firstSlice3.reshape(5, 1);
assertEquals(raveled2, raveled3);
}
@Test
public void testGetRow() {
INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
INDArray get = arr.getRow(1);
INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
INDArray assertion = Nd4j.create(new double[] {4, 5, 6});
assertEquals(assertion, get);
assertEquals(get, get2);
get2.assign(Nd4j.linspace(1, 3, 3));
assertEquals(Nd4j.linspace(1, 3, 3), get2);
INDArray threeByThree = Nd4j.linspace(1, 9, 9).reshape(3, 3);
INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all());
INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}});
assertEquals(threeByThreeAssertion, offsetTest);
}
@Test
public void testVectorIndexing() {
INDArray zeros = Nd4j.create(1, 400000);
INDArray get = zeros.get(NDArrayIndex.interval(0, 300000));
assertArrayEquals(new int[] {1, 300000}, get.shape());
}
@Override
public char ordering() {
return 'c';
}
}