package org.deeplearning4j.parallelism;
import com.google.common.base.Preconditions;
import lombok.Data;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.parallelism.factory.DefaultTrainerContext;
import org.deeplearning4j.parallelism.factory.TrainerContext;
import org.deeplearning4j.parallelism.trainer.Trainer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* This is simple data-parallel wrapper
* suitable for multi-cpu/multi-gpu environments.
*
* PLEASE NOTE: This implementation is NOT NUMA-aware.
*
* @author raver119@gmail.com
*/
// TODO: We want this thing to be NUMA-aware in foreseable future
@Slf4j
@Data
public class ParallelWrapper implements AutoCloseable {
protected Model model;
protected int workers = 2;
protected int prefetchSize = 2;
protected int averagingFrequency = 1;
protected Trainer[] zoo;
private TrainerContext trainerContext = new DefaultTrainerContext();
protected AtomicLong iterationsCounter = new AtomicLong(0);
protected boolean reportScore = false;
protected boolean averageUpdaters = true;
protected boolean legacyAveraging = false;
protected boolean wasAveraged = false;
protected AtomicBoolean stopFit = new AtomicBoolean(false);
protected List<IterationListener> listeners = new ArrayList<>();
protected StatsStorageRouter storageRouter;
protected boolean isMQ;
protected WorkspaceMode workspaceMode;
private Object[] trainerContextArgs;
private MagicQueue mq;
// log uncaught exceptions
Thread.UncaughtExceptionHandler handler = new Thread.UncaughtExceptionHandler() {
public void uncaughtException(Thread th, Throwable ex) {
log.error("Uncaught exception: " + ex);
}
};
protected ParallelWrapper(Model model, int workers, int prefetchSize) {
this.model = model;
this.workers = workers;
this.prefetchSize = prefetchSize;
if (this.model instanceof MultiLayerNetwork) {
((MultiLayerNetwork) this.model).getUpdater();
} else if (this.model instanceof ComputationGraph) {
((ComputationGraph) this.model).getUpdater();
}
}
@Override
public void close() throws Exception {
if (zoo != null) {
for (int i = 0; i < zoo.length; i++) {
if (zoo[i] != null)
zoo[i].shutdown();
}
zoo = null;
}
}
/**
* This method causes all threads used for parallel training to stop
*/
public synchronized void shutdown() {
try {
close();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* Will stop a fit operation from continuing to iterate.
*/
public void stopFit() {
stopFit.set(true);
}
/**
*
* @param source
*/
public synchronized void fit(@NonNull MultiDataSetIterator source) {
stopFit.set(false);
createZooIfNeccessary(true);
if (source.resetSupported())
source.reset();
MultiDataSetIterator iterator = source;
if (prefetchSize > 0 && source.asyncSupported()) {
if (isMQ) {
if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers,
Nd4j.getAffinityManager().getNumberOfDevices());
iterator = new AsyncMultiDataSetIterator(source, prefetchSize, new LinkedBlockingQueue<>(prefetchSize * workers), true, new InterleavedDataSetCallback(prefetchSize * 2));
} else iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
}
AtomicInteger locker = new AtomicInteger(0);
long time1 = System.currentTimeMillis();
while (iterator.hasNext() && !stopFit.get()) {
MultiDataSet dataSet = iterator.next();
long time2 = System.currentTimeMillis();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
zoo[pos].feedMultiDataSet(dataSet, time2 - time1);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
// averaging model
double score = getScore(locker);
// averaging updaters state
averageUpdatersState(locker, score);
}
locker.set(0);
}
time1 = System.currentTimeMillis();
}
if (prefetchSize > 0 && source.asyncSupported())
((AsyncMultiDataSetIterator) iterator).shutdown();
if (zoo != null) {
for (int i = 0; i < zoo.length; i++) {
zoo[i].shutdown();
}
zoo = null;
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
private double getScore(AtomicInteger locker) {
wasAveraged = true;
double score = 0.0;
List<INDArray> params = new ArrayList<>();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
params.add(zoo[cnt].getModel().params());
score += zoo[cnt].getModel().score();
}
Nd4j.averageAndPropagate(model.params(), params);
score /= Math.min(workers, locker.get());
// TODO: improve this
if (reportScore)
log.info("Averaged score: " + score);
return score;
}
private void averageUpdatersState(AtomicInteger locker, double score) {
// averaging updaters state
if (model instanceof MultiLayerNetwork) {
if (averageUpdaters) {
Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
List<INDArray> updaters = new ArrayList<>();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
}
Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
}
}
((MultiLayerNetwork) model).setScore(score);
} else if (model instanceof ComputationGraph) {
if (averageUpdaters) {
ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
List<INDArray> updaters = new ArrayList<>();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
}
Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
}
}
((ComputationGraph) model).setScore(score);
}
}
/**
* This method allows you to specify IterationListeners for this model.
* Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead
* using {@link #setListeners(StatsStorageRouter, Collection)}
*
* @param listeners Listeners to set
*/
public void setListeners(@NonNull Collection<IterationListener> listeners) {
setListeners(null, listeners);
}
/**
* This method allows you to specify IterationListeners for this model.
* Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead
* using {@link #setListeners(StatsStorageRouter, Collection)}
*
* @param listeners Listeners to set
*/
public void setListeners(@NonNull IterationListener... listeners) {
setListeners(Arrays.asList(listeners));
}
/**
* Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
* that implement the {@link RoutingIterationListener} interface)
*
* @param statsStorage Stats storage router to place the results into
* @param listeners Listeners to set
*/
public void setListeners(StatsStorageRouter statsStorage, IterationListener... listeners) {
setListeners(statsStorage, Arrays.asList(listeners));
}
/**
* Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
* that implement the {@link RoutingIterationListener} interface)
*
* @param statsStorage Stats storage router to place the results into
* @param listeners Listeners to set
*/
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends IterationListener> listeners) {
//Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
if (listeners != null) {
for (IterationListener l : listeners) {
if (l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
if (statsStorage == null && rl.getStorageRouter() == null) {
log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}",
l);
} else if (rl.getStorageRouter() != null && !(rl.getStorageRouter() instanceof Serializable)) {
//Spark would throw a (probably cryptic) serialization exception later anyway...
throw new IllegalStateException(
"RoutingIterationListener provided with non-serializable storage router "
+ "\nRoutingIterationListener class: " + rl.getClass().getName()
+ "\nStatsStorageRouter class: "
+ rl.getStorageRouter().getClass().getName());
}
}
}
this.listeners.addAll(listeners);
} else {
this.listeners.clear();
}
this.storageRouter = statsStorage;
}
/**
* This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
*
* @param source
*/
public synchronized void fit(@NonNull DataSetIterator source) {
stopFit.set(false);
createZooIfNeccessary(false);
if (source.asyncSupported())
source.reset();
DataSetIterator iterator = source;
if (prefetchSize > 0 && source.asyncSupported()) {
log.info("Creating asynchronous prefetcher...");
if (isMQ) {
if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers,
Nd4j.getAffinityManager().getNumberOfDevices());
// if (mq == null)
// mq = new MagicQueue.Builder().setCapacityPerFlow(prefetchSize).setMode(MagicQueue.Mode.SEQUENTIAL).setType(MagicQueue.Type.DS)
// .setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
iterator = new AsyncDataSetIterator(source, prefetchSize, new LinkedBlockingQueue<>(prefetchSize * workers), true, new InterleavedDataSetCallback(prefetchSize * 2));
} else
iterator = new AsyncDataSetIterator(source, prefetchSize);
}
List<Long> nanos = new ArrayList<>();
AtomicInteger locker = new AtomicInteger(0);
long time1 = System.currentTimeMillis();
while (iterator.hasNext() && !stopFit.get()) {
//int intcnt = 0;
//while (intcnt < 1000) {
//intcnt++;
DataSet dataSet = iterator.next();
long time2 = System.currentTimeMillis();
long lastEtlTime = time2 - time1;
//nanos.add((time2 - time1));
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as DataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
if (zoo == null)
throw new IllegalStateException(
"ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
zoo[pos].feedDataSet(dataSet, lastEtlTime );
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers ) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
long timeA1 = System.currentTimeMillis();
// model averaging happens within
double score = getScore(locker);
// updaters averging happens within (if any)
averageUpdatersState(locker, score);
long timeA2 = System.currentTimeMillis();
if (reportScore)
log.info("Averaging time: {} ms", timeA2 - timeA1);
}
locker.set(0);
}
time1 = System.currentTimeMillis();
}
if (prefetchSize > 0 && source.asyncSupported())
((AsyncDataSetIterator) iterator).shutdown();
if (zoo != null) {
for (int i = 0; i < zoo.length; i++) {
zoo[i].shutdown();
}
zoo = null;
}
//Collections.sort(nanos);
//int pos = (int) (nanos.size() * 0.85);
//log.info("p85 ETL time: {} ms; p50 ETL time: {} ms", nanos.get(pos), nanos.get(nanos.size() / 2));
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
}
private void createZooIfNeccessary(boolean useMDS) {
if (zoo == null) {
trainerContext.init(model, trainerContextArgs);
zoo = new Trainer[workers];
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
for (int cnt = 0; cnt < workers; cnt++) {
// we pass true here, to tell Trainer to use MultiDataSet queue for training
zoo[cnt] = trainerContext.create(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(),
useMDS, this, workspaceMode, averagingFrequency);
zoo[cnt].setUncaughtExceptionHandler(handler);
if (zoo[cnt] instanceof Thread) {
Nd4j.getAffinityManager().attachThreadToDevice((Thread) zoo[cnt], cnt % numDevices);
}
zoo[cnt].start();
}
}
}
public static class Builder<T extends Model> {
protected T model;
protected int workers = Nd4j.getAffinityManager().getNumberOfDevices();
protected int prefetchSize = 16;
protected int averagingFrequency = 1;
protected boolean reportScore = false;
protected boolean averageUpdaters = true;
protected boolean legacyAveraging = true;
protected boolean isMQ = Nd4j.getAffinityManager().getNumberOfDevices() > 1;
protected TrainerContext trainerContext = new DefaultTrainerContext();
protected Object[] trainerContextArgs;
protected WorkspaceMode workspaceMode = WorkspaceMode.SEPARATE;
/**
* Transer context args are for calling a
* {@link TrainerContext} init method
* when {@link ParallelWrapper} starts training
* @param trainerContextArgs the args to use (maybe null)
* @return
*/
public Builder trainerContextArgs(Object... trainerContextArgs) {
this.trainerContextArgs = trainerContextArgs;
return this;
}
/**
* Specify a {@link TrainerContext}
* for the given {@link ParallelWrapper}
* instance.
* Defaults to {@link DefaultTrainerContext}
* otherwise
* @param trainerContext the trainer factory to use
* @return builder pattern
*/
public Builder trainerFactory(TrainerContext trainerContext) {
Preconditions.checkNotNull(trainerContext);
this.trainerContext = trainerContext;
return this;
}
public Builder workspaceMode(@NonNull WorkspaceMode mode) {
this.workspaceMode = mode;
return this;
}
/**
* Build ParallelWrapper for MultiLayerNetwork
*
* @param model
*/
public Builder(@NonNull T model) {
this.model = model;
}
/**
* This method allows to configure number of workers that'll be used for parallel training
*
* @param num
* @return
*/
public Builder workers(int num) {
if (num < 2)
throw new RuntimeException("Number of workers can't be lower then 2!");
this.workers = num;
return this;
}
/**
* Model averaging frequency.
*
* @param freq number of iterations between averaging
* @return
*/
public Builder averagingFrequency(int freq) {
this.averagingFrequency = freq;
return this;
}
/**
* This method enables/disables updaters averaging.
*
* Default value: TRUE
*
* PLEASE NOTE: This method is suitable for debugging purposes mostly. So don't change default value, unless you're sure why you need it.
*
* @param reallyAverage
* @return
*/
public Builder averageUpdaters(boolean reallyAverage) {
this.averageUpdaters = reallyAverage;
return this;
}
/**
* This method enables/disable MagicQueue use
* If set to true, all datasets will be spread among all available devices at prefetch phase using AsyncDataSetIterator
*
* PLEASE NOTE: This is experimental feature.
*
* Default: true
*
* @param reallyUse
* @return
*/
public Builder useMQ(boolean reallyUse) {
//this.isMQ = reallyUse;
return this;
}
/**
* Size of prefetch buffer that will be used for background data prefetching.
* Usually it's better to keep this value equal to the number of workers.
*
* Default value: 2
*
* @param size 0 to disable prefetching, any positive number
* @return
*/
public Builder prefetchBuffer(int size) {
if (size < 0)
size = 0;
this.prefetchSize = size;
return this;
}
/**
* If set to true, legacy averaging method is used. This might be used as fallback on multi-gpu systems without P2P access available.
*
* Default value: false
*
* @param reallyUse
* @return
*/
public Builder useLegacyAveraging(boolean reallyUse) {
this.legacyAveraging = reallyUse;
return this;
}
/**
* This method enables/disables averaged model score reporting
*
* @param reallyReport
* @return
*/
public Builder reportScoreAfterAveraging(boolean reallyReport) {
this.reportScore = reallyReport;
return this;
}
/**
* This method returns ParallelWrapper instance
*
* @return
*/
public ParallelWrapper build() {
ParallelWrapper wrapper = new ParallelWrapper(model, workers, prefetchSize);
wrapper.averagingFrequency = this.averagingFrequency;
wrapper.reportScore = this.reportScore;
wrapper.averageUpdaters = this.averageUpdaters;
wrapper.legacyAveraging = this.legacyAveraging;
wrapper.isMQ = this.isMQ;
wrapper.workspaceMode = this.workspaceMode;
wrapper.trainerContext = this.trainerContext;
return wrapper;
}
}
private static IterationListener cloneListener(IterationListener original) {
if (original instanceof RoutingIterationListener) {
return ((RoutingIterationListener) original).clone();
}
return original;
}
private void configureListeners(String workerUUID, Collection<IterationListener> oldListeners,
Collection<IterationListener> replicatedListeners) {
for (IterationListener listener : oldListeners) {
IterationListener l = cloneListener(listener);
if (l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
//We're assuming session ID is set by the original RoutingIterationListener constructor, which means
// it will be synced across all cloned instances
rl.setSessionID(((RoutingIterationListener) listener).getSessionID());
rl.setWorkerID(workerUUID);
StatsStorageRouter currentRouter = ((RoutingIterationListener) listener).getStorageRouter();
if (currentRouter != null) {
//User has set router on the listener/model, instead of via the
// setListeners(StatsStorageRouter, ...) method
rl.setStorageRouter(currentRouter);
} else {
rl.setStorageRouter(ParallelWrapper.this.storageRouter);
}
}
replicatedListeners.add(l);
}
}
}