package de.jungblut.online.regression.multinomial; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import com.google.common.base.Preconditions; import de.jungblut.online.ml.Model; import de.jungblut.online.regression.RegressionModel; public class MultinomialRegressionModel implements Model { private RegressionModel[] trainedModels; // deserialization constructor public MultinomialRegressionModel() { } public MultinomialRegressionModel(RegressionModel[] trainedModels) { this.trainedModels = Preconditions.checkNotNull(trainedModels, "weights"); for (int i = 0; i < trainedModels.length; i++) { Preconditions.checkNotNull(trainedModels[i], "model at index " + i); } } @Override public void serialize(DataOutput out) throws IOException { out.writeInt(trainedModels.length); for (RegressionModel model : trainedModels) { model.serialize(out); } } @Override public MultinomialRegressionModel deserialize(DataInput in) throws IOException { trainedModels = new RegressionModel[in.readInt()]; for (int i = 0; i < trainedModels.length; i++) { trainedModels[i] = new RegressionModel().deserialize(in); } return this; } public RegressionModel[] getModels() { return trainedModels; } }