package org.nd4j.parameterserver.client;
import io.aeron.Aeron;
import io.aeron.driver.MediaDriver;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.aeron.ipc.AeronUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.ParameterServerListener;
import org.nd4j.parameterserver.ParameterServerSubscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static junit.framework.TestCase.assertFalse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Created by agibsonccc on 10/3/16.
*/
public class ParameterServerClientTest {
private static MediaDriver mediaDriver;
private static Logger log = LoggerFactory.getLogger(ParameterServerClientTest.class);
private static Aeron aeron;
private static ParameterServerSubscriber masterNode, slaveNode;
private static int parameterLength = 1000;
@BeforeClass
public static void before() throws Exception {
mediaDriver = MediaDriver.launchEmbedded(AeronUtil.getMediaDriverContext(parameterLength));
System.setProperty("play.server.dir", "/tmp");
aeron = Aeron.connect(getContext());
masterNode = new ParameterServerSubscriber(mediaDriver);
masterNode.setAeron(aeron);
int masterPort = 40323 + new java.util.Random().nextInt(3000);
masterNode.run(new String[] {"-m", "true", "-s", "1," + String.valueOf(parameterLength), "-p",
String.valueOf(masterPort), "-h", "localhost", "-id", "11", "-md",
mediaDriver.aeronDirectoryName(), "-sp", "33000", "-u", String.valueOf(1)});
assertTrue(masterNode.isMaster());
assertEquals(masterPort, masterNode.getPort());
assertEquals("localhost", masterNode.getHost());
assertEquals(11, masterNode.getStreamId());
assertEquals(12, masterNode.getResponder().getStreamId());
slaveNode = new ParameterServerSubscriber(mediaDriver);
slaveNode.setAeron(aeron);
slaveNode.run(new String[] {"-p", String.valueOf(masterPort + 100), "-h", "localhost", "-id", "10", "-pm",
masterNode.getSubscriber().connectionUrl(), "-md", mediaDriver.aeronDirectoryName(), "-sp",
"31000", "-u", String.valueOf(1)});
assertFalse(slaveNode.isMaster());
assertEquals(masterPort + 100, slaveNode.getPort());
assertEquals("localhost", slaveNode.getHost());
assertEquals(10, slaveNode.getStreamId());
int tries = 10;
while (!masterNode.subscriberLaunched() && !slaveNode.subscriberLaunched() && tries < 10) {
Thread.sleep(10000);
tries++;
}
if (!masterNode.subscriberLaunched() && !slaveNode.subscriberLaunched()) {
throw new IllegalStateException("Failed to start master and slave node");
}
log.info("Using media driver directory " + mediaDriver.aeronDirectoryName());
log.info("Launched media driver");
}
@Test
public void testServer() throws Exception {
int subscriberPort = 40625 + new java.util.Random().nextInt(100);
ParameterServerClient client = ParameterServerClient.builder().aeron(aeron)
.ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl())
.ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost")
.subscriberPort(subscriberPort).subscriberStream(12).build();
assertEquals(String.format("localhost:%d:12", subscriberPort), client.connectionUrl());
//flow 1:
/**
* Client (40125:12): sends array to listener on slave(40126:10)
* which publishes to master (40123:11)
* which adds the array for parameter averaging.
* In this case totalN should be 1.
*/
client.pushNDArray(Nd4j.ones(parameterLength));
log.info("Pushed ndarray");
Thread.sleep(30000);
ParameterServerListener listener = (ParameterServerListener) masterNode.getCallback();
assertEquals(1, listener.getUpdater().numUpdates());
assertEquals(Nd4j.ones(parameterLength), listener.getUpdater().ndArrayHolder().get());
INDArray arr = client.getArray();
assertEquals(Nd4j.ones(1000), arr);
}
private static Aeron.Context getContext() {
return new Aeron.Context().publicationConnectionTimeout(-1)
.availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
.errorHandler(e -> log.error(e.toString(), e));
}
}