package org.nd4j.aeron.ipc; import io.aeron.Aeron; import io.aeron.driver.MediaDriver; import lombok.extern.slf4j.Slf4j; import org.agrona.CloseHelper; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertFalse; /** * Created by agibsonccc on 9/22/16. */ @Slf4j public class LargeNdArrayIpcTest { private MediaDriver mediaDriver; private Aeron.Context ctx; private String channel = "aeron:udp?endpoint=localhost:" + (40123 + new java.util.Random().nextInt(130)); private int streamId = 10; private int length = (int) 1e7; @Before public void before() { //MediaDriver.loadPropertiesFile("aeron.properties"); MediaDriver.Context ctx = AeronUtil.getMediaDriverContext(length); mediaDriver = MediaDriver.launchEmbedded(ctx); System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName()); System.out.println("Launched media driver"); } @After public void after() { CloseHelper.quietClose(mediaDriver); } @Test public void testMultiThreadedIpcBig() throws Exception { int length = (int) 1e7; INDArray arr = Nd4j.ones(length); AeronNDArrayPublisher publisher; ctx = new Aeron.Context().publicationConnectionTimeout(-1).availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000) .errorHandler(err -> err.printStackTrace()); final AtomicBoolean running = new AtomicBoolean(true); Aeron aeron = Aeron.connect(ctx); int numSubscribers = 1; AeronNDArraySubscriber[] subscribers = new AeronNDArraySubscriber[numSubscribers]; for (int i = 0; i < numSubscribers; i++) { AeronNDArraySubscriber subscriber = AeronNDArraySubscriber.builder().streamId(streamId).ctx(getContext()) .channel(channel).aeron(aeron).running(running).ndArrayCallback(new NDArrayCallback() { /** * A listener for ndarray message * * @param message the message for the callback */ @Override public void onNDArrayMessage(NDArrayMessage message) { running.set(false); } @Override public void onNDArrayPartial(INDArray arr, long idx, int... dimensions) { } @Override public void onNDArray(INDArray arr) { running.set(false); } }).build(); Thread t = new Thread(() -> { try { subscriber.launch(); } catch (Exception e) { e.printStackTrace(); } }); t.start(); subscribers[i] = subscriber; } Thread.sleep(10000); publisher = AeronNDArrayPublisher.builder().publishRetryTimeOut(3000).streamId(streamId).channel(channel) .aeron(aeron).build(); for (int i = 0; i < 1 && running.get(); i++) { log.info("About to send array."); publisher.publish(arr); log.info("Sent array"); } Thread.sleep(30000); for (int i = 0; i < numSubscribers; i++) CloseHelper.close(subscribers[i]); CloseHelper.close(aeron); CloseHelper.close(publisher); assertFalse(running.get()); } private Aeron.Context getContext() { if (ctx == null) ctx = new Aeron.Context().publicationConnectionTimeout(-1) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000) .errorHandler(err -> err.printStackTrace()); return ctx; } }