package de.jungblut.online.regression; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import org.apache.hadoop.util.ReflectionUtils; import com.google.common.base.Preconditions; import de.jungblut.math.DoubleVector; import de.jungblut.math.activation.ActivationFunction; import de.jungblut.online.ml.Model; import de.jungblut.writable.VectorWritable; public class RegressionModel implements Model { private DoubleVector weights; private ActivationFunction activationFunction; // deserialization constructor public RegressionModel() { } public RegressionModel(DoubleVector weights, ActivationFunction activationFunction) { this.weights = Preconditions.checkNotNull(weights, "weights"); this.activationFunction = Preconditions.checkNotNull(activationFunction, "activationFunction"); } @Override public void serialize(DataOutput out) throws IOException { out.writeUTF(activationFunction.getClass().getName()); VectorWritable.writeVector(weights, out); } @Override public RegressionModel deserialize(DataInput in) throws IOException { String clzName = in.readUTF(); try { this.activationFunction = (ActivationFunction) ReflectionUtils .newInstance(Class.forName(clzName), null); } catch (ClassNotFoundException e) { throw new IOException(e); } weights = VectorWritable.readVector(in); return this; } public DoubleVector getWeights() { return this.weights; } public ActivationFunction getActivationFunction() { return this.activationFunction; } }