package org.deeplearning4j.parallelism;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
import java.util.Observer;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
/**
* This class is simple wrapper for
* ParallelInference using batched input
*
* @author raver119@gmail.com
*/
@Slf4j
public class ParallelInference {
private Model model;
private long nanos;
private int workers;
private int batchLimit;
private InferenceMode inferenceMode;
private int queueLimit;
// this queue
private BlockingQueue<InferenceObservable> observables;
private final Object locker = new Object();
private InferenceWorker[] zoo;
private ObservablesProvider provider;
public final static int DEFAULT_NUM_WORKERS = Nd4j.getAffinityManager().getNumberOfDevices();
public final static int DEFAULT_BATCH_LIMIT = 32;
public final static InferenceMode DEFAULT_INFERENCE_MODE = InferenceMode.BATCHED;
public final static int DEFAULT_QUEUE_LIMIT = 64;
protected ParallelInference() {
//
}
protected void init() {
observables = new LinkedBlockingQueue<>(queueLimit);
zoo = new InferenceWorker[workers];
for (int i = 0; i < workers; i++) {
zoo[i] = new InferenceWorker(i, model, observables);
zoo[i].start();
}
if (inferenceMode == InferenceMode.BATCHED) {
log.info("Initializing ObservablesProvider...");
provider = new ObservablesProvider(nanos, batchLimit, observables);
}
}
protected long getWorkerCounter(int workerIdx) {
return zoo[workerIdx].getCounterValue();
}
/**
*
* @param input
* @return
*/
public INDArray output(double[] input) {
return output(Nd4j.create(input));
}
/**
*
* @param input
* @return
*/
public INDArray output(float[] input) {
return output(Nd4j.create(input));
}
public INDArray output(INDArray input) {
// basically, depending on model type we either
// throw stuff to specific model, or wait for batch
return output(new INDArray[] {input})[0];
}
/**
*
* @param dataSet
* @return
*/
public INDArray output(DataSet dataSet) {
return output(dataSet.getFeatureMatrix());
}
/**
*
* @param input
* @return
*/
public INDArray[] output(INDArray... input) {
// basically, depending on model type we either throw stuff to specific model, or wait for batch
BasicInferenceObserver observer = new BasicInferenceObserver();
InferenceObservable observable;
if (inferenceMode == InferenceMode.SEQUENTIAL) {
observable = new BasicInferenceObservable(input);
observable.addObserver(observer);
try {
observables.put(observable);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else {
observable = provider.setInput(observer, input);
}
try {
// submit query to processing
// and block until Observable returns
//observer.wait();
observer.waitTillDone();
} catch (Exception e) {
throw new RuntimeException(e);
}
return observable.getOutput();
}
public static class Builder {
private Model model;
private int workers = DEFAULT_NUM_WORKERS;
private int batchLimit = DEFAULT_BATCH_LIMIT;
private InferenceMode inferenceMode = DEFAULT_INFERENCE_MODE;
private int queueLimit = DEFAULT_QUEUE_LIMIT;
public Builder(@NonNull Model model) {
this.model = model;
}
/**
* This method allows you to define mode that'll be used during inference. Options are:
*
* SEQUENTIAL: Input will be sent to last-used worker unmodified.
* BATCHED: Multiple inputs will be packed into single batch, and
* sent to last-used device.
*
* @param inferenceMode
* @return
*/
public Builder inferenceMode(@NonNull InferenceMode inferenceMode) {
this.inferenceMode = inferenceMode;
return this;
}
/**
* This method defines, how many model copies will be used for inference.
*
* PLEASE NOTE: This method primarily suited for multi-GPU systems
*
* @param workers
* @return
*/
public Builder workers(int workers) {
if (workers < 1)
throw new IllegalStateException("Workers should be positive value");
this.workers = workers;
return this;
}
/**
* This method defines, how many input samples can
* be batched within given time frame.
*
* PLEASE NOTE: This value has no effect in
* SEQUENTIAL inference mode
*
* @param limit
* @return
*/
public Builder batchLimit(int limit) {
if (limit < 1)
throw new IllegalStateException("Batch limit should be positive value");
this.batchLimit = limit;
return this;
}
/**
* This method defines buffer queue size.
*
* Default value: 64
*
* @param limit
* @return
*/
public Builder queueLimit(int limit) {
if (limit < 1)
throw new IllegalStateException("Queue limit should be positive value");
this.queueLimit = limit;
return this;
}
/**
* This method builds new ParallelInference instance
*
* @return
*/
public ParallelInference build() {
ParallelInference inference = new ParallelInference();
inference.batchLimit = this.batchLimit;
inference.queueLimit = this.queueLimit;
inference.inferenceMode = this.inferenceMode;
inference.model = this.model;
inference.workers = this.workers;
inference.init();
return inference;
}
}
/**
* This class actually does inference with respect to device affinity
*
*/
private class InferenceWorker extends Thread implements Runnable {
private BlockingQueue<InferenceObservable> inputQueue;
private AtomicBoolean shouldWork = new AtomicBoolean(true);
private AtomicBoolean isStopped = new AtomicBoolean(false);
private Model protoModel;
private Model replicatedModel;
private AtomicLong counter = new AtomicLong(0);
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue) {
this.inputQueue = inputQueue;
this.protoModel = model;
this.setDaemon(true);
this.setName("InferenceThread-" + id);
}
protected long getCounterValue() {
return counter.get();
}
@Override
public void run() {
try {
// model should be replicated & initialized here
if (protoModel instanceof ComputationGraph) {
this.replicatedModel = new ComputationGraph(ComputationGraphConfiguration
.fromJson(((ComputationGraph) protoModel).getConfiguration().toJson()));
this.replicatedModel.init();
synchronized (locker) {
this.replicatedModel.setParams(protoModel.params());
Nd4j.getExecutioner().commit();
}
} else if (protoModel instanceof MultiLayerNetwork) {
this.replicatedModel = new MultiLayerNetwork(MultiLayerConfiguration
.fromJson(((MultiLayerNetwork) protoModel).getLayerWiseConfigurations().toJson()));
this.replicatedModel.init();
synchronized (locker) {
this.replicatedModel.setParams(protoModel.params());
Nd4j.getExecutioner().commit();
}
}
while (shouldWork.get()) {
InferenceObservable request = inputQueue.take();
if (request != null) {
counter.incrementAndGet();
// FIXME: get rid of instanceof here, model won't change during runtime anyway
if (replicatedModel instanceof ComputationGraph) {
INDArray[] output = ((ComputationGraph) replicatedModel).output(false, request.getInput());
request.setOutput(output);
} else if (replicatedModel instanceof MultiLayerNetwork) {
INDArray output = ((MultiLayerNetwork) replicatedModel).output(request.getInput()[0]);
request.setOutput(output);
}
} else {
// just do nothing, i guess and hope for next round?
}
}
} catch (InterruptedException e) {
// do nothing
} catch (Exception e) {
throw new RuntimeException(e);
}
isStopped.set(true);
}
protected void shutdown() {
shouldWork.set(false);
while (!isStopped.get()) {
// block until main loop is finished
}
}
}
protected static class ObservablesProvider {
private BlockingQueue<InferenceObservable> targetQueue;
private long nanos;
private int batchLimit;
private volatile BatchedInferenceObservable currentObservable;
private final Object locker = new Object();
protected ObservablesProvider(long nanos, int batchLimit, @NonNull BlockingQueue<InferenceObservable> queue) {
this.targetQueue = queue;
this.nanos = nanos;
this.batchLimit = batchLimit;
}
protected InferenceObservable setInput(@NonNull Observer observer, INDArray... input) {
synchronized (locker) {
boolean isNew = false;
if (currentObservable == null || currentObservable.getCounter() >= batchLimit
|| currentObservable.isLocked()) {
isNew = true;
currentObservable = new BatchedInferenceObservable();
}
currentObservable.setInput(input);
currentObservable.addObserver(observer);
try {
if (isNew)
targetQueue.put(currentObservable);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return currentObservable;
}
}
}
}