package org.nd4j.parameterserver.background;
import io.aeron.Aeron;
import io.aeron.driver.MediaDriver;
import io.aeron.driver.ThreadingMode;
import lombok.extern.slf4j.Slf4j;
import org.agrona.CloseHelper;
import org.agrona.concurrent.BusySpinIdleStrategy;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.aeron.ipc.AeronUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.client.ParameterServerClient;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
/**
* Created by agibsonccc on 10/5/16.
*/
@Slf4j
public class RemoteParameterServerClientTests {
private int parameterLength = 1000;
private Aeron.Context ctx;
private MediaDriver mediaDriver;
private AtomicInteger masterStatus = new AtomicInteger(0);
private AtomicInteger slaveStatus = new AtomicInteger(0);
private Aeron aeron;
@Before
public void before() throws Exception {
final MediaDriver.Context ctx =
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true)
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
.receiverIdleStrategy(new BusySpinIdleStrategy())
.senderIdleStrategy(new BusySpinIdleStrategy());
mediaDriver = MediaDriver.launchEmbedded(ctx);
aeron = Aeron.connect(getContext());
Thread t = new Thread(() -> {
try {
masterStatus.set(
BackgroundDaemonStarter.startMaster(parameterLength, mediaDriver.aeronDirectoryName()));
} catch (Exception e) {
e.printStackTrace();
}
});
t.start();
log.info("Started master");
Thread t2 = new Thread(() -> {
try {
slaveStatus.set(BackgroundDaemonStarter.startSlave(parameterLength, mediaDriver.aeronDirectoryName()));
} catch (Exception e) {
e.printStackTrace();
}
});
t2.start();
log.info("Started slave");
//wait on the http servers
Thread.sleep(30000);
}
@After
public void after() throws Exception {
CloseHelper.close(mediaDriver);
CloseHelper.close(aeron);
}
@Test
public void remoteTests() throws Exception {
if (masterStatus.get() != 0 || slaveStatus.get() != 0)
throw new IllegalStateException("Master or slave failed to start. Exiting");
ParameterServerClient client = ParameterServerClient.builder().aeron(aeron)
.ndarrayRetrieveUrl(BackgroundDaemonStarter.masterResponderUrl())
.ndarraySendUrl(BackgroundDaemonStarter.slaveConnectionUrl()).subscriberHost("localhost")
.masterStatusHost("localhost").masterStatusPort(9200).subscriberPort(40125).subscriberStream(12)
.build();
assertEquals("localhost:40125:12", client.connectionUrl());
while (!client.masterStarted()) {
Thread.sleep(1000);
log.info("Waiting on master starting.");
}
//flow 1:
/**
* Client (40125:12): sends array to listener on slave(40126:10)
* which publishes to master (40123:11)
* which adds the array for parameter averaging.
* In this case totalN should be 1.
*/
log.info("Pushing ndarray");
client.pushNDArray(Nd4j.ones(parameterLength));
while (client.arraysSentToResponder() < 1) {
Thread.sleep(1000);
log.info("Waiting on ndarray responder to receive array");
}
log.info("Pushed ndarray");
INDArray arr = client.getArray();
assertEquals(Nd4j.ones(1000), arr);
/*
StopWatch stopWatch = new StopWatch();
long nanoTimeTotal = 0;
int n = 1000;
for(int i = 0; i < n; i++) {
stopWatch.start();
client.getArray();
stopWatch.stop();
nanoTimeTotal += stopWatch.getNanoTime();
stopWatch.reset();
}
System.out.println(nanoTimeTotal / n);
*/
}
private Aeron.Context getContext() {
if (ctx == null)
ctx = new Aeron.Context().publicationConnectionTimeout(-1)
.availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
.errorHandler(e -> log.error(e.toString(), e));
return ctx;
}
}