package org.nd4j.parameterserver.distributed.training; import lombok.NonNull; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.nd4j.parameterserver.distributed.logic.completion.Clipboard; import org.nd4j.parameterserver.distributed.logic.Storage; import org.nd4j.parameterserver.distributed.messages.TrainingMessage; import org.nd4j.parameterserver.distributed.transport.Transport; import org.reflections.Reflections; import java.lang.reflect.Modifier; import java.util.HashMap; import java.util.Map; import java.util.Set; /** * @author raver119@gmail.com */ public class TrainerProvider { private static final TrainerProvider INSTANCE = new TrainerProvider(); // we use Class.getSimpleName() as key here protected Map<String, TrainingDriver<?>> trainers = new HashMap<>(); protected VoidConfiguration voidConfiguration; protected Transport transport; protected Clipboard clipboard; protected Storage storage; private TrainerProvider() { scanClasspath(); } public static TrainerProvider getInstance() { return INSTANCE; } protected void scanClasspath() { // TODO: reflection stuff to fill trainers Reflections reflections = new Reflections("org"); Set<Class<? extends TrainingDriver>> classes = reflections.getSubTypesOf(TrainingDriver.class); for (Class clazz : classes) { if (clazz.isInterface() || Modifier.isAbstract(clazz.getModifiers())) continue; try { TrainingDriver driver = (TrainingDriver) clazz.newInstance(); trainers.put(driver.targetMessageClass(), driver); } catch (Exception e) { throw new RuntimeException(e); } } if (trainers.size() < 1) throw new ND4JIllegalStateException("No TrainingDrivers were found"); } public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, @NonNull Storage storage, @NonNull Clipboard clipboard) { this.voidConfiguration = voidConfiguration; this.transport = transport; this.clipboard = clipboard; this.storage = storage; for (TrainingDriver<?> trainer : trainers.values()) { trainer.init(voidConfiguration, transport, storage, clipboard); } } @SuppressWarnings("unchecked") protected <T extends TrainingMessage> TrainingDriver<T> getTrainer(T message) { TrainingDriver<?> driver = trainers.get(message.getClass().getSimpleName()); if (driver == null) throw new ND4JIllegalStateException("Can't find trainer for [" + message.getClass().getSimpleName() + "]"); return (TrainingDriver<T>) driver; } public <T extends TrainingMessage> void doTraining(T message) { TrainingDriver<T> trainer = getTrainer(message); trainer.startTraining(message); } }