package org.nd4j.linalg.rng;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.impl.transforms.LegacyDropOutInverted;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
/**
* @author raver119@gmail.com
*/
@Slf4j
@RunWith(Parameterized.class)
public class RandomPerformanceTests extends BaseNd4jTest {
public RandomPerformanceTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testDropoutPerformance() throws Exception {
for (int i = 0; i < 100; i++) {
DropOutInverted opWarmup = new DropOutInverted(Nd4j.createUninitialized(1000000), 0.8);
Nd4j.getExecutioner().exec(opWarmup, Nd4j.getRandom());
}
Nd4j.getExecutioner().commit();
for (int i = 100; i < 100000001; i *= 10) {
INDArray x1 = Nd4j.createUninitialized(i);
INDArray x2 = Nd4j.createUninitialized(i);
LegacyDropOutInverted op1 = new LegacyDropOutInverted(x1, 0.8);
long time1 = System.nanoTime();
Nd4j.getExecutioner().exec(op1);
Nd4j.getExecutioner().commit();
long time2 = System.nanoTime();
long timeLegacy = time2 - time1;
DropOutInverted op2 = new DropOutInverted(x2, 0.8);
time1 = System.nanoTime();
Nd4j.getExecutioner().exec(op2, Nd4j.getRandom());
Nd4j.getExecutioner().commit();
time2 = System.nanoTime();
long timeRecent = time2 - time1;
log.info("Length: {}; Legacy time: {} us, Current time: {} us; Legacy NPE: {} ns; Current NPE: {}", i,
timeLegacy / 1000, timeRecent / 1000, timeLegacy / x1.length(), timeRecent / x1.length());
}
}
@Override
public char ordering() {
return 'c';
}
}