package org.nd4j.parameterserver.distributed.training; import lombok.NonNull; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; import org.nd4j.parameterserver.distributed.logic.Storage; import org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler; import org.nd4j.parameterserver.distributed.messages.TrainingMessage; import org.nd4j.parameterserver.distributed.transport.Transport; /** * @author raver119@gmail.co, */ public abstract class BaseTrainer<T extends TrainingMessage> implements TrainingDriver<T> { protected VoidConfiguration voidConfiguration; protected Transport transport; protected Clipboard clipboard; protected Storage storage; protected FrameCompletionHandler completionHandler = new FrameCompletionHandler(); @Override public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, @NonNull Storage storage, @NonNull Clipboard clipboard) { this.clipboard = clipboard; this.transport = transport; this.voidConfiguration = voidConfiguration; this.storage = storage; } protected int[] replicate(int value, int size) { int[] result = new int[size]; for (int e = 0; e < size; e++) result[e] = value; return result; } @Override public void addCompletionHook(long originatorId, long frameId, long messageId) { completionHandler.addHook(originatorId, frameId, messageId); } }