package org.nd4j.aeron.ipc; import org.agrona.concurrent.UnsafeBuffer; import org.apache.commons.lang3.time.StopWatch; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.BufferedOutputStream; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 9/23/16. */ public class AeronNDArraySerdeTest { @Test public void testToAndFrom() { INDArray arr = Nd4j.scalar(1.0); UnsafeBuffer buffer = AeronNDArraySerde.toBuffer(arr); INDArray back = AeronNDArraySerde.toArray(buffer); assertEquals(arr, back); } @Test public void testToAndFromCompressed() { INDArray arr = Nd4j.scalar(1.0); INDArray compress = Nd4j.getCompressor().compress(arr, "GZIP"); assertTrue(compress.isCompressed()); UnsafeBuffer buffer = AeronNDArraySerde.toBuffer(compress); INDArray back = AeronNDArraySerde.toArray(buffer); INDArray decompressed = Nd4j.getCompressor().decompress(compress); assertEquals(arr, decompressed); assertEquals(arr, back); } @Test public void testToAndFromCompressedLarge() { INDArray arr = Nd4j.zeros((int) 1e7); INDArray compress = Nd4j.getCompressor().compress(arr, "GZIP"); assertTrue(compress.isCompressed()); UnsafeBuffer buffer = AeronNDArraySerde.toBuffer(compress); INDArray back = AeronNDArraySerde.toArray(buffer); INDArray decompressed = Nd4j.getCompressor().decompress(compress); assertEquals(arr, decompressed); assertEquals(arr, back); } @Test public void timeOldVsNew() throws Exception { int numTrials = 1000; long oldTotal = 0; long newTotal = 0; INDArray arr = Nd4j.create(100000); Nd4j.getCompressor().compressi(arr, "GZIP"); for (int i = 0; i < numTrials; i++) { StopWatch oldStopWatch = new StopWatch(); BufferedOutputStream bos = new BufferedOutputStream(new ByteArrayOutputStream(arr.length())); DataOutputStream dos = new DataOutputStream(bos); oldStopWatch.start(); Nd4j.write(arr, dos); oldStopWatch.stop(); // System.out.println("Old " + oldStopWatch.getNanoTime()); oldTotal += oldStopWatch.getNanoTime(); StopWatch newStopWatch = new StopWatch(); newStopWatch.start(); AeronNDArraySerde.toBuffer(arr); newStopWatch.stop(); // System.out.println("New " + newStopWatch.getNanoTime()); newTotal += newStopWatch.getNanoTime(); } oldTotal /= numTrials; newTotal /= numTrials; System.out.println("Old avg " + oldTotal + " New avg " + newTotal); } }