package org.nd4j.linalg.benchmark.app;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.nd4j.linalg.benchmark.api.BenchMarkPerformer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.reflections.Reflections;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.concurrent.TimeUnit;
/**
* Discovers all sub classes of benchmark performer on the class
* path and runs each backend on the class path with each performer.
*
* You can specify the number of trials to run for each benchmark.
*
* @author Adam Gibson
*/
public class BenchmarkRunnerApp {
@Option(name = "--nTrials", usage = "Number of trials to run", aliases = "-n")
private int nTrials = 1000;
@Option(name = "--run", usage = "Trials to run", aliases = "-r")
private String benchmarksToRun;
/**
* Do the main method
* @param args the arguments for the method
* @throws Exception if an exception is thrown
*/
public void doMain(String[] args) throws Exception {
Reflections reflections = new Reflections();
CmdLineParser parser = new CmdLineParser(this);
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
return;
}
ServiceLoader<Nd4jBackend> backends = ServiceLoader.load(Nd4jBackend.class);
Iterator<Nd4jBackend> backendIterator = backends.iterator();
List<Nd4jBackend> allBackends = new ArrayList<>();
Set<String> run = new HashSet<>();
if (benchmarksToRun != null) {
String[] split = benchmarksToRun.split(",");
for (String s : split)
run.add(s);
}
while (backendIterator.hasNext())
allBackends.add(backendIterator.next());
Set<Class<? extends BenchMarkPerformer>> performers = reflections.getSubTypesOf(BenchMarkPerformer.class);
for (Class<? extends BenchMarkPerformer> perfClazz : performers) {
if (Modifier.isAbstract(perfClazz.getModifiers()) || !run.isEmpty() && !run.contains(perfClazz.getName()))
continue;
String begin = "=========================";
String end = "===========================";
System.out.println(begin + " Benchmark: " + perfClazz.getName() + " " + end);
for (Nd4jBackend backend : backends) {
Nd4j nd4j = new Nd4j();
nd4j.initWithBackend(backend);
Constructor<BenchMarkPerformer> performerConstructor =
(Constructor<BenchMarkPerformer>) perfClazz.getConstructor(int.class);
BenchMarkPerformer performer = performerConstructor.newInstance(nTrials);
System.out.println("Running " + backend.getClass().getName());
performer.run(backend);
System.out.println("Backend " + backend.getClass().getName() + " took (in nanoseconds) "
+ performer.averageTime() + " (in milliseconds) "
+ TimeUnit.MILLISECONDS.convert(performer.averageTime(), TimeUnit.NANOSECONDS));
}
System.out.println(begin + end);
}
}
public static void main(String[] args) throws Exception {
new BenchmarkRunnerApp().doMain(args);
}
}