package org.nd4j.aeron.ipc; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import org.agrona.CloseHelper; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertFalse; /** * Created by agibsonccc on 9/22/16. */ public class NdArrayIpcTest { private MediaDriver mediaDriver; private static Logger log = LoggerFactory.getLogger(NdArrayIpcTest.class); private Aeron.Context ctx; private String channel = "aeron:udp?endpoint=localhost:" + (40132 + new java.util.Random().nextInt(3000)); private int streamId = 10; private int length = (int) 1e7; @Before public void before() { MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); mediaDriver = MediaDriver.launchEmbedded(ctx); System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); System.out.println("Launched media driver"); } @After public void after() { CloseHelper.quietClose(mediaDriver); } @Test public void testMultiThreadedIpc() throws Exception { ExecutorService executorService = Executors.newFixedThreadPool(4); INDArray arr = Nd4j.scalar(1.0); final AtomicBoolean running = new AtomicBoolean(true); Aeron aeron = Aeron.connect(getContext()); int numSubscribers = 10; AeronNDArraySubscriber[] subscribers = new AeronNDArraySubscriber[numSubscribers]; for (int i = 0; i < numSubscribers; i++) { AeronNDArraySubscriber subscriber = AeronNDArraySubscriber.builder().streamId(streamId).ctx(getContext()) .channel(channel).aeron(aeron).running(running).ndArrayCallback(new NDArrayCallback() { /** * A listener for ndarray message * * @param message the message for the callback */ @Override public void onNDArrayMessage(NDArrayMessage message) { System.out.println("Callback invoked for subscriber on ndarray ipc test"); running.set(false); } @Override public void onNDArrayPartial(INDArray arr, long idx, int... dimensions) { } @Override public void onNDArray(INDArray arr) { System.out.println("Callback invoked for subscriber on ndarray ipc test"); running.set(false); } }).build(); Thread t = new Thread(() -> { try { subscriber.launch(); } catch (Exception e) { e.printStackTrace(); } }); t.setDaemon(true); t.start(); subscribers[i] = subscriber; } AeronNDArrayPublisher publisher = AeronNDArrayPublisher.builder().streamId(streamId).channel(channel).aeron(aeron).build(); Thread.sleep(10000); for (int i = 0; i < 10 && running.get(); i++) { executorService.execute(() -> { try { log.info("About to send array."); publisher.publish(arr); log.info("Sent array"); } catch (Exception e) { e.printStackTrace(); } }); } Thread.sleep(30000); for (int i = 0; i < numSubscribers; i++) CloseHelper.close(subscribers[i]); CloseHelper.close(publisher); CloseHelper.close(aeron); assertFalse(running.get()); } @Test public void testIpc() throws Exception { INDArray arr = Nd4j.scalar(1.0); final AtomicBoolean running = new AtomicBoolean(true); Aeron aeron = Aeron.connect(getContext()); AeronNDArraySubscriber subscriber = AeronNDArraySubscriber.builder().streamId(streamId).aeron(aeron) .channel(channel).running(running).ndArrayCallback(new NDArrayCallback() { /** * A listener for ndarray message * * @param message the message for the callback */ @Override public void onNDArrayMessage(NDArrayMessage message) { System.out.println(arr); running.set(false); } @Override public void onNDArrayPartial(INDArray arr, long idx, int... dimensions) { } @Override public void onNDArray(INDArray arr) { } }).build(); Thread t = new Thread(() -> { try { subscriber.launch(); } catch (Exception e) { e.printStackTrace(); } }); t.start(); while (!subscriber.launched()) Thread.sleep(1000); Thread.sleep(10000); AeronNDArrayPublisher publisher = AeronNDArrayPublisher.builder().streamId(streamId).aeron(aeron).channel(channel).build(); for (int i = 0; i < 1 && running.get(); i++) { publisher.publish(arr); } Thread.sleep(30000); assertFalse(running.get()); publisher.close(); subscriber.close(); } private Aeron.Context getContext() { if (ctx == null) ctx = new Aeron.Context().publicationConnectionTimeout(1000) .availableImageHandler(image -> System.out.println(image)) .unavailableImageHandler(AeronUtil::printUnavailableImage) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .errorHandler(e -> log.error(e.toString(), e)); return ctx; } }