package jcuda.jcublas.ops;
import org.apache.commons.math3.util.Pair;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* Created by raver on 08.05.2016.
*/
@Ignore
public class ElementWiseStrideTests {
@Before
public void setUp() {
CudaEnvironment.getInstance().getConfiguration()
.setFirstMemory(AllocationStatus.DEVICE)
.setExecutionModel(Configuration.ExecutionModel.SEQUENTIAL)
.setAllocationModel(Configuration.AllocationModel.CACHE_ALL)
.setMaximumBlockSize(128)
.enableDebug(true)
.setVerbose(true);
System.out.println("Init called");
}
@Test
public void testEWS1() throws Exception {
List<Pair<INDArray,String>> list = NDArrayCreationUtil.getAllTestMatricesWithShape(4,5,12345);
list.addAll(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345,4,5,6));
list.addAll(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345,4,5,6,7));
list.addAll(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345,4,5,6,7,8));
list.addAll(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345,4,5,6,7,8,9));
for(Pair<INDArray,String> p : list){
int ewsBefore = Shape.elementWiseStride(p.getFirst().shapeInfo());
INDArray reshapeAttempt = Shape.newShapeNoCopy(p.getFirst(),new int[]{1,p.getFirst().length()}, Nd4j.order() == 'f');
if (reshapeAttempt != null && ewsBefore == -1 && reshapeAttempt.elementWiseStride() != -1 ) {
System.out.println("NDArrayCreationUtil." + p.getSecond());
System.out.println("ews before: " + ewsBefore);
System.out.println(p.getFirst().shapeInfoToString());
System.out.println("ews returned by elementWiseStride(): " + p.getFirst().elementWiseStride());
System.out.println("ews returned by reshape(): " + reshapeAttempt.elementWiseStride());
System.out.println();
// assertTrue(false);
} else {
// System.out.println("FAILED: " + p.getFirst().shapeInfoToString());
}
}
}
@Test
public void testDualVStack() throws Exception {
INDArray[] arrs = new INDArray[50];
INDArray[] arrs2 = new INDArray[50];
for( int i=0; i<arrs.length; i++ ){
arrs[i] = Nd4j.create(new float[]{1f, 2f}).dup('c');
arrs2[i] = Nd4j.create(new int[]{1,2},'c');
}
INDArray result = Nd4j.vstack(arrs);
System.out.println("Result: " + result);
// Nd4j.vstack(arrs2);
}
@Test
public void testBVStack() throws Exception {
INDArray[] arr = new INDArray[5];
for( int i=0; i<arr.length; i++ ){
arr[i] = Nd4j.create(new int[]{1,5749},'c');
}
Nd4j.vstack(arr);
Nd4j.create(1);
}
@Test
public void test2(){
INDArray[] first = new INDArray[10];
INDArray[] second = new INDArray[10];
for( int i=0; i<10; i++ ){
first[i] = Nd4j.create(new int[]{1,784},'c');
second[i] = Nd4j.create(new int[]{1,5749},'c');
}
Nd4j.vstack(first);
Nd4j.vstack(second);
Nd4j.create(1);
}
@Test
public void testVstackWithMatrices(){
INDArray[] arr = new INDArray[3];
arr[0] = Nd4j.linspace(0,49,50).reshape('c',5,10);
arr[1] = Nd4j.linspace(50,59,10);
arr[2] = Nd4j.linspace(60,99,40).reshape('c',4,10);
INDArray expected = Nd4j.linspace(0,99,100).reshape('c',10,10);
INDArray actual = Nd4j.vstack(arr);
System.out.println(expected);
System.out.println();
System.out.println(actual);
assertEquals(expected, actual);
}
@Test
public void testHstackConcatCols(){
int rows = 10;
INDArray[] arr = new INDArray[5];
for( int i=0; i<arr.length; i++ ){
arr[i] = Nd4j.linspace(i*rows,(i+1)*rows-1, rows).transpose();
}
INDArray expected = Nd4j.linspace(0,arr.length*rows-1, arr.length*rows).reshape('f',rows,arr.length);
INDArray actual = Nd4j.hstack(arr);
System.out.println(expected);
System.out.println();
System.out.println(actual);
assertEquals(expected, actual);
}
@Test
public void testHstackConcatSimple(){
int rows = 10;
INDArray[] arr = new INDArray[5];
for( int i=0; i<arr.length; i++ ){
arr[i] = Nd4j.linspace(i*rows,(i+1)*rows-1, rows);
}
INDArray expected = Nd4j.linspace(0,arr.length*rows-1, arr.length*rows);
INDArray actual = Nd4j.hstack(arr);
System.out.println(expected);
System.out.println();
System.out.println(actual);
assertEquals(expected, actual);
}
@Test
public void testVstackConcatRows(){
int cols = 10;
INDArray[] arr = new INDArray[5];
for( int i=0; i<arr.length; i++ ){
arr[i] = Nd4j.linspace(i*cols,(i+1)*cols-1, cols);
}
INDArray expected = Nd4j.linspace(0,arr.length*cols-1, arr.length*cols).reshape('c',arr.length, cols);
INDArray actual = Nd4j.vstack(arr);
System.out.println(expected);
System.out.println();
System.out.println(actual);
assertEquals(expected, actual);
}
}