package org.nd4j.aeron.ipc.response; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import io.aeron.driver.ThreadingMode; import lombok.extern.slf4j.Slf4j; import org.agrona.CloseHelper; import org.agrona.concurrent.BusySpinIdleStrategy; import org.junit.Before; import org.junit.Test; import org.nd4j.aeron.ipc.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 10/3/16. */ @Slf4j public class AeronNDArrayResponseTest { private MediaDriver mediaDriver; @Before public void before() { final MediaDriver.Context ctx = new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy()) .senderIdleStrategy(new BusySpinIdleStrategy()); mediaDriver = MediaDriver.launchEmbedded(ctx); System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); System.out.println("Launched media driver"); } @Test public void testResponse() throws Exception { int streamId = 10; int responderStreamId = 11; String host = "127.0.0.1"; Aeron.Context ctx = new Aeron.Context().publicationConnectionTimeout(-1) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .errorHandler(e -> log.error(e.toString(), e)); int baseSubscriberPort = 40123 + new java.util.Random().nextInt(1000); Aeron aeron = Aeron.connect(ctx); AeronNDArrayResponder responder = AeronNDArrayResponder.startSubscriber(aeron, host, baseSubscriberPort + 1, new NDArrayHolder() { /** * Set the ndarray * * @param arr the ndarray for this holder * to use */ @Override public void setArray(INDArray arr) { } /** * The number of updates * that have been sent to this older. * * @return */ @Override public int totalUpdates() { return 1; } /** * Retrieve an ndarray * * @return */ @Override public INDArray get() { return Nd4j.scalar(1.0); } /** * Retrieve a partial view of the ndarray. * This method uses tensor along dimension internally * Note this will call dup() * * @param idx the index of the tad to get * @param dimensions the dimensions to use * @return the tensor along dimension based on the index and dimensions * from the master array. */ @Override public INDArray getTad(int idx, int... dimensions) { return Nd4j.scalar(1.0); } } , responderStreamId); AtomicInteger count = new AtomicInteger(0); AtomicBoolean running = new AtomicBoolean(true); AeronNDArraySubscriber subscriber = AeronNDArraySubscriber.startSubscriber(aeron, host, baseSubscriberPort, new NDArrayCallback() { /** * A listener for ndarray message * * @param message the message for the callback */ @Override public void onNDArrayMessage(NDArrayMessage message) { count.incrementAndGet(); } @Override public void onNDArrayPartial(INDArray arr, long idx, int... dimensions) { count.incrementAndGet(); } @Override public void onNDArray(INDArray arr) { count.incrementAndGet(); } }, streamId, running); int expectedResponses = 10; HostPortPublisher publisher = HostPortPublisher.builder().aeron(aeron) .uriToSend(host + String.format(":%d:", baseSubscriberPort) + streamId) .channel(AeronUtil.aeronChannel(host, baseSubscriberPort + 1)).streamId(responderStreamId) .build(); for (int i = 0; i < expectedResponses; i++) { publisher.send(); } Thread.sleep(60000); assertEquals(expectedResponses, count.get()); System.out.println("After"); CloseHelper.close(responder); CloseHelper.close(subscriber); CloseHelper.close(publisher); CloseHelper.close(aeron); } }