package org.nd4j.linalg.crash;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.RandomUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
/**
* @author raver119@gmail.com
*/
@Slf4j
@RunWith(Parameterized.class)
public class SpecialTests extends BaseNd4jTest {
public SpecialTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testDimensionalThings1() {
INDArray x = Nd4j.rand(new int[] {20, 30, 50});
INDArray y = Nd4j.rand(x.shape());
INDArray result = transform(x, y);
}
@Test
public void testDimensionalThings2() {
INDArray x = Nd4j.rand(new int[] {20, 30, 50});
INDArray y = Nd4j.rand(x.shape());
for (int i = 0; i < 1; i++) {
int number = 5;
int start = RandomUtils.nextInt(0, x.shape()[2] - number);
transform(getView(x, start, 5), getView(y, start, 5));
}
}
protected static INDArray getView(INDArray x, int from, int number) {
return x.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(from, from + number));
}
protected static INDArray transform(INDArray a, INDArray b) {
int nShape[] = new int[] {1, 2};
INDArray a_reduced = a.sum(nShape);
INDArray b_reduced = b.sum(nShape);
//log.info("reduced shape: {}", Arrays.toString(a_reduced.shapeInfoDataBuffer().asInt()));
return Transforms.abs(a_reduced.sub(b_reduced)).div(a_reduced);
}
@Override
public char ordering() {
return 'c';
}
}