package org.deeplearning4j.ui.play; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import com.google.common.collect.Sets; import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.i18n.I18NProvider; import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule; import org.deeplearning4j.ui.module.defaultModule.DefaultModule; import org.deeplearning4j.ui.module.flow.FlowListenerModule; import org.deeplearning4j.ui.module.remote.RemoteReceiverModule; import org.deeplearning4j.ui.module.train.TrainModule; import org.deeplearning4j.ui.module.histogram.HistogramModule; import org.deeplearning4j.ui.module.tsne.TsneModule; import org.deeplearning4j.ui.play.misc.FunctionUtil; import org.deeplearning4j.ui.play.staticroutes.Assets; import org.deeplearning4j.ui.play.staticroutes.I18NRoute; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorageEvent; import org.deeplearning4j.api.storage.StatsStorageListener; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.storage.impl.QueuePairStatsStorageListener; import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener; import org.reflections.ReflectionUtils; import org.reflections.Reflections; import play.Mode; import play.api.routing.Router; import play.routing.RoutingDsl; import play.server.Server; import java.util.*; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; import static play.mvc.Results.ok; /** * A UI server based on the Play framework * * @author Alex Black */ @Slf4j @Data public class PlayUIServer extends UIServer { /** * System property for setting the UI port. Defaults to 9000. * Set to 0 to use a random port */ public static final String UI_SERVER_PORT_PROPERTY = "org.deeplearning4j.ui.port"; public static final int DEFAULT_UI_PORT = 9000; /** * System property to enable classpath scanning for custom UI modules. Disabled by default. */ public static final String UI_CUSTOM_MODULE_PROPERTY = "org.deeplearning4j.ui.custommodule.enable"; public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/"; private Server server; private final BlockingQueue<StatsStorageEvent> eventQueue = new LinkedBlockingQueue<>(); private List<Pair<StatsStorage, StatsStorageListener>> listeners = new ArrayList<>(); private List<StatsStorage> statsStorageInstances = new ArrayList<>(); private List<UIModule> uiModules = new ArrayList<>(); private RemoteReceiverModule remoteReceiverModule; //typeIDModuleMap: Records which modules are registered for which type IDs private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>(); private long uiProcessingDelay = 500; //500ms. TODO make configurable private final AtomicBoolean shutdown = new AtomicBoolean(false); private Thread uiEventRoutingThread; @Parameter(names = {"-r", "-enableRemote"}, description = "Whether to enable remote or not", arity = 1) private boolean enableRemote; @Parameter(names = {"--uiPort"}, description = "Whether to enable remote or not", arity = 1) private int port = DEFAULT_UI_PORT; public PlayUIServer() { this(DEFAULT_UI_PORT); } public PlayUIServer(int port) { this.port = port; } public void runMain(String[] args) { JCommander jcmdr = new JCommander(this); try { jcmdr.parse(args); } catch (ParameterException e) { //User provides invalid input -> print the usage info jcmdr.usage(); try { Thread.sleep(500); } catch (Exception e2) { } System.exit(1); } RoutingDsl routingDsl = new RoutingDsl(); //Set up index page and assets routing //The definitions and FunctionUtil may look a bit weird here... this is used to translate implementation independent // definitions (i.e., Java Supplier, Function etc interfaces) to the Play-specific versions //This way, routing is not directly dependent ot Play API. Furthermore, Play 2.5 switches to using these Java interfaces // anyway; thus switching 2.5 should be as simple as removing the FunctionUtil calls... routingDsl.GET("/setlang/:to").routeTo(FunctionUtil.function(new I18NRoute())); routingDsl.GET("/lang/getCurrent").routeTo(() -> ok(I18NProvider.getInstance().getDefaultLanguage())); routingDsl.GET("/assets/*file").routeTo(FunctionUtil.function(new Assets(ASSETS_ROOT_DIRECTORY))); uiModules.add(new DefaultModule()); //For: navigation page "/" uiModules.add(new HistogramModule()); uiModules.add(new TrainModule()); uiModules.add(new ConvolutionalListenerModule()); uiModules.add(new FlowListenerModule()); uiModules.add(new TsneModule()); remoteReceiverModule = new RemoteReceiverModule(); uiModules.add(remoteReceiverModule); //Check if custom UI modules are enabled... String customModulePropertyStr = System.getProperty(UI_CUSTOM_MODULE_PROPERTY); boolean useCustomModules = false; if (customModulePropertyStr != null) { useCustomModules = Boolean.parseBoolean(customModulePropertyStr); } if (useCustomModules) { List<Class<?>> excludeClasses = new ArrayList<>(); for (UIModule u : uiModules) { excludeClasses.add(u.getClass()); } List<UIModule> list = getCustomUIModules(excludeClasses); uiModules.addAll(list); } for (UIModule m : uiModules) { List<Route> routes = m.getRoutes(); for (Route r : routes) { RoutingDsl.PathPatternMatcher ppm = routingDsl.match(r.getHttpMethod().name(), r.getRoute()); switch (r.getFunctionType()) { case Supplier: ppm.routeTo(FunctionUtil.function0(r.getSupplier())); break; case Function: ppm.routeTo(FunctionUtil.function(r.getFunction())); break; case BiFunction: case Function3: default: throw new RuntimeException("Not yet implemented"); } } //Determine which type IDs this module wants to receive: List<String> typeIDs = m.getCallbackTypeIDs(); for (String typeID : typeIDs) { List<UIModule> list = typeIDModuleMap.get(typeID); if (list == null) { list = Collections.synchronizedList(new ArrayList<>()); typeIDModuleMap.put(typeID, list); } list.add(m); } } String portProperty = System.getProperty(UI_SERVER_PORT_PROPERTY); if (portProperty != null) { try { port = Integer.parseInt(portProperty); } catch (NumberFormatException e) { log.warn("Could not parse UI port property \"{}\" with value \"{}\"", UI_SERVER_PORT_PROPERTY, portProperty, e); } } Router router = routingDsl.build(); server = Server.forRouter(router, Mode.DEV, port); this.port = port; String addr = server.mainAddress().toString(); if (addr.startsWith("/0:0:0:0:0:0:0:0")) { int last = addr.lastIndexOf(':'); if (last > 0) { addr = "http://localhost:" + addr.substring(last + 1); } } log.info("UI Server started at {}", addr); uiEventRoutingThread = new Thread(new StatsEventRouterRunnable()); uiEventRoutingThread.setDaemon(true); uiEventRoutingThread.start(); if (enableRemote) enableRemoteListener(); } public static void main(String[] args) { new PlayUIServer().runMain(args); } private List<UIModule> getCustomUIModules(List<Class<?>> excludeClasses) { //Scan classpath for UI module instances, but ignore the 'excludeClasses' classes List<String> classNames = Collections.singletonList(UIModule.class.getName()); Reflections reflections = new Reflections(); org.reflections.Store store = reflections.getStore(); Iterable<String> subtypesByName = store.getAll(org.reflections.scanners.SubTypesScanner.class.getSimpleName(), classNames); Set<? extends Class<?>> subtypeClasses = Sets.newHashSet(ReflectionUtils.forNames(subtypesByName)); List<Class<?>> toCreate = new ArrayList<>(); for (Class<?> c : subtypeClasses) { if (excludeClasses.contains(c)) continue;; toCreate.add(c); } List<UIModule> ret = new ArrayList<>(toCreate.size()); for (Class<?> c : toCreate) { UIModule m; try { m = (UIModule) c.newInstance(); } catch (Exception e) { log.warn("Could not create instance of custom UIModule of type {}; skipping", c, e); continue; } log.debug("Created instance of custom UI module: {}", c); ret.add(m); } return ret; } @Override public int getPort() { return port; } @Override public synchronized void attach(StatsStorage statsStorage) { if (statsStorage == null) throw new IllegalArgumentException("StatsStorage cannot be null"); if (statsStorageInstances.contains(statsStorage)) return; StatsStorageListener listener = new QueueStatsStorageListener(eventQueue); listeners.add(new Pair<>(statsStorage, listener)); statsStorage.registerStatsStorageListener(listener); statsStorageInstances.add(statsStorage); for (UIModule uiModule : uiModules) { uiModule.onAttach(statsStorage); } log.info("StatsStorage instance attached to UI: {}", statsStorage); } @Override public synchronized void detach(StatsStorage statsStorage) { if (statsStorage == null) throw new IllegalArgumentException("StatsStorage cannot be null"); if (!statsStorageInstances.contains(statsStorage)) return; //No op boolean found = false; for (Pair<StatsStorage, StatsStorageListener> p : listeners) { if (p.getFirst() == statsStorage) { //Same object, not equality statsStorage.deregisterStatsStorageListener(p.getSecond()); listeners.remove(p); found = true; } } for (UIModule uiModule : uiModules) { uiModule.onDetach(statsStorage); } if (found) { log.info("StatsStorage instance detached from UI: {}", statsStorage); } } @Override public boolean isAttached(StatsStorage statsStorage) { return statsStorageInstances.contains(statsStorage); } @Override public List<StatsStorage> getStatsStorageInstances() { return new ArrayList<>(statsStorageInstances); } @Override public void enableRemoteListener() { if (remoteReceiverModule == null) remoteReceiverModule = new RemoteReceiverModule(); if (remoteReceiverModule.isEnabled()) return; enableRemoteListener(new InMemoryStatsStorage(), true); } @Override public void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach) { remoteReceiverModule.setEnabled(true); remoteReceiverModule.setStatsStorage(statsStorage); if (attach && statsStorage instanceof StatsStorage) { attach((StatsStorage) statsStorage); } } @Override public void disableRemoteListener() { remoteReceiverModule.setEnabled(false); } @Override public boolean isRemoteListenerEnabled() { return remoteReceiverModule.isEnabled(); } @Override public void stop() { if (server != null) server.stop(); } private class StatsEventRouterRunnable implements Runnable { @Override public void run() { try { runHelper(); } catch (Exception e) { log.error("Unexpected exception from Event routing runnable", e); } } private void runHelper() throws Exception { log.debug("PlayUIServer.StatsEventRouterRunnable started"); //Idea: collect all event stats, and route them to the appropriate modules while (!shutdown.get()) { List<StatsStorageEvent> events = new ArrayList<>(); StatsStorageEvent sse = eventQueue.take(); //Blocking operation events.add(sse); eventQueue.drainTo(events); //Non-blocking for (UIModule m : uiModules) { List<String> callbackTypes = m.getCallbackTypeIDs(); List<StatsStorageEvent> out = new ArrayList<>(); for (StatsStorageEvent e : events) { if (callbackTypes.contains(e.getTypeID())) { out.add(e); } } m.reportStorageEvents(out); } try { Thread.sleep(uiProcessingDelay); } catch (InterruptedException e) { if (!shutdown.get()) { throw new RuntimeException("Unexpected interrupted exception", e); } } } } } }