package org.nd4j.parameterserver.node; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import lombok.extern.slf4j.Slf4j; import org.junit.BeforeClass; import org.junit.Test; import org.nd4j.aeron.ipc.AeronUtil; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.parameterserver.client.ParameterServerClient; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import static org.junit.Assert.*; /** * Created by agibsonccc on 12/3/16. */ @Slf4j public class ParameterServerNodeTest { private static MediaDriver mediaDriver; private static Aeron aeron; private static ParameterServerNode parameterServerNode; private static int parameterLength = 4; private static int masterStatusPort = 40323 + new java.util.Random().nextInt(15999); private static int statusPort = masterStatusPort - 1299; @BeforeClass public static void before() throws Exception { mediaDriver = MediaDriver.launchEmbedded(AeronUtil.getMediaDriverContext(parameterLength)); System.setProperty("play.server.dir", "/tmp"); aeron = Aeron.connect(getContext()); parameterServerNode = new ParameterServerNode(mediaDriver, statusPort); parameterServerNode.runMain(new String[] {"-m", "true", "-s", "1," + String.valueOf(parameterLength), "-p", String.valueOf(masterStatusPort), "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(statusPort), "-sh", "localhost", "-u", String.valueOf(Runtime.getRuntime().availableProcessors())}); while (!parameterServerNode.subscriberLaunched()) { Thread.sleep(10000); } } @Test public void testSimulateRun() throws Exception { int numCores = Runtime.getRuntime().availableProcessors(); ExecutorService executorService = Executors.newFixedThreadPool(numCores); ParameterServerClient[] clients = new ParameterServerClient[numCores]; String host = "localhost"; for (int i = 0; i < numCores; i++) { clients[i] = ParameterServerClient.builder().aeron(aeron).masterStatusHost(host) .masterStatusPort(statusPort).subscriberHost(host).subscriberPort(40325 + i) .subscriberStream(10 + i) .ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()) .ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()) .build(); } Thread.sleep(60000); //no arrays have been sent yet for (int i = 0; i < numCores; i++) { assertFalse(clients[i].isReadyForNext()); } //send "numCores" arrays, the default parameter server updater //is synchronous so it should be "ready" when number of updates == number of workers for (int i = 0; i < numCores; i++) { clients[i].pushNDArrayMessage(NDArrayMessage.wholeArrayUpdate(Nd4j.ones(parameterLength))); } Thread.sleep(10000); //all arrays should have been sent for (int i = 0; i < numCores; i++) { assertTrue(clients[i].isReadyForNext()); } Thread.sleep(10000); for (int i = 0; i < 1; i++) { assertEquals(Nd4j.valueArrayOf(1, parameterLength, numCores), clients[i].getArray()); Thread.sleep(1000); } executorService.shutdown(); Thread.sleep(60000); parameterServerNode.close(); } private static Aeron.Context getContext() { return new Aeron.Context().publicationConnectionTimeout(-1) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .errorHandler(e -> log.error(e.toString(), e)); } }