package org.nd4j.linalg.compression; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.*; /** * @author raver119@gmail.com */ @RunWith(Parameterized.class) public class CompressionMagicTests extends BaseNd4jTest { public CompressionMagicTests(Nd4jBackend backend) { super(backend); } @Test public void testMagicDecompression1() throws Exception { INDArray array = Nd4j.linspace(1, 100, 2500); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); compressed.muli(1.0); assertEquals(array, compressed); } @Test public void testMagicDecompression2() throws Exception { INDArray array = Nd4j.linspace(1, 100, 2500); INDArray compressed = Nd4j.getCompressor().compress(array, "FLOAT16"); compressed.muli(1.0); assertArrayEquals(array.data().asFloat(), compressed.data().asFloat(), 0.1f); } @Test public void testMagicDecompression3() throws Exception { INDArray array = Nd4j.linspace(1, 2500, 2500); INDArray compressed = Nd4j.getCompressor().compress(array, "INT16"); compressed.muli(1.0); assertEquals(array, compressed); } @Test public void testMagicDecompression4() throws Exception { INDArray array = Nd4j.linspace(1, 100, 2500); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); for (int cnt = 0; cnt < array.length(); cnt++) { float a = array.getFloat(cnt); float c = compressed.getFloat(cnt); assertEquals(a, c, 0.01f); } } @Test public void testDupSkipDecompression1() { INDArray array = Nd4j.linspace(1, 100, 2500); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); INDArray newArray = compressed.dup(); assertTrue(newArray.isCompressed()); Nd4j.getCompressor().decompressi(compressed); Nd4j.getCompressor().decompressi(newArray); assertEquals(array, compressed); assertEquals(array, newArray); } @Test public void testDupSkipDecompression2() { INDArray array = Nd4j.linspace(1, 100, 2500); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); INDArray newArray = compressed.dup('c'); assertTrue(newArray.isCompressed()); Nd4j.getCompressor().decompressi(compressed); Nd4j.getCompressor().decompressi(newArray); assertEquals(array, compressed); assertEquals(array, newArray); } @Test public void testDupSkipDecompression3() { INDArray array = Nd4j.linspace(1, 100, 2500); INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP"); INDArray newArray = compressed.dup('f'); assertFalse(newArray.isCompressed()); Nd4j.getCompressor().decompressi(compressed); // Nd4j.getCompressor().decompressi(newArray); assertEquals(array, compressed); assertEquals(array, newArray); assertEquals('f', newArray.ordering()); assertEquals('c', compressed.ordering()); } @Override public char ordering() { return 'c'; } }