package org.nd4j.linalg.compression;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.ByteArrayInputStream;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
* Tests for SerDe on compressed arrays
* @author raver119@gmail.com
*/
@RunWith(Parameterized.class)
public class CompressionSerDeTests extends BaseNd4jTest {
public CompressionSerDeTests(Nd4jBackend backend) {
super(backend);
}
/*
This test checks for automatic decompression after deserialization
*/
@Test
public void testAutoDecompression1() throws Exception {
INDArray array = Nd4j.linspace(1, 250, 250);
INDArray compressed = Nd4j.getCompressor().compress(array, "UINT8");
ByteArrayOutputStream bos = new ByteArrayOutputStream();
Nd4j.write(bos, compressed);
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
INDArray result = Nd4j.read(bis);
assertEquals(array, result);
}
@Test
public void testManualDecompression1() throws Exception {
INDArray array = Nd4j.linspace(1, 5, 10);
INDArray compressed = Nd4j.getCompressor().compress(array, "FLOAT16");
assertEquals(true, compressed.isCompressed());
// assertEquals(true, compressed.data() instanceof CompressedDataBuffer);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
Nd4j.write(bos, compressed);
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
INDArray result = Nd4j.read(bis);
// INDArray decomp = Nd4j.getCompressor().decompress(result);
assertArrayEquals(array.data().asFloat(), result.data().asFloat(), 0.1f);
}
@Test
public void testAutoDecompression2() throws Exception {
INDArray array = Nd4j.linspace(1, 10, 11);
INDArray compressed = Nd4j.getCompressor().compress(array, "GZIP");
ByteArrayOutputStream bos = new ByteArrayOutputStream();
Nd4j.write(bos, compressed);
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
System.out.println("Restoring -------------------------");
INDArray result = Nd4j.read(bis);
System.out.println("Decomp -------------------------");
INDArray decomp = Nd4j.getCompressor().decompress(result);
assertEquals(array, decomp);
}
@Override
public char ordering() {
return 'c';
}
}