package org.encog.examples.neural.opencl; import org.encog.Encog; import org.encog.engine.network.train.prop.OpenCLTrainingProfile; import org.encog.engine.network.train.prop.TrainFlatNetworkOpenCL; import org.encog.engine.opencl.EncogCLDevice; import org.encog.engine.opencl.EncogCLError; import org.encog.engine.util.Format; import org.encog.engine.util.Stopwatch; import org.encog.neural.data.NeuralDataSet; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation; import org.encog.neural.networks.training.strategy.end.EndIterationsStrategy; import org.encog.util.benchmark.RandomTrainingFactory; import org.encog.util.logging.Logging; import org.encog.util.simple.EncogUtility; /** * Performs a simple benchmark of your first OpenCL device, compared to the CPU. * If you have multiple OpenCL devices(i.e. two GPU's) this benchmark will only * take advantage of one. To see multiple OpenCL devices used in parallel, use * the BenchmarkConcurrent example. * */ public class BenchmarkCL { public static final int GLOBAL_SIZE = 200; public static final int BENCHMARK_ITERATIONS = 100; public static final double OPENCL_RATIO = 1.0; public static final int ITERATIONS_PER_CYCLE = 1; public static OpenCLTrainingProfile profile; public static long benchmarkCPU(BasicNetwork network, NeuralDataSet training) { ResilientPropagation train = new ResilientPropagation(network, training); EndIterationsStrategy stop; train.addStrategy(stop = new EndIterationsStrategy(BENCHMARK_ITERATIONS)); train.iteration(); // warmup Stopwatch stopwatch = new Stopwatch(); stopwatch.start(); while( !stop.shouldStop() ) { train.iteration(ITERATIONS_PER_CYCLE); } stopwatch.stop(); return stopwatch.getElapsedMilliseconds(); } public static long benchmarkCL(BasicNetwork network, NeuralDataSet training) { profile = new OpenCLTrainingProfile(Encog.getInstance().getCL().chooseDevice()); System.out.println("Using device: " + profile.getDevice().toString()); ResilientPropagation train = new ResilientPropagation(network, training, profile); train.iteration(); // warmup EndIterationsStrategy stop; train.addStrategy(stop = new EndIterationsStrategy(BENCHMARK_ITERATIONS)); Stopwatch stopwatch = new Stopwatch(); stopwatch.start(); while( !stop.shouldStop() ) { train.iteration(ITERATIONS_PER_CYCLE); } stopwatch.stop(); return stopwatch.getElapsedMilliseconds(); } public static void main(String[] args) { try { Logging.stopConsoleLogging(); int outputSize = 2; int inputSize = 10; int trainingSize = 100000; NeuralDataSet training = RandomTrainingFactory.generate(1000, trainingSize, inputSize, outputSize, -1, 1); BasicNetwork network = EncogUtility.simpleFeedForward(training .getInputSize(), 6, 0, training.getIdealSize(), true); network.reset(); System.out.println("Running non-OpenCL test."); long cpuTime = benchmarkCPU(network, training); System.out.println("Non-OpenCL test took " + cpuTime + "ms."); System.out.println(); System.out.println("Starting OpenCL"); Encog.getInstance().initCL(); int i = 0; System.out .println("OpenCL Devices: (Encog will use the first GPU, or CPU if no GPU's)"); for (EncogCLDevice device : Encog.getInstance().getCL() .getDevices()) { System.out.println("Device " + i + ": " + device.toString()); i++; } System.out.println("Running OpenCL test."); long clTime = benchmarkCL(network, training); System.out.println("OpenCL test took " + clTime + "ms."); System.out.println(); System.out.println("ITERATIONS_PER_CYCLE: " + ITERATIONS_PER_CYCLE); System.out.println(); System.out.println(profile.toString()); System.out.println(); String percent = Format.formatPercent((double) cpuTime / (double) clTime); System.out.println("OpenCL Performed at " + percent + " the speed of non-OpenCL"); System.out.println("You will likely get better performance by tuning: ITERATIONS_PER_CYCLE, local ratio, global ratio & segmentation ratio."); } catch (EncogCLError ex) { System.out .println("Can't startup CL, make sure you have drivers loaded."); System.out.println(ex.toString()); } finally { Encog.getInstance().shutdown(); } } }