/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.functions.kernel; import libsvm.Svm; import libsvm.svm_model; import libsvm.svm_node; import libsvm.svm_parameter; import com.rapidminer.example.Attribute; import com.rapidminer.example.AttributeRole; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.FastExample2SparseTransform; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.operator.UserError; import com.rapidminer.operator.learner.FormulaProvider; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.Tools; /** * A model generated by the <a href="http://www.csie.ntu.edu.tw/~cjlin/libsvm">libsvm</a> by Chih-Chung Chang and * Chih-Jen Lin. * * @author Ingo Mierswa */ public class LibSVMModel extends KernelModel implements FormulaProvider { private static final long serialVersionUID = -2654603017217487365L; private svm_model model; private int numberOfAttributes; private boolean confidenceForMultiClass = true; public LibSVMModel(ExampleSet exampleSet, svm_model model, int numberOfAttributes, boolean confidenceForMultiClass) { super(exampleSet); this.model = model; this.numberOfAttributes = numberOfAttributes; this.confidenceForMultiClass = confidenceForMultiClass; } @Override public boolean isClassificationModel() { return getLabel().isNominal(); } @Override public double getAlpha(int index) { return model.sv_coef[0][index]; } @Override public String getId(int index) { return null; } @Override public int getNumberOfSupportVectors() { return model.SV.length; } @Override public int getNumberOfAttributes() { return numberOfAttributes; } @Override public double getBias() { if (this.model.rho.length > 0) return this.model.rho[0]; else return 0.0d; } @Override public SupportVector getSupportVector(int index) { svm_node[] nodes = this.model.SV[index]; double[] x = new double[getNumberOfAttributes()]; for (int i = 0; i < nodes.length; i++) x[nodes[i].index] = nodes[i].value; return new SupportVector(x, getRegressionLabel(index), Math.abs(getAlpha(index))); } @Override public double getAttributeValue(int exampleIndex, int attributeIndex) { double[] dense = new double[numberOfAttributes]; svm_node[] node = model.SV[exampleIndex]; for (int i = 0; i < node.length; i++) { dense[node[i].index] = node[i].value; } return dense[attributeIndex]; } @Override public String getClassificationLabel(int index) { double functionValue = getRegressionLabel(index); if (!Double.isNaN(functionValue)) return getLabel().getMapping().mapIndex((int) functionValue); else return "?"; } @Override public double getRegressionLabel(int index) { if (model.labelValues != null) return model.labelValues[index]; else return Double.NaN; } @Override public double getFunctionValue(int index) { if (getLabel().isNominal()) { double[] classProbs = new double[getLabel().getMapping().size()]; Svm.svm_predict_probability(model, model.SV[index], classProbs); return classProbs[0]; } else { return Svm.svm_predict(model, model.SV[index]); } } @Override public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws UserError { FastExample2SparseTransform ripper = new FastExample2SparseTransform(exampleSet); Attribute label = getLabel(); // check if one class SVM is used if (model.param.svm_type == LibSVMLearner.SVM_TYPE_ONE_CLASS) { // if yes, then clear predictedLabel mapping: We use a fixed one predictedLabel.getMapping().getValues().clear(); predictedLabel.getMapping().getValues().add("outside"); predictedLabel.getMapping().getValues().add("inside"); // create own confidence attribute Attribute confidenceAttribute = AttributeFactory.createAttribute(Attributes.CONFIDENCE_NAME + "(inside)", Ontology.REAL); exampleSet.getExampleTable().addAttribute(confidenceAttribute); AttributeRole confidenceRole = new AttributeRole(confidenceAttribute); confidenceRole.setSpecial(Attributes.CONFIDENCE_NAME + "_inside"); exampleSet.getAttributes().add(confidenceRole); // now calculate weights int counter = 0; double[] allConfidences = new double[exampleSet.size()]; int[] allLabels = new int[exampleSet.size()]; double maxConfidence = Double.NEGATIVE_INFINITY; double minConfidence = Double.POSITIVE_INFINITY; double confidence; for (Example example : exampleSet) { svm_node[] currentNodes = LibSVMLearner.makeNodes(example, ripper); double[] prob = new double[1]; Svm.svm_predict_values(model, currentNodes, prob); allLabels[counter] = (prob[0] >= 0) ? 1 : 0; allConfidences[counter] = prob[0]; minConfidence = Math.min(minConfidence, prob[0]); maxConfidence = Math.max(maxConfidence, prob[0]); counter++; } counter = 0; for (Example example : exampleSet) { confidence = allConfidences[counter]; // (allConfidences[counter] - minConfidence) / (maxConfidence - minConfidence); example.setValue(predictedLabel, allLabels[counter]); example.setValue(confidenceAttribute, confidence); counter++; } } else { // performing regular classification or regression Attribute[] confidenceAttributes = null; if (label.isNominal() && label.getMapping().size() >= 2) { confidenceAttributes = new Attribute[model.label.length]; for (int j = 0; j < model.label.length; j++) { String labelName = label.getMapping().mapIndex(model.label[j]); confidenceAttributes[j] = exampleSet.getAttributes().getSpecial(Attributes.CONFIDENCE_NAME + "_" + labelName); } } for (Example example : exampleSet) { if (label.isNominal()) { // set prediction svm_node[] currentNodes = LibSVMLearner.makeNodes(example, ripper); // set class probs (properly calculated during training) if ((model.probA != null) && (model.probB != null)) { double[] classProbs = new double[model.nr_class]; int nr_class = model.nr_class; double[] dec_values = new double[nr_class * (nr_class - 1) / 2]; Svm.svm_predict_values(model, currentNodes, dec_values); double min_prob = 1e-7; double[][] pairwise_prob = new double[nr_class][nr_class]; int k = 0; for (int a = 0; a < nr_class; a++) for (int j = a + 1; j < nr_class; j++) { pairwise_prob[a][j] = Math.min(Math.max(Svm.sigmoid_predict(dec_values[k], model.probA[k], model.probB[k]), min_prob), 1 - min_prob); pairwise_prob[j][a] = 1 - pairwise_prob[a][j]; k++; } Svm.multiclass_probability(nr_class, pairwise_prob, classProbs); for (k = 0; k < nr_class; k++) { example.setValue(confidenceAttributes[k], classProbs[k]); } if (confidenceForMultiClass) { // use highest confidence double predictedClass = Svm.svm_predict_probability(model, currentNodes, classProbs); example.setValue(predictedLabel, predictedClass); } else { // binary majority vote over 1-vs-1 classifiers double predictedClass = Svm.svm_predict(model, currentNodes); example.setValue(predictedLabel, predictedClass); } } else { double predictedClass = Svm.svm_predict(model, currentNodes); example.setValue(predictedLabel, predictedClass); // use simple calculation for binary cases... if (label.getMapping().size() == 2) { double[] functionValues = new double[model.nr_class]; Svm.svm_predict_values(model, currentNodes, functionValues); double prediction = functionValues[0]; if ((confidenceAttributes != null) && (confidenceAttributes.length > 0)) { example.setValue(confidenceAttributes[0], 1.0d / (1.0d + java.lang.Math.exp(-prediction))); if (confidenceAttributes.length > 1) { example.setValue(confidenceAttributes[1], 1.0d / (1.0d + java.lang.Math.exp(prediction))); } } } else { // ...or no proper confidence value for the multi class setting // here the confidence attribute calculated above cannot be used example.setConfidence(getLabel().getMapping().mapIndex((int) predictedClass), 1.0d); } } } else { example.setValue(predictedLabel, Svm.svm_predict(model, LibSVMLearner.makeNodes(example, ripper))); } } } return exampleSet; } @Override protected boolean supportsConfidences(Attribute label) { return super.supportsConfidences(label) && model.param.svm_type != LibSVMLearner.SVM_TYPE_ONE_CLASS; } @Override public String toString() { StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator()); result.append("number of classes: " + model.nr_class + Tools.getLineSeparator()); if (getLabel().isNominal() && (getLabel().getMapping().size() >= 2) && model.nSV != null) { for (int i = 0; i < model.nSV.length; i++) { result.append("number of support vectors for class " + getLabel().getMapping().mapIndex(model.label[i]) + ": " + model.nSV[i] + Tools.getLineSeparator()); } } else { result.append("number of support vectors: " + model.l + Tools.getLineSeparator()); } return result.toString(); } public String getFormula() { StringBuffer result = new StringBuffer(); int kernelType = this.model.param.kernel_type; if (kernelType == svm_parameter.PRECOMPUTED) { return "Precomputed kernel, no formula possible."; } else if (kernelType == svm_parameter.RBF) { return "RBF kernel, no formula possible."; } boolean first = true; for (int i = 0; i < getNumberOfSupportVectors(); i++) { SupportVector sv = getSupportVector(i); if (sv != null) { double alpha = sv.getAlpha(); if (!Tools.isZero(alpha)) { result.append(Tools.getLineSeparator()); double[] x = sv.getX(); double y = sv.getY(); double factor = y * alpha; if (factor < 0.0d) { if (first) result.append("- " + Math.abs(factor)); else result.append("- " + Math.abs(factor)); } else { if (first) { result.append(" " + factor); } else { result.append("+ " + factor); } } result.append(" * (" + getDistanceFormula(x, getAttributeConstructions()) + ")"); first = false; } } } double bias = getBias(); if (!Tools.isZero(bias)) { result.append(Tools.getLineSeparator()); if (bias < 0.0d) { if (first) result.append("- " + Math.abs(bias)); else result.append("- " + Math.abs(bias)); } else { if (first) { result.append(bias); } else { result.append("+ " + bias); } } } return result.toString(); } private String getDistanceFormula(double[] x, String[] attributeConstructions) { int kernelType = this.model.param.kernel_type; switch (kernelType) { case svm_parameter.LINEAR: StringBuffer result = new StringBuffer(); boolean first = true; for (int i = 0; i < x.length; i++) { double value = x[i]; if (!Tools.isZero(value)) { if (value < 0.0d) { if (first) result.append("-" + Math.abs(value) + " * " + attributeConstructions[i]); else result.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]); } else { if (first) result.append(value + " * " + attributeConstructions[i]); else result.append(" + " + value + " * " + attributeConstructions[i]); } first = false; } } return result.toString(); case svm_parameter.POLY: StringBuffer dotResult = new StringBuffer(); first = true; for (int i = 0; i < x.length; i++) { double value = x[i]; if (!Tools.isZero(value)) { if (value < 0.0d) { if (first) dotResult.append("-" + Math.abs(value) + " * " + attributeConstructions[i]); else dotResult.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]); } else { if (first) dotResult.append(value + " * " + attributeConstructions[i]); else dotResult.append(" + " + value + " * " + attributeConstructions[i]); } first = false; } } return "pow((" + model.param.gamma + " * (" + dotResult.toString() + ") + " + model.param.coef0 + "), " + model.param.degree + ")"; case svm_parameter.SIGMOID: dotResult = new StringBuffer(); first = true; for (int i = 0; i < x.length; i++) { double value = x[i]; if (!Tools.isZero(value)) { if (value < 0.0d) { if (first) dotResult.append("-" + Math.abs(value) + " * " + attributeConstructions[i]); else dotResult.append(" - " + Math.abs(value) + " * " + attributeConstructions[i]); } else { if (first) dotResult.append(value + " * " + attributeConstructions[i]); else dotResult.append(" + " + value + " * " + attributeConstructions[i]); } first = false; } } return "tanh(" + model.param.gamma + " * (" + dotResult.toString() + ") + " + model.param.coef0 + ")"; default: return ""; } } }