package de.jungblut.online.regression;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import org.junit.Assert;
import org.junit.Test;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.SigmoidActivationFunction;
import de.jungblut.math.dense.DenseDoubleVector;
public class TestRegressionModel {
@Test
public void testSerDe() throws IOException {
DoubleVector weights = new DenseDoubleVector(new double[] { 1, 2, 3, 4, 5 });
SigmoidActivationFunction activation = new SigmoidActivationFunction();
RegressionModel model = new RegressionModel(weights, activation);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
model.serialize(dos);
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
DataInputStream dis = new DataInputStream(bais);
model = new RegressionModel();
RegressionModel deserialized = model.deserialize(dis);
Assert.assertArrayEquals(weights.toArray(), deserialized.getWeights()
.toArray(), 1e-8);
Assert.assertEquals(activation.getClass(), model.getActivationFunction()
.getClass());
}
}