package org.nd4j.linalg.serde;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import static junit.framework.TestCase.assertEquals;
/**
* Created by raver119 on 21.12.16.
*/
@RunWith(Parameterized.class)
@Slf4j
public class BasicSerDeTests extends BaseNd4jTest {
public BasicSerDeTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testBasicDataTypeSwitch1() throws Exception {
DataBuffer.Type initialType = Nd4j.dataType();
Nd4j.setDataType(DataBuffer.Type.FLOAT);
INDArray array = Nd4j.create(new float[] {1, 2, 3, 4, 5, 6});
ByteArrayOutputStream bos = new ByteArrayOutputStream();
Nd4j.write(bos, array);
Nd4j.setDataType(DataBuffer.Type.DOUBLE);
INDArray restored = Nd4j.read(new ByteArrayInputStream(bos.toByteArray()));
assertEquals(Nd4j.create(new float[] {1, 2, 3, 4, 5, 6}), restored);
assertEquals(8, restored.data().getElementSize());
assertEquals(4, restored.shapeInfoDataBuffer().getElementSize());
Nd4j.setDataType(initialType);
}
@Override
public char ordering() {
return 'f';
}
}