package org.nd4j.linalg.crash; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.RandomUtils; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative; import org.nd4j.linalg.api.ops.impl.transforms.Sqrt; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; /** * This set of test launches different ops in different order, to check for possible data corruption cases * * @author raver119@gmail.com */ @Slf4j @RunWith(Parameterized.class) public class CrashTest extends BaseNd4jTest { public CrashTest(Nd4jBackend backend) { super(backend); } private static final int ITERATIONS = 10; private static final boolean[] paramsA = new boolean[] {true, false}; private static final boolean[] paramsB = new boolean[] {true, false}; /** * tensorAlongDimension() produces shapeInfo without EWS defined */ @Test public void testNonEWSViews1() { System.out.println("non-EWS 1"); INDArray x = Nd4j.create(64, 1024, 64); INDArray y = Nd4j.create(64, 64, 1024); for (int i = 0; i < ITERATIONS; i++) { int slice = RandomUtils.nextInt(0, x.size(0)); op(x.tensorAlongDimension(slice, 1, 2), y.tensorAlongDimension(slice, 1, 2), i); } } @Test public void testNonEWSViews2() { System.out.println("non-EWS 2"); INDArray x = Nd4j.create(new int[] {64, 1024, 64}, 'f'); INDArray y = Nd4j.create(new int[] {64, 64, 1024}, 'f'); for (int i = 0; i < ITERATIONS; i++) { int slice = RandomUtils.nextInt(0, x.size(0)); op(x.tensorAlongDimension(slice, 1, 2), y.tensorAlongDimension(slice, 1, 2), i); } } /** * slice() produces shapeInfo with EWS being 1 in our case */ @Test public void testEWSViews1() { System.out.println("EWS 1"); INDArray x = Nd4j.create(64, 1024, 64); INDArray y = Nd4j.create(64, 64, 1024); for (int i = 0; i < ITERATIONS; i++) { int slice = RandomUtils.nextInt(0, x.shape()[0]); op(x.slice(slice), y.slice(slice), i); } } @Test public void testEWSViews2() { System.out.println("EWS 2"); INDArray x = Nd4j.create(new int[] {96, 1024, 64}, 'f'); INDArray y = Nd4j.create(new int[] {96, 64, 1024}, 'f'); for (int i = 0; i < 1; i++) { int slice = 0; //RandomUtils.nextInt(0, x.shape()[0]); op(x.slice(slice), y.slice(slice), i); } } protected void op(INDArray x, INDArray y, int i) { // broadcast along row & column INDArray row = Nd4j.ones(64); INDArray column = Nd4j.ones(1024, 1); x.addiRowVector(row); x.addiColumnVector(column); // casual scalar x.addi(i * 2); // reduction along all dimensions float sum = x.sumNumber().floatValue(); // index reduction Nd4j.getExecutioner().exec(new IMax(x), Integer.MAX_VALUE); // casual transform Nd4j.getExecutioner().exec(new Sqrt(x, x)); // dup INDArray x1 = x.dup(x.ordering()); INDArray x2 = x.dup(x.ordering()); INDArray x3 = x.dup('c'); INDArray x4 = x.dup('f'); // vstack && hstack INDArray vstack = Nd4j.vstack(x, x1, x2, x3, x4); INDArray hstack = Nd4j.hstack(x, x1, x2, x3, x4); // reduce3 call Nd4j.getExecutioner().exec(new ManhattanDistance(x, x2)); // flatten call INDArray flat = Nd4j.toFlattened(x, x1, x2, x3, x4); // reduction along dimension: row & column INDArray max_0 = x.max(0); INDArray max_1 = x.max(1); // index reduction along dimension: row & column INDArray imax_0 = Nd4j.argMax(x, 0); INDArray imax_1 = Nd4j.argMax(x, 1); // logisoftmax, softmax & softmax derivative Nd4j.getExecutioner().exec(new SoftMax(x)); Nd4j.getExecutioner().exec(new SoftMaxDerivative(x)); Nd4j.getExecutioner().exec(new LogSoftMax(x)); // BooleanIndexing BooleanIndexing.replaceWhere(x, 5f, Conditions.lessThan(8f)); // assing on view BooleanIndexing.assignIf(x, x1, Conditions.greaterThan(-1000000000f)); // std var along all dimensions float std = x.stdNumber().floatValue(); // std var along row & col INDArray xStd_0 = x.std(0); INDArray xStd_1 = x.std(1); // blas call float dot = (float) Nd4j.getBlasWrapper().dot(x, x1); // mmul for (boolean tA : paramsA) { for (boolean tB : paramsB) { INDArray xT = tA ? x.dup() : x.dup().transpose(); INDArray yT = tB ? y.dup() : y.dup().transpose(); Nd4j.gemm(xT, yT, tA, tB); } } // specially for views, checking here without dup and rollover Nd4j.gemm(x, y, false, false); System.out.println("Iteration passed: " + i); } @Override public char ordering() { return 'c'; } }