package org.nd4j.linalg.compression; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.nio.ByteBuffer; import static junit.framework.TestCase.assertFalse; import static org.junit.Assert.*; /** * @author raver119@gmail.com */ @RunWith(Parameterized.class) public class CompressionTests extends BaseNd4jTest { public CompressionTests(Nd4jBackend backend) { super(backend); } @Test public void testCompressionDescriptorSerde() { CompressionDescriptor descriptor = new CompressionDescriptor(); descriptor.setCompressedLength(4); descriptor.setOriginalElementSize(4); descriptor.setNumberOfElements(4); descriptor.setCompressionAlgorithm("GZIP"); descriptor.setOriginalLength(4); descriptor.setCompressionType(CompressionType.LOSSY); ByteBuffer toByteBuffer = descriptor.toByteBuffer(); CompressionDescriptor fromByteBuffer = CompressionDescriptor.fromByteBuffer(toByteBuffer); assertEquals(descriptor, fromByteBuffer); } @Test public void testGzipInPlaceCompression() { INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); Nd4j.getCompressor().setDefaultCompression("GZIP"); Nd4j.getCompressor().compressi(array); assertTrue(array.isCompressed()); Nd4j.getCompressor().decompressi(array); assertFalse(array.isCompressed()); } @Test public void testFP16Compression1() { INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("INT8"); BasicNDArrayCompressor.getInstance().printAvailableCompressors(); INDArray compr = BasicNDArrayCompressor.getInstance().compress(array); assertEquals(DataBuffer.Type.COMPRESSED, compr.data().dataType()); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(1.0f, decomp.getFloat(0), 0.01f); assertEquals(2.0f, decomp.getFloat(1), 0.01f); assertEquals(3.0f, decomp.getFloat(2), 0.01f); assertEquals(4.0f, decomp.getFloat(3), 0.01f); assertEquals(5.0f, decomp.getFloat(4), 0.01f); } @Test public void testFP16Compression2() { DataBuffer buffer = Nd4j.createBuffer(new float[] {1f, 2f, 3f, 4f, 5f}); DataBuffer exp = Nd4j.createBuffer(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("FLOAT16"); DataBuffer compr = BasicNDArrayCompressor.getInstance().compress(buffer); assertEquals(DataBuffer.Type.COMPRESSED, compr.dataType()); DataBuffer decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(1.0f, decomp.getFloat(0), 0.01f); assertEquals(2.0f, decomp.getFloat(1), 0.01f); assertEquals(3.0f, decomp.getFloat(2), 0.01f); assertEquals(4.0f, decomp.getFloat(3), 0.01f); assertEquals(5.0f, decomp.getFloat(4), 0.01f); } @Test public void testFP16Compression3() { INDArray buffer = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); INDArray exp = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("FLOAT16"); INDArray compr = BasicNDArrayCompressor.getInstance().compress(buffer); assertEquals(false, buffer.isCompressed()); assertEquals(true, compr.isCompressed()); assertEquals(DataBuffer.Type.COMPRESSED, compr.data().dataType()); // assertNotEquals(exp, compr); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(false, decomp.isCompressed()); assertEquals(DataBuffer.Type.FLOAT, decomp.data().dataType()); assertEquals(exp, decomp); } @Test public void testUint8Compression1() { DataBuffer buffer = Nd4j.createBuffer(new float[] {1f, 2f, 3f, 4f, 5f}); DataBuffer exp = Nd4j.createBuffer(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("UINT8"); DataBuffer compr = BasicNDArrayCompressor.getInstance().compress(buffer); assertEquals(DataBuffer.Type.COMPRESSED, compr.dataType()); DataBuffer decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(1.0f, decomp.getFloat(0), 0.01f); assertEquals(2.0f, decomp.getFloat(1), 0.01f); assertEquals(3.0f, decomp.getFloat(2), 0.01f); assertEquals(4.0f, decomp.getFloat(3), 0.01f); assertEquals(5.0f, decomp.getFloat(4), 0.01f); } @Test public void testUint8Compression2() { DataBuffer buffer = Nd4j.createBuffer(new float[] {1f, 2f, 3f, 4f, 1005f}); DataBuffer exp = Nd4j.createBuffer(new float[] {1f, 2f, 3f, 4f, 1005f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("UINT8"); DataBuffer compr = BasicNDArrayCompressor.getInstance().compress(buffer); assertEquals(DataBuffer.Type.COMPRESSED, compr.dataType()); DataBuffer decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(1.0f, decomp.getFloat(0), 0.01f); assertEquals(2.0f, decomp.getFloat(1), 0.01f); assertEquals(3.0f, decomp.getFloat(2), 0.01f); assertEquals(4.0f, decomp.getFloat(3), 0.01f); assertEquals(255.0f, decomp.getFloat(4), 0.01f); } @Test public void testInt8Compression1() { DataBuffer buffer = Nd4j.createBuffer(new float[] {1f, 2f, 3f, 4f, 1005f, -3.7f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("INT8"); DataBuffer compr = BasicNDArrayCompressor.getInstance().compress(buffer); assertEquals(DataBuffer.Type.COMPRESSED, compr.dataType()); DataBuffer decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(1.0f, decomp.getFloat(0), 0.01f); assertEquals(2.0f, decomp.getFloat(1), 0.01f); assertEquals(3.0f, decomp.getFloat(2), 0.01f); assertEquals(4.0f, decomp.getFloat(3), 0.01f); assertEquals(127.0f, decomp.getFloat(4), 0.01f); assertEquals(-3.0f, decomp.getFloat(5), 0.01f); } @Test public void testGzipCompression1() { INDArray array = Nd4j.linspace(1, 10000, 20000); INDArray exp = array.dup(); BasicNDArrayCompressor.getInstance().setDefaultCompression("GZIP"); INDArray compr = BasicNDArrayCompressor.getInstance().compress(array); assertEquals(DataBuffer.Type.COMPRESSED, compr.data().dataType()); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(exp, array); assertEquals(exp, decomp); } @Test public void testNoOpCompression1() { INDArray array = Nd4j.linspace(1, 10000, 20000); INDArray exp = Nd4j.linspace(1, 10000, 20000); INDArray mps = Nd4j.linspace(1, 10000, 20000); BasicNDArrayCompressor.getInstance().setDefaultCompression("NOOP"); INDArray compr = BasicNDArrayCompressor.getInstance().compress(array); assertEquals(DataBuffer.Type.COMPRESSED, compr.data().dataType()); assertTrue(compr.isCompressed()); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(DataBuffer.Type.FLOAT, decomp.data().dataType()); assertFalse(decomp.isCompressed()); assertFalse(decomp.data() instanceof CompressedDataBuffer); assertFalse(exp.data() instanceof CompressedDataBuffer); assertFalse(exp.isCompressed()); assertFalse(array.data() instanceof CompressedDataBuffer); assertEquals(exp, decomp); } @Test public void testFP8Compression1() { INDArray array = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("FLOAT8"); BasicNDArrayCompressor.getInstance().printAvailableCompressors(); INDArray compr = BasicNDArrayCompressor.getInstance().compress(array); assertEquals(DataBuffer.Type.COMPRESSED, compr.data().dataType()); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compr); assertEquals(1.0f, decomp.getFloat(0), 0.01f); assertEquals(2.0f, decomp.getFloat(1), 0.01f); assertEquals(3.0f, decomp.getFloat(2), 0.01f); assertEquals(4.0f, decomp.getFloat(3), 0.01f); assertEquals(5.0f, decomp.getFloat(4), 0.01f); } @Test public void testJVMCompression1() throws Exception { INDArray exp = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("FLOAT16"); INDArray compressed = BasicNDArrayCompressor.getInstance().compress(new float[] {1f, 2f, 3f, 4f, 5f}); assertNotEquals(null, compressed.data()); assertNotEquals(null, compressed.shapeInfoDataBuffer()); assertTrue(compressed.isCompressed()); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compressed); assertEquals(exp, decomp); } @Test public void testJVMCompression2() throws Exception { INDArray exp = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("INT8"); INDArray compressed = BasicNDArrayCompressor.getInstance().compress(new float[] {1f, 2f, 3f, 4f, 5f}); assertNotEquals(null, compressed.data()); assertNotEquals(null, compressed.shapeInfoDataBuffer()); assertTrue(compressed.isCompressed()); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compressed); assertEquals(exp, decomp); } @Test public void testJVMCompression3() throws Exception { INDArray exp = Nd4j.create(new float[] {1f, 2f, 3f, 4f, 5f}); BasicNDArrayCompressor.getInstance().setDefaultCompression("NOOP"); INDArray compressed = BasicNDArrayCompressor.getInstance().compress(new float[] {1f, 2f, 3f, 4f, 5f}); assertNotEquals(null, compressed.data()); assertNotEquals(null, compressed.shapeInfoDataBuffer()); assertTrue(compressed.isCompressed()); INDArray decomp = BasicNDArrayCompressor.getInstance().decompress(compressed); assertEquals(exp, decomp); } @Override public char ordering() { return 'c'; } }