package org.nd4j.linalg;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.Tanh;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Created by agibsonccc on 4/1/16.
*/
@RunWith(Parameterized.class)
public class LoneTest extends BaseNd4jTest {
public LoneTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testSoftmaxStability() {
INDArray input = Nd4j.create(new double[] {-0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04})
.transpose();
System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
INDArray output = Nd4j.create(10, 1);
System.out.println("Element wise stride of output " + output.elementWiseStride());
Nd4j.getExecutioner().exec(new SoftMax(input, output));
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testFlattenedView() {
int rows = 8;
int cols = 8;
int dim2 = 4;
int length = rows * cols;
int length3d = rows * cols * dim2;
INDArray first = Nd4j.linspace(1, length, length).reshape('c', rows, cols);
INDArray second = Nd4j.create(new int[] {rows, cols}, 'f').assign(first);
INDArray third = Nd4j.linspace(1, length3d, length3d).reshape('c', rows, cols, dim2);
first.addi(0.1);
second.addi(0.2);
third.addi(0.3);
first = first.get(NDArrayIndex.interval(4, 8), NDArrayIndex.interval(0, 2, 8));
for (int i = 0; i < first.tensorssAlongDimension(0); i++) {
System.out.println(first.tensorAlongDimension(i, 0));
}
for (int i = 0; i < first.tensorssAlongDimension(1); i++) {
System.out.println(first.tensorAlongDimension(i, 1));
}
second = second.get(NDArrayIndex.interval(3, 7), NDArrayIndex.all());
third = third.permute(0, 2, 1);
INDArray cAssertion = Nd4j.create(new double[] {33.10, 35.10, 37.10, 39.10, 41.10, 43.10, 45.10, 47.10, 49.10,
51.10, 53.10, 55.10, 57.10, 59.10, 61.10, 63.10});
INDArray fAssertion = Nd4j.create(new double[] {33.10, 41.10, 49.10, 57.10, 35.10, 43.10, 51.10, 59.10, 37.10,
45.10, 53.10, 61.10, 39.10, 47.10, 55.10, 63.10});
assertEquals(cAssertion, Nd4j.toFlattened('c', first));
assertEquals(fAssertion, Nd4j.toFlattened('f', first));
}
@Test
public void testIndexingColVec() {
int elements = 5;
INDArray rowVector = Nd4j.linspace(1, elements, elements).reshape(1, elements);
INDArray colVector = rowVector.transpose();
int j;
INDArray jj;
for (int i = 0; i < elements; i++) {
j = i + 1;
assertEquals(colVector.getRow(i).getInt(0), i + 1);
assertEquals(rowVector.getColumn(i).getInt(0), i + 1);
assertEquals(rowVector.get(NDArrayIndex.interval(i, j)).getInt(0), i + 1);
assertEquals(colVector.get(NDArrayIndex.interval(i, j)).getInt(0), i + 1);
System.out.println("Making sure index interval will not crash with begin/end vals...");
jj = colVector.get(NDArrayIndex.interval(i, i + 10));
jj = colVector.get(NDArrayIndex.interval(i, i + 10));
}
}
@Test
public void concatScalarVectorIssue() {
//A bug was found when the first array that concat sees is a scalar and the rest vectors + scalars
INDArray arr1 = Nd4j.create(1, 1);
INDArray arr2 = Nd4j.create(1, 8);
INDArray arr3 = Nd4j.create(1, 1);
INDArray arr4 = Nd4j.concat(1, arr1, arr2, arr3);
assertTrue(arr4.sumNumber().floatValue() <= Nd4j.EPS_THRESHOLD);
}
@Test
public void reshapeTensorMmul() {
INDArray a = Nd4j.linspace(1, 2, 12).reshape(2, 3, 2);
INDArray b = Nd4j.linspace(3, 4, 4).reshape(2, 2);
int[][] axes = new int[2][];
axes[0] = new int[] {0, 1};
axes[1] = new int[] {0, 2};
//this was throwing an exception
INDArray c = Nd4j.tensorMmul(b, a, axes);
}
@Test
public void maskWhenMerge() {
DataSet dsA = new DataSet(Nd4j.linspace(1, 15, 15).reshape(1, 3, 5), Nd4j.zeros(1, 3, 5));
DataSet dsB = new DataSet(Nd4j.linspace(1, 9, 9).reshape(1, 3, 3), Nd4j.zeros(1, 3, 3));
List<DataSet> dataSetList = new ArrayList<DataSet>();
dataSetList.add(dsA);
dataSetList.add(dsB);
DataSet fullDataSet = DataSet.merge(dataSetList);
assertTrue(fullDataSet.getFeaturesMaskArray() != null);
DataSet fullDataSetCopy = fullDataSet.copy();
assertTrue(fullDataSetCopy.getFeaturesMaskArray() != null);
}
@Test
public void testRelu() {
INDArray aA = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray aD = Nd4j.linspace(-3, 4, 8).reshape(2, 4);
INDArray b = Nd4j.getExecutioner().execAndReturn(new Tanh(aA));
//Nd4j.getExecutioner().execAndReturn(new TanhDerivative(aD));
System.out.println(aA);
System.out.println(aD);
System.out.println(b);
}
@Test
public void testTad() {
int[] someShape = {2, 1, 3, 3};
INDArray a = Nd4j.linspace(1, 18, 18).reshape(someShape);
INDArray java = a.javaTensorAlongDimension(0, 2, 3);
INDArray tad = a.tensorAlongDimension(0, 2, 3);
//assertTrue(a.tensorAlongDimension(0,2,3).rank() == 2); //is rank 3 with an extra 1
assertEquals(java, tad);
}
@Test(expected = IllegalStateException.class)
public void opsNotAllowed() {
INDArray A = Nd4j.ones(2, 3, 1);
INDArray B = Nd4j.ones(2, 3);
System.out.println(A.add(B));
System.out.println(B.add(A));
}
@Test
//broken at a threshold
public void testArgMax() {
int max = 63;
INDArray A = Nd4j.linspace(1, max, max).reshape(1, max);
int currentArgMax = Nd4j.argMax(A).getInt(0, 0);
assertEquals(max - 1, currentArgMax);
max = 64;
A = Nd4j.linspace(1, max, max).reshape(1, max);
currentArgMax = Nd4j.argMax(A).getInt(0, 0);
System.out.println("Returned argMax is " + currentArgMax);
assertEquals(max - 1, currentArgMax);
}
}