package org.nd4j.linalg.shape.indexing;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import static org.junit.Assert.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexingTestsC extends BaseNd4jTest {
public IndexingTestsC(Nd4jBackend backend) {
super(backend);
}
@Test
public void testExecSubArray() {
INDArray nd = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {2, 3});
INDArray sub = nd.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2));
Nd4j.getExecutioner().exec(new ScalarAdd(sub, 2));
assertEquals(getFailureMessage(), Nd4j.create(new double[][] {{3, 4}, {6, 7}}), sub);
}
@Test
public void testLinearViewElementWiseMatching() {
INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArray dup = linspace.dup();
linspace.addi(dup);
}
@Test
public void testGetRows() {
INDArray arr = Nd4j.linspace(1, 9, 9).reshape(3, 3);
INDArray testAssertion = Nd4j.create(new double[][] {{4, 5}, {7, 8}});
INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(0, 1));
assertEquals(testAssertion, test);
}
@Test
public void testFirstColumn() {
INDArray arr = Nd4j.create(new double[][] {{5, 7}, {6, 8}});
INDArray assertion = Nd4j.create(new double[] {5, 6});
INDArray test = arr.get(NDArrayIndex.all(), NDArrayIndex.point(0));
assertEquals(assertion, test);
}
@Test
public void testMultiRow() {
INDArray matrix = Nd4j.linspace(1, 9, 9).reshape(3, 3);
INDArray assertion = Nd4j.create(new double[][] {{4, 7}});
INDArray test = matrix.get(new SpecifiedIndex(1, 2), NDArrayIndex.interval(0, 1));
assertEquals(assertion, test);
}
@Test
public void testPointIndexes() {
INDArray arr = Nd4j.create(4, 3, 2);
INDArray get = arr.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all());
assertArrayEquals(new int[] {4, 2}, get.shape());
INDArray linspaced = Nd4j.linspace(1, 24, 24).reshape(4, 3, 2);
INDArray assertion = Nd4j.create(new double[][] {{3, 4}, {9, 10}, {15, 16}, {21, 22}});
INDArray linspacedGet = linspaced.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all());
for (int i = 0; i < linspacedGet.slices(); i++) {
INDArray sliceI = linspacedGet.slice(i);
assertEquals(assertion.slice(i), sliceI);
}
assertArrayEquals(new int[] {6, 1}, linspacedGet.stride());
assertEquals(assertion, linspacedGet);
}
@Test
public void testGetWithVariedStride() {
int ph = 0;
int pw = 0;
int sy = 2;
int sx = 2;
int iLim = 8;
int jLim = 8;
int i = 0;
int j = 0;
INDArray img = Nd4j.create(new double[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4}, new int[] {1, 1, 8, 8});
INDArray padded = Nd4j.pad(img, new int[][] {{0, 0}, {0, 0}, {ph, ph + sy - 1}, {pw, pw + sx - 1}},
Nd4j.PadMode.CONSTANT);
INDArray get = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim),
NDArrayIndex.interval(j, sx, jLim));
assertArrayEquals(new int[] {81, 81, 18, 2}, get.stride());
INDArray assertion = Nd4j.create(new double[] {1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3},
new int[] {1, 1, 4, 4});
assertEquals(assertion, get);
i = 1;
iLim = 9;
INDArray get3 = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim),
NDArrayIndex.interval(j, sx, jLim));
INDArray assertion2 = Nd4j.create(new double[] {2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 2, 2, 4, 4, 4, 4},
new int[] {1, 1, 4, 4});
assertArrayEquals(new int[] {81, 81, 18, 2}, get3.stride());
assertEquals(assertion2, get3);
i = 0;
iLim = 8;
jLim = 9;
j = 1;
INDArray get2 = padded.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, sy, iLim),
NDArrayIndex.interval(j, sx, jLim));
assertArrayEquals(new int[] {81, 81, 18, 2}, get2.stride());
assertEquals(assertion, get2);
}
@Test
public void testRowVectorInterval() {
int len = 30;
INDArray row = Nd4j.zeros(len);
for (int i = 0; i < len; i++) {
row.putScalar(i, i);
}
INDArray first10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 10));
assertArrayEquals(first10a.shape(), new int[] {1, 10});
for (int i = 0; i < 10; i++)
assertTrue(first10a.getDouble(i) == i);
INDArray first10b = row.get(NDArrayIndex.interval(0, 10));
assertArrayEquals(first10b.shape(), new int[] {1, 10});
for (int i = 0; i < 10; i++)
assertTrue(first10b.getDouble(i) == i);
INDArray last10a = row.get(NDArrayIndex.point(0), NDArrayIndex.interval(20, 30));
assertArrayEquals(last10a.shape(), new int[] {1, 10});
for (int i = 0; i < 10; i++)
assertTrue(last10a.getDouble(i) == 20 + i);
INDArray last10b = row.get(NDArrayIndex.interval(20, 30));
assertArrayEquals(last10b.shape(), new int[] {1, 10});
for (int i = 0; i < 10; i++)
assertTrue(last10b.getDouble(i) == 20 + i);
}
@Override
public char ordering() {
return 'c';
}
}