package org.nd4j.parameterserver.node; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import lombok.Data; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.agrona.CloseHelper; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayCallback; import org.nd4j.parameterserver.ParameterServerListener; import org.nd4j.parameterserver.ParameterServerSubscriber; import org.nd4j.parameterserver.status.play.InMemoryStatusStorage; import org.nd4j.parameterserver.status.play.StatusServer; import play.server.Server; import java.util.ArrayList; import java.util.Arrays; import java.util.List; /** * Integrated node for running * the parameter server. * This includes the status server, * media driver, and parameter server subscriber * * @author Adam Gibson */ @Slf4j @NoArgsConstructor @Data public class ParameterServerNode implements AutoCloseable { private Server server; private ParameterServerSubscriber[] subscriber; private MediaDriver mediaDriver; private Aeron aeron; private int statusPort; private int numWorkers; /** * * @param mediaDriver the media driver to sue for communication * @param statusPort the port for the server status */ public ParameterServerNode(MediaDriver mediaDriver, int statusPort) { this(mediaDriver, statusPort, Runtime.getRuntime().availableProcessors()); } /** * * @param mediaDriver the media driver to sue for communication * @param statusPort the port for the server status */ public ParameterServerNode(MediaDriver mediaDriver, int statusPort, int numWorkers) { this.mediaDriver = mediaDriver; this.statusPort = statusPort; this.numWorkers = numWorkers; subscriber = new ParameterServerSubscriber[numWorkers]; } /** * Pass in the media driver used for communication * and a defualt status port of 9000 * @param mediaDriver */ public ParameterServerNode(MediaDriver mediaDriver) { this(mediaDriver, 9000); } /** * Run this node with the given args * These args are the same ones * that a {@link ParameterServerSubscriber} takes * @param args the arguments for the {@link ParameterServerSubscriber} */ public void runMain(String[] args) { server = StatusServer.startServer(new InMemoryStatusStorage(), statusPort); if (mediaDriver == null) mediaDriver = MediaDriver.launchEmbedded(); log.info("Started media driver with aeron directory " + mediaDriver.aeronDirectoryName()); //cache a reference to the first listener. //The reason we do this is to share an updater and listener across *all* subscribers //This will create a shared pool of subscribers all updating the same "server". //This will simulate a shared pool but allow an accumulative effect of anything //like averaging we try. NDArrayCallback parameterServerListener = null; ParameterServerListener cast = null; for (int i = 0; i < numWorkers; i++) { subscriber[i] = new ParameterServerSubscriber(mediaDriver); //ensure reuse of aeron wherever possible if (aeron == null) aeron = Aeron.connect(getContext(mediaDriver)); subscriber[i].setAeron(aeron); List<String> multiArgs = new ArrayList<>(Arrays.asList(args)); if (multiArgs.contains("-id")) { int streamIdIdx = multiArgs.indexOf("-id") + 1; int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i; multiArgs.set(streamIdIdx, String.valueOf(streamId)); } else if (multiArgs.contains("--streamId")) { int streamIdIdx = multiArgs.indexOf("--streamId") + 1; int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i; multiArgs.set(streamIdIdx, String.valueOf(streamId)); } if (i == 0) { subscriber[i].run(multiArgs.toArray(new String[args.length])); parameterServerListener = subscriber[i].getCallback(); cast = subscriber[i].getParameterServerListener(); } else { //note that we set both the callback AND the listener here subscriber[i].setCallback(parameterServerListener); subscriber[i].setParameterServerListener(cast); //now run the callback initialized with this callback instead //in the run method it will use this reference instead of creating it //itself subscriber[i].run(multiArgs.toArray(new String[args.length])); } } } /** * Returns true if all susbcribers in the * subscriber pool have been launched * @return */ public boolean subscriberLaunched() { boolean launched = true; for (int i = 0; i < numWorkers; i++) { launched = launched && subscriber[i].subscriberLaunched(); } return launched; } /** * Stop the server * @throws Exception */ @Override public void close() throws Exception { if (subscriber != null) { for (int i = 0; i < subscriber.length; i++) { if (subscriber[i] != null) { subscriber[i].close(); } } } if (server != null) server.stop(); if (mediaDriver != null) CloseHelper.quietClose(mediaDriver); if (aeron != null) CloseHelper.quietClose(aeron); } private static Aeron.Context getContext(MediaDriver mediaDriver) { return new Aeron.Context().publicationConnectionTimeout(-1) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .errorHandler(e -> log.error(e.toString(), e)); } public static void main(String[] args) { new ParameterServerNode().runMain(args); } }