package org.nd4j.linalg.profiling;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
/**
* @author raver119@gmail.com
*/
@RunWith(Parameterized.class)
public class InfNanTests extends BaseNd4jTest {
public InfNanTests(Nd4jBackend backend) {
super(backend);
}
@Before
public void setUp() {
}
@After
public void cleanUp() {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
}
@Test(expected = ND4JIllegalStateException.class)
public void testInf1() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC);
INDArray x = Nd4j.create(100);
x.putScalar(2, Float.NEGATIVE_INFINITY);
OpExecutionerUtil.checkForAny(x);
}
@Test(expected = ND4JIllegalStateException.class)
public void testInf2() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC);
INDArray x = Nd4j.create(100);
x.putScalar(2, Float.NEGATIVE_INFINITY);
OpExecutionerUtil.checkForAny(x);
}
@Test
public void testInf3() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC);
INDArray x = Nd4j.create(100);
OpExecutionerUtil.checkForAny(x);
}
@Test
public void testInf4() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
INDArray x = Nd4j.create(100);
OpExecutionerUtil.checkForAny(x);
}
@Test(expected = ND4JIllegalStateException.class)
public void testNaN1() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC);
INDArray x = Nd4j.create(100);
x.putScalar(2, Float.NaN);
OpExecutionerUtil.checkForAny(x);
}
@Test(expected = ND4JIllegalStateException.class)
public void testNaN2() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC);
INDArray x = Nd4j.create(100);
x.putScalar(2, Float.NaN);
OpExecutionerUtil.checkForAny(x);
}
@Test
public void testNaN3() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC);
INDArray x = Nd4j.create(100);
OpExecutionerUtil.checkForAny(x);
}
@Test
public void testNaN4() throws Exception {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
INDArray x = Nd4j.create(100);
OpExecutionerUtil.checkForAny(x);
}
@Override
public char ordering() {
return 'c';
}
}