package org.deeplearning4j.parallelism.inference.observers;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* This class holds reference input, and implements second use case: BATCHED inference
*
* @author raver119@gmail.com
*/
@Slf4j
public class BatchedInferenceObservable extends BasicInferenceObservable implements InferenceObservable {
private List<INDArray[]> inputs = new ArrayList<>();
private List<INDArray[]> outputs = new ArrayList<>();
private AtomicInteger counter = new AtomicInteger(0);
private ThreadLocal<Integer> position = new ThreadLocal<>();
private final Object locker = new Object();
private ReentrantReadWriteLock realLocker = new ReentrantReadWriteLock();
private AtomicBoolean isLocked = new AtomicBoolean(false);
private AtomicBoolean isReadLocked = new AtomicBoolean(false);
public BatchedInferenceObservable() {
}
@Override
public void setInput(INDArray... input) {
synchronized (locker) {
inputs.add(input);
position.set(counter.getAndIncrement());
if (isReadLocked.get())
realLocker.readLock().unlock();
}
}
@Override
public INDArray[] getInput() {
realLocker.writeLock().lock();
isLocked.set(true);
// this method should pile individual examples into single batch
if (counter.get() > 1) {
INDArray[] result = new INDArray[inputs.get(0).length];
for (int i = 0; i < result.length; i++) {
List<INDArray> examples = new ArrayList<>();
for (int e = 0; e < inputs.size(); e++) {
examples.add(inputs.get(e)[i]);
}
result[i] = Nd4j.pile(examples);
}
realLocker.writeLock().unlock();
return result;
} else {
realLocker.writeLock().unlock();
return inputs.get(0);
}
}
@Override
public void setOutput(INDArray... output) {
//this method should split batched output INDArray[] into multiple separate INDArrays
// pre-create outputs
if (counter.get() > 1) {
for (int i = 0; i < counter.get(); i++) {
outputs.add(new INDArray[output.length]);
}
// pull back results for individual examples
int cnt = 0;
for (INDArray array : output) {
int[] dimensions = new int[array.rank() - 1];
for (int i = 1; i < array.rank(); i++) {
dimensions[i - 1] = i;
}
INDArray[] split = Nd4j.tear(array, dimensions);
if (split.length != counter.get())
throw new ND4JIllegalStateException("Number of splits [" + split.length
+ "] doesn't match number of queries [" + counter.get() + "]");
for (int e = 0; e < counter.get(); e++) {
outputs.get(e)[cnt] = split[e];
}
cnt++;
}
} else {
outputs.add(output);
}
this.setChanged();
notifyObservers();
}
/**
* PLEASE NOTE: This method is for tests only
*
* @return
*/
protected List<INDArray[]> getOutputs() {
return outputs;
}
protected void setCounter(int value) {
counter.set(value);
}
public void setPosition(int pos) {
position.set(pos);
}
public int getCounter() {
return counter.get();
}
public boolean isLocked() {
boolean lck = !realLocker.readLock().tryLock();
boolean result = lck || isLocked.get();
if (!result)
isReadLocked.set(true);
return result;
}
@Override
public INDArray[] getOutput() {
// basically we should take care of splits here: each client should get its own part of output, wrt order number
return outputs.get(position.get());
}
}