package org.nd4j.linalg.shape;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.util.Pair;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* @author raver119@gmail.com
*/
@Slf4j
@RunWith(Parameterized.class)
public class TADTests extends BaseNd4jTest {
public TADTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testStall() {
//[4, 3, 3, 4, 5, 60, 20, 5, 1, 0, 1, 99], dimensions: [1, 2, 3]
INDArray arr = Nd4j.create(3, 3, 4, 5);
arr.tensorAlongDimension(0, 1, 2, 3);
}
/**
* This test checks for TADs equality between Java & native
*
* @throws Exception
*/
@Test
public void testEquality1() throws Exception {
char[] order = new char[] {'c', 'f'};
int[] dim_e = new int[] {0, 2};
int[] dim_x = new int[] {1, 3};
List<int[]> dim_3 = Arrays.asList(new int[] {0, 2, 3}, new int[] {0, 1, 2}, new int[] {1, 2, 3},
new int[] {0, 1, 3});
for (char o : order) {
INDArray array = Nd4j.create(new int[] {3, 5, 7, 9}, o);
for (int e : dim_e) {
for (int x : dim_x) {
int[] shape = new int[] {e, x};
Arrays.sort(shape);
INDArray assertion = array.javaTensorAlongDimension(0, shape);
INDArray test = array.tensorAlongDimension(0, shape);
assertEquals(assertion, test);
assertEquals(assertion.shapeInfoDataBuffer(), test.shapeInfoDataBuffer());
/*DataBuffer tadShape_N = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array, shape).getFirst();
DataBuffer tadShape_J = array.tensorAlongDimension(0, shape).shapeInfoDataBuffer();
log.info("Original order: {}; Dimensions: {}; Original shape: {};", o, Arrays.toString(shape), Arrays.toString(array.shapeInfoDataBuffer().asInt()));
log.info("Java shape: {}; Native shape: {}", Arrays.toString(tadShape_J.asInt()), Arrays.toString(tadShape_N.asInt()));
System.out.println();
assertEquals("TAD asertadShape_J,tadShape_N);*/
}
}
}
log.info("3D TADs:");
for (char o : order) {
INDArray array = Nd4j.create(new int[] {9, 7, 5, 3}, o);
for (int[] shape : dim_3) {
Arrays.sort(shape);
log.info("About to do shape: " + Arrays.toString(shape) + " for array of shape "
+ array.shapeInfoToString());
INDArray assertion = array.javaTensorAlongDimension(0, shape);
INDArray test = array.tensorAlongDimension(0, shape);
assertEquals(assertion, test);
assertEquals(assertion.shapeInfoDataBuffer(), test.shapeInfoDataBuffer());
/*
log.info("Original order: {}; Dimensions: {}; Original shape: {};", o, Arrays.toString(shape), Arrays.toString(array.shapeInfoDataBuffer().asInt()));
log.info("Java shape: {}; Native shape: {}", Arrays.toString(tadShape_J.asInt()), Arrays.toString(tadShape_N.asInt()));
System.out.println();
assertEquals(true, compareShapes(tadShape_N, tadShape_J));*/
}
}
}
@Test
public void testMysteriousCrash() {
INDArray arrayF = Nd4j.create(new int[] {1, 1, 4, 4}, 'f');
INDArray arrayC = Nd4j.create(new int[] {1, 1, 4, 4}, 'c');
INDArray javaCTad = arrayC.javaTensorAlongDimension(0, 2, 3);
INDArray javaFTad = arrayF.javaTensorAlongDimension(0, 2, 3);
Pair<DataBuffer, DataBuffer> tadBuffersF =
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayF, new int[] {2, 3});
Pair<DataBuffer, DataBuffer> tadBuffersC =
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayC, new int[] {2, 3});
log.info("Got TADShapeF: {}", Arrays.toString(tadBuffersF.getFirst().asInt()) + " with java "
+ javaFTad.shapeInfoDataBuffer());
log.info("Got TADShapeC: {}", Arrays.toString(tadBuffersC.getFirst().asInt()) + " with java "
+ javaCTad.shapeInfoDataBuffer());
}
@Override
public char ordering() {
return 'c';
}
/**
* this method compares rank, shape and stride for two given shapeBuffers
* @param shapeA
* @param shapeB
* @return
*/
protected boolean compareShapes(@NonNull DataBuffer shapeA, @NonNull DataBuffer shapeB) {
if (shapeA.dataType() != DataBuffer.Type.INT)
throw new IllegalStateException("ShapeBuffer should have dataType of INT");
if (shapeA.dataType() != shapeB.dataType())
return false;
int rank = shapeA.getInt(0);
if (rank != shapeB.getInt(0))
return false;
for (int e = 1; e <= rank * 2; e++) {
if (shapeA.getInt(e) != shapeB.getInt(e))
return false;
}
return true;
}
}