package org.nd4j.linalg.profiling;
import lombok.extern.slf4j.Slf4j;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.data.StackAggregator;
import org.nd4j.linalg.profiler.data.primitives.StackDescriptor;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author raver119@gmail.com
*/
@Slf4j
public class StackAggregatorTests {
@Before
public void setUp() {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL);
OpProfiler.getInstance().reset();
}
@After
public void tearDown() {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
}
@Test
public void testBasicBranching1() {
StackAggregator aggregator = new StackAggregator();
aggregator.incrementCount();
aggregator.incrementCount();
assertEquals(2, aggregator.getTotalEventsNumber());
assertEquals(2, aggregator.getUniqueBranchesNumber());
}
@Test
public void testBasicBranching2() {
StackAggregator aggregator = new StackAggregator();
for (int i = 0; i < 10; i++) {
aggregator.incrementCount();
}
assertEquals(10, aggregator.getTotalEventsNumber());
// simnce method is called in loop, there should be only 1 unique code branch
assertEquals(1, aggregator.getUniqueBranchesNumber());
}
@Test
public void testTrailingFrames1() {
StackAggregator aggregator = new StackAggregator();
aggregator.incrementCount();
StackDescriptor descriptor = aggregator.getLastDescriptor();
log.info("Trace: {}", descriptor.toString());
// we just want to make sure that OpProfiler methods are NOT included in trace
assertTrue(descriptor.getStackTrace()[descriptor.size() - 1].getClassName().contains("StackAggregatorTests"));
}
@Test
public void testTrailingFrames2() {
INDArray x = Nd4j.create(new int[] {10, 10}, 'f');
INDArray y = Nd4j.create(new int[] {10, 10}, 'c');
x.assign(y);
x.assign(y);
Nd4j.getExecutioner().commit();
StackAggregator aggregator = OpProfiler.getInstance().getMixedOrderAggregator();
StackDescriptor descriptor = aggregator.getLastDescriptor();
log.info("Trace: {}", descriptor.toString());
assertEquals(2, aggregator.getTotalEventsNumber());
assertEquals(2, aggregator.getUniqueBranchesNumber());
aggregator.renderTree();
}
@Test
public void testScalarAggregator() {
INDArray x = Nd4j.create(10);
x.putScalar(0, 1.0);
double x_0 = x.getDouble(0);
assertEquals(1.0, x_0, 1e-5);
StackAggregator aggregator = OpProfiler.getInstance().getScalarAggregator();
assertEquals(2, aggregator.getTotalEventsNumber());
assertEquals(2, aggregator.getUniqueBranchesNumber());
aggregator.renderTree(false);
}
}