/** * Copyright (c) 2007-2014 The LIBLINEAR Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this list of conditions * and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials provided with * the distribution. * * 3. Neither name of copyright holders nor the names of its contributors may be used to endorse or * promote products derived from this software without specific prior written permission. * * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF * THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package de.bwaldvogel.liblinear; import static de.bwaldvogel.liblinear.Linear.copyOf; import java.io.File; import java.io.IOException; import java.io.Reader; import java.io.Serializable; import java.io.Writer; import java.util.Arrays; /** * <p> * Model stores the model obtained from the training procedure * </p> * * <p> * use {@link Linear#loadModel(File)} and {@link Linear#saveModel(File, Model)} to load/save it * </p> */ public final class Model implements Serializable { private static final long serialVersionUID = -6456047576741854834L; double bias; /** label of each class */ public int[] label; public int nr_class; int nr_feature; SolverType solverType; /** feature weight array */ public double[] w; /** * @return number of classes */ public int getNrClass() { return nr_class; } /** * @return number of features */ public int getNrFeature() { return nr_feature; } public int[] getLabels() { return copyOf(label, nr_class); } /** * The nr_feature*nr_class array w gives feature weights. We use one against the rest for * multi-class classification, so each feature index corresponds to nr_class weight values. * Weights are organized in the following way * * <pre> * +------------------+------------------+------------+ * | nr_class weights | nr_class weights | ... * | for 1st feature | for 2nd feature | * +------------------+------------------+------------+ * </pre> * * If bias >= 0, x becomes [x; bias]. The number of features is increased by one, so w is a * (nr_feature+1)*nr_class array. The value of bias is stored in the variable bias. * * @see #getBias() * @return a <b>copy of</b> the feature weight array as described */ public double[] getFeatureWeights() { return Linear.copyOf(w, w.length); } /** * @return true for logistic regression solvers */ public boolean isProbabilityModel() { return solverType.isLogisticRegressionSolver(); } /** * @see #getFeatureWeights() */ public double getBias() { return bias; } @Override public String toString() { StringBuilder sb = new StringBuilder("Model"); sb.append(" bias=").append(bias); sb.append(" nr_class=").append(nr_class); sb.append(" nr_feature=").append(nr_feature); sb.append(" solverType=").append(solverType); return sb.toString(); } @Override public int hashCode() { final int prime = 31; int result = 1; long temp; temp = Double.doubleToLongBits(bias); result = prime * result + (int) (temp ^ temp >>> 32); result = prime * result + Arrays.hashCode(label); result = prime * result + nr_class; result = prime * result + nr_feature; result = prime * result + (solverType == null ? 0 : solverType.hashCode()); result = prime * result + Arrays.hashCode(w); return result; } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (obj == null) { return false; } if (getClass() != obj.getClass()) { return false; } Model other = (Model) obj; if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) { return false; } if (!Arrays.equals(label, other.label)) { return false; } if (nr_class != other.nr_class) { return false; } if (nr_feature != other.nr_feature) { return false; } if (solverType == null) { if (other.solverType != null) { return false; } } else if (!solverType.equals(other.solverType)) { return false; } if (!equals(w, other.w)) { return false; } return true; } /** * don't use {@link Arrays#equals(double[], double[])} here, cause 0.0 and -0.0 should be * handled the same * * @see Linear#saveModel(java.io.Writer, Model) */ protected static boolean equals(double[] a, double[] a2) { if (a == a2) { return true; } if (a == null || a2 == null) { return false; } int length = a.length; if (a2.length != length) { return false; } for (int i = 0; i < length; i++) { if (a[i] != a2[i]) { return false; } } return true; } /** * see {@link Linear#saveModel(java.io.File, Model)} */ public void save(File file) throws IOException { Linear.saveModel(file, this); } /** * see {@link Linear#saveModel(Writer, Model)} */ public void save(Writer writer) throws IOException { Linear.saveModel(writer, this); } /** * see {@link Linear#loadModel(File)} */ public static Model load(File file) throws IOException { return Linear.loadModel(file); } /** * see {@link Linear#loadModel(Reader)} */ public static Model load(Reader inputReader) throws IOException { return Linear.loadModel(inputReader); } }