package org.nd4j.linalg.shape.indexing;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class IndexingTests extends BaseNd4jTest {
public IndexingTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testGet() {
System.out.println( "Testing sub-array put and get with a 3D array ..." );
INDArray arr = Nd4j.linspace( 0, 124, 125 ).reshape( 5, 5, 5 );
/*
* Extract elements with the following indices:
*
* (2,1,1) (2,1,2) (2,1,3)
* (2,2,1) (2,2,2) (2,2,3)
* (2,3,1) (2,3,2) (2,3,3)
*/
int slice = 2;
int iStart = 1;
int jStart = 1;
int iEnd = 4;
int jEnd = 4;
// Method A: Element-wise.
INDArray subArr_A = Nd4j.create( new int[]{ 3, 3 } );
for ( int i = iStart; i < iEnd; i++ ){
for ( int j = jStart; j < jEnd; j++ ){
double val = arr.getDouble( slice, i,j );
int[] sub = new int[]{ i - iStart ,j - jStart};
subArr_A.putScalar( sub, val );
}
}
// Method B: Using NDArray get and put with index classes.
INDArray subArr_B = Nd4j.create( new int[]{ 3, 3 } );
INDArrayIndex ndi_Slice = NDArrayIndex.point( slice );
INDArrayIndex ndi_J = NDArrayIndex.interval( iStart, iEnd );
INDArrayIndex ndi_I = NDArrayIndex.interval( iStart, iEnd );
INDArrayIndex[] whereToGet = new INDArrayIndex[]{ ndi_Slice, ndi_J, ndi_I };
INDArray whatToPut = arr.get( whereToGet );
System.out.println(whatToPut);
INDArrayIndex[] whereToPut = new INDArrayIndex[]{ NDArrayIndex.all(), NDArrayIndex.all() };
subArr_B.put( whereToPut, whatToPut );
assertEquals( subArr_A, subArr_B );
System.out.println( "... done" );
}
@Test
@Ignore //added recently: For some reason this is passing.
// The test .equals fails on a comparison of row vs column vector.
//TODO: possibly figure out what's going on here at some point?
// - Adam
public void testTensorGet() {
INDArray threeTwoTwo = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2);
/*
* [[[ 1., 7.],
[ 4., 10.]],
[[ 2., 8.],
[ 5., 11.]],
[[ 3., 9.],
[ 6., 12.]]])
*/
INDArray firstAssertion = Nd4j.create(new double[] {1, 7});
INDArray firstTest = threeTwoTwo.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all());
assertEquals(firstAssertion, firstTest);
INDArray secondAssertion = Nd4j.create(new double[] {3, 9});
INDArray secondTest = threeTwoTwo.get(NDArrayIndex.point(2), NDArrayIndex.point(0), NDArrayIndex.all());
assertEquals(secondAssertion, secondTest);
}
@Test
public void concatGetBug() {
int width = 5;
int height = 4;
int depth = 3;
int nExamples1 = 2;
int nExamples2 = 1;
int length1 = width * height * depth * nExamples1;
int length2 = width * height * depth * nExamples2;
INDArray first = Nd4j.linspace(1, length1, length1).reshape('c', nExamples1, depth, width, height);
INDArray second = Nd4j.linspace(1, length2, length2).reshape('c', nExamples2, depth, width, height).addi(0.1);
INDArray fMerged = Nd4j.concat(0, first, second);
assertEquals(first, fMerged.get(NDArrayIndex.interval(0, nExamples1), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
INDArray get = fMerged.get(NDArrayIndex.interval(nExamples1, nExamples1 + nExamples2, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
assertEquals(second, get.dup()); //Passes
assertEquals(second, get); //Fails
}
@Test
public void testShape() {
INDArray ndarray = Nd4j.create(new float[][] {{1f, 2f}, {3f, 4f}});
INDArray subarray = ndarray.get(NDArrayIndex.point(0), NDArrayIndex.all());
assertTrue(subarray.isRowVector());
int[] shape = subarray.shape();
assertEquals(shape[0], 1);
assertEquals(shape[1], 2);
}
@Test
public void testGetRows() {
INDArray arr = Nd4j.linspace(1, 9, 9).reshape(3, 3);
INDArray testAssertion = Nd4j.create(new double[][] {{5, 8}, {6, 9}});
INDArray test = arr.get(new SpecifiedIndex(1, 2), new SpecifiedIndex(1, 2));
assertEquals(testAssertion, test);
}
@Test
public void testFirstColumn() {
INDArray arr = Nd4j.create(new double[][] {{5, 6}, {7, 8}});
INDArray assertion = Nd4j.create(new double[] {5, 7});
INDArray test = arr.get(NDArrayIndex.all(), NDArrayIndex.point(0));
assertEquals(assertion, test);
}
@Test
public void testLinearIndex() {
INDArray linspace = Nd4j.linspace(1, 4, 4).reshape(2, 2);
for (int i = 0; i < linspace.length(); i++) {
assertEquals(i + 1, linspace.getDouble(i), 1e-1);
}
}
@Override
public char ordering() {
return 'f';
}
}