package org.deeplearning4j.parallelism.trainer;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
/**
* Trains datasets using a standard in memory
* parameter averaging technique.
* Think of this worker as the simplest form of doing parameter averaging
*
* @author Adam Gibson
*/
@Builder
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
public class DefaultTrainer extends Thread implements Trainer {
protected Model originalModel;
protected Model replicatedModel;
@Builder.Default protected LinkedBlockingQueue<DataSet> queue = new LinkedBlockingQueue<>();
@Builder.Default protected LinkedBlockingQueue<MultiDataSet> queueMDS = new LinkedBlockingQueue<>();
@Builder.Default protected AtomicInteger running = new AtomicInteger(0);
protected int threadId;
@Builder.Default protected AtomicBoolean shouldUpdate = new AtomicBoolean(false);
@Builder.Default protected AtomicBoolean shouldStop = new AtomicBoolean(false);
protected Exception thrownException;
@Builder.Default protected volatile boolean useMDS = false;
protected final String uuid = UUID.randomUUID().toString();
@Builder.Default protected boolean onRootModel = false;
protected ParallelWrapper parallelWrapper;
protected WorkspaceMode workspaceMode;
@Builder.Default protected AtomicLong lastEtlTime = new AtomicLong(0);
@Builder.Default protected AtomicBoolean nullMode = new AtomicBoolean(false);
protected DataSet nullDataSet;
@Builder.Default protected AtomicBoolean isStopped = new AtomicBoolean(false);
protected int averagingFrequency;
@Override
public void feedMultiDataSet(@NonNull MultiDataSet dataSet, long etlTime) {
setupIfNeccessary();
running.incrementAndGet();
queueMDS.add(dataSet);
if (lastEtlTime == null)
lastEtlTime = new AtomicLong(0);
lastEtlTime.set(etlTime);
}
@Override
public void feedDataSet(DataSet dataSet, long etlTime) {
setupIfNeccessary();
running.incrementAndGet();
if (dataSet != null)
queue.add(dataSet);
else {
if (nullMode == null)
nullMode = new AtomicBoolean(false);
nullMode.set(true);
}
if (lastEtlTime == null)
lastEtlTime = new AtomicLong(0);
lastEtlTime.set(etlTime);
}
@Override
public Model getModel() {
return replicatedModel;
}
@Override
public void updateModel(@NonNull Model model) {
this.shouldUpdate.set(true);
if (replicatedModel instanceof MultiLayerNetwork) {
replicatedModel.setParams(model.params().dup());
Updater updater = ((MultiLayerNetwork) model).getUpdater();
INDArray view = updater.getStateViewArray();
if (view != null) {
updater = ((MultiLayerNetwork) replicatedModel).getUpdater();
INDArray viewD = view.dup();
Nd4j.getExecutioner().commit();
updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false);
}
} else if (replicatedModel instanceof ComputationGraph) {
replicatedModel.setParams(model.params().dup());
ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
INDArray view = updater.getStateViewArray();
if (view != null) {
INDArray viewD = view.dup();
Nd4j.getExecutioner().commit();
updater = ((ComputationGraph) replicatedModel).getUpdater();
updater.setStateViewArray(viewD);
}
}
Nd4j.getExecutioner().commit();
}
protected void setupIfNeccessary() {
if (queue == null)
queue = new LinkedBlockingQueue<>();
if (queueMDS == null)
queueMDS = new LinkedBlockingQueue<>();
if (running == null)
running = new AtomicInteger(0);
if (shouldStop == null)
shouldStop = new AtomicBoolean(false);
if (shouldUpdate == null)
shouldUpdate = new AtomicBoolean(false);
if (isStopped == null)
isStopped = new AtomicBoolean(false);
}
@Override
public boolean isRunning() {
// if Trainer thread got exception during training - rethrow it here
if (thrownException != null)
throw new RuntimeException(thrownException);
return running.get() == 0;
}
@Override
public void shutdown() {
shouldStop.set(true);
while (!isStopped.get())
LockSupport.parkNanos(1000L);
shouldStop.set(false);
isStopped.set(false);
}
@Override
public void run() {
setupIfNeccessary();
AtomicInteger iterationsCounter = new AtomicInteger(0);
try {
// we create fresh network, with the same configuration, as initially created by user
// however, we don't need clone or anything here
if (originalModel instanceof MultiLayerNetwork) {
if (!onRootModel) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(
((MultiLayerNetwork) originalModel).getLayerWiseConfigurations().toJson());
conf.setTrainingWorkspaceMode(workspaceMode);
this.replicatedModel = new MultiLayerNetwork(conf);
replicatedModel.init();
// we replicate original model params & updater state, just in case it's pre-trained model
synchronized (originalModel) {
replicatedModel.setParams(originalModel.params());
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
Updater updaterOrigina = ((MultiLayerNetwork) originalModel).getUpdater();
if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null)
updaterReplica.setStateViewArray((MultiLayerNetwork) replicatedModel,
updaterOrigina.getStateViewArray().dup(), false);
Nd4j.getExecutioner().commit();
}
Collection<IterationListener> oldListeners = ((MultiLayerNetwork) originalModel).getListeners();
oldListeners = (oldListeners == null ? new ArrayList<>() : new ArrayList<>(oldListeners));
Collection<IterationListener> replicatedListeners = new ArrayList<>();
if (parallelWrapper.getListeners() != null) {
oldListeners.addAll(parallelWrapper.getListeners());
}
configureListeners(uuid, oldListeners, replicatedListeners);
this.replicatedModel.setListeners(replicatedListeners);
}
} else if (originalModel instanceof ComputationGraph) {
if (!onRootModel) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration
.fromJson(((ComputationGraph) originalModel).getConfiguration().toJson());
conf.setTrainingWorkspaceMode(workspaceMode);
this.replicatedModel = new ComputationGraph(conf);
this.replicatedModel.init();
// we replicate original model params & updater state, just in case it's pre-trained model
synchronized (originalModel) {
replicatedModel.setParams(originalModel.params());
ComputationGraphUpdater updaterReplica = ((ComputationGraph) replicatedModel).getUpdater();
ComputationGraphUpdater updaterOrigina = ((ComputationGraph) originalModel).getUpdater();
if (updaterOrigina != null && updaterOrigina.getStateViewArray() != null)
updaterReplica.setStateViewArray(updaterOrigina.getStateViewArray().dup());
Nd4j.getExecutioner().commit();
}
Collection<IterationListener> oldListeners = ((ComputationGraph) originalModel).getListeners();
oldListeners = (oldListeners == null ? new ArrayList<>() : new ArrayList<>(oldListeners));
Collection<IterationListener> replicatedListeners = new ArrayList<>();
if (parallelWrapper.getListeners() != null) {
oldListeners.addAll(parallelWrapper.getListeners());
}
configureListeners(uuid, oldListeners, replicatedListeners);
this.replicatedModel.setListeners(replicatedListeners);
}
}
if (!useMDS) {
while (!shouldStop.get()) {
DataSet dataSet = null;
if (nullMode == null || !nullMode.get())
dataSet = queue.poll(100, TimeUnit.MILLISECONDS);
else {
if (nullDataSet == null)
nullDataSet = new org.nd4j.linalg.dataset.DataSet(Nd4j.create(64, 28 * 28), Nd4j.create(64, 10));
dataSet = nullDataSet;
}
if (dataSet != null) {
//if (Nd4j.getAffinityManager().getDeviceForCurrentThread() != Nd4j.getAffinityManager().getDeviceForArray(dataSet.getFeatures()))
// log.debug("Thread: {}; Bad align for data: {}/{}", Thread.currentThread().getId(), Nd4j.getAffinityManager().getDeviceForCurrentThread(), Nd4j.getAffinityManager().getDeviceForArray(dataSet.getFeatures()));
if (replicatedModel instanceof MultiLayerNetwork) {
((MultiLayerNetwork) replicatedModel).setLastEtlTime(lastEtlTime.get());
((MultiLayerNetwork) replicatedModel).fit(dataSet);
} else if (replicatedModel instanceof ComputationGraph) {
((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());
((ComputationGraph) replicatedModel).fit(dataSet);
}
// if we don't support cross-device stuff (like multi-gpu on windows) - sync back to host
if (!Nd4j.getAffinityManager().isCrossDeviceAccessSupported() && iterationsCounter.incrementAndGet() % averagingFrequency == 0) {
// we ensure all operations are finished in this training round
Nd4j.getExecutioner().commit();
// we ensure memory is updated on host side
Nd4j.getAffinityManager().ensureLocation(replicatedModel.params(), AffinityManager.Location.HOST);
if (replicatedModel instanceof MultiLayerNetwork) {
Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater();
if (updaterReplica.getStateViewArray() != null)
Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);
} else {
ComputationGraphUpdater updaterReplica = ((ComputationGraph) replicatedModel).getUpdater();
if (updaterReplica.getStateViewArray() != null)
Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);
}
}
running.decrementAndGet();
}
}
} else {
// loop for MultiDataSet
while (!shouldStop.get()) {
MultiDataSet dataSet = queueMDS.poll(100, TimeUnit.MILLISECONDS);
if (dataSet != null) {
if (replicatedModel instanceof ComputationGraph) {
((ComputationGraph) replicatedModel).setLastEtlTime(lastEtlTime.get());
((ComputationGraph) replicatedModel).fit(dataSet);
} else
throw new RuntimeException("MultiDataSet can be fit into ComputationGraph only");
// if we don't support cross-device stuff (like multi-gpu on windows) - sync back to host
if (!Nd4j.getAffinityManager().isCrossDeviceAccessSupported() && iterationsCounter.incrementAndGet() % averagingFrequency == 0) {
// we ensure all operations are finished in this training round
Nd4j.getExecutioner().commit();
// we ensure memory is updated on host side
Nd4j.getAffinityManager().ensureLocation(replicatedModel.params(), AffinityManager.Location.HOST);
ComputationGraphUpdater updaterReplica = ((ComputationGraph) replicatedModel).getUpdater();
if (updaterReplica.getStateViewArray() != null)
Nd4j.getAffinityManager().ensureLocation(updaterReplica.getStateViewArray(), AffinityManager.Location.HOST);
}
running.decrementAndGet();
}
}
}
} catch (Exception e) {
this.thrownException = e;
throw new RuntimeException(e);
} finally {
log.debug("Terminating all workspaces for trainer_{}", threadId);
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
isStopped.set(true);
}
}
@Override
public void waitTillRunning() {
while (running.get() != 0) {
// if Trainer thread got exception during training - rethrow it here
if (thrownException != null)
throw new RuntimeException(thrownException);
LockSupport.parkNanos(1000L);
}
}
protected static IterationListener cloneListener(IterationListener original) {
if (original instanceof RoutingIterationListener) {
return ((RoutingIterationListener) original).clone();
}
return original;
}
protected void configureListeners(String workerUUID, Collection<IterationListener> oldListeners,
Collection<IterationListener> replicatedListeners) {
for (IterationListener listener : oldListeners) {
IterationListener l = cloneListener(listener);
if (l instanceof RoutingIterationListener) {
RoutingIterationListener rl = (RoutingIterationListener) l;
//We're assuming session ID is set by the original RoutingIterationListener constructor, which means
// it will be synced across all cloned instances
rl.setSessionID(((RoutingIterationListener) listener).getSessionID());
rl.setWorkerID(workerUUID);
StatsStorageRouter currentRouter = ((RoutingIterationListener) listener).getStorageRouter();
if (currentRouter != null) {
//User has set router on the listener/model, instead of via the
// setListeners(StatsStorageRouter, ...) method
rl.setStorageRouter(currentRouter);
} else {
rl.setStorageRouter(parallelWrapper.getStorageRouter());
}
}
replicatedListeners.add(l);
}
}
public static class DefaultTrainerBuilder {
public DefaultTrainerBuilder() {}
}
}