/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.functions.kernel; import java.util.Iterator; import libsvm.Svm; import libsvm.svm_model; import libsvm.svm_node; import com.rapidminer.example.Attribute; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.FastExample2SparseTransform; import com.rapidminer.operator.UserError; import com.rapidminer.tools.Tools; /** * A model generated by the <a * href="http://www.csie.ntu.edu.tw/~cjlin/libsvm">libsvm</a> by Chih-Chung * Chang and Chih-Jen Lin. * * @author Ingo Mierswa * @version $Id: LibSVMModel.java,v 1.15 2008/06/04 11:33:10 ingomierswa Exp $ */ public class LibSVMModel extends KernelModel { private static final long serialVersionUID = -2654603017217487365L; private svm_model model; private int numberOfAttributes; private boolean confidenceForMultiClass = true; public LibSVMModel(ExampleSet exampleSet, svm_model model, int numberOfAttributes, boolean confidenceForMultiClass) { super(exampleSet); this.model = model; this.numberOfAttributes = numberOfAttributes; this.confidenceForMultiClass = confidenceForMultiClass; } public boolean isClassificationModel() { return getLabel().isNominal(); } public double getAlpha(int index) { return model.sv_coef[0][index]; } public String getId(int index) { return null; } public int getNumberOfSupportVectors() { return model.SV.length; } public int getNumberOfAttributes() { return numberOfAttributes; } public double getBias() { if (this.model.rho.length > 0) return this.model.rho[0]; else return 0.0d; } public SupportVector getSupportVector(int index) { svm_node[] nodes = this.model.SV[index]; double[] x = new double[getNumberOfAttributes()]; for (int i = 0; i < nodes.length; i++) x[nodes[i].index] = nodes[i].value; return new SupportVector(x, getRegressionLabel(index), Math.abs(getAlpha(index))); } public double getAttributeValue(int exampleIndex, int attributeIndex) { double[] dense = new double[numberOfAttributes]; svm_node[] node = model.SV[exampleIndex]; for (int i = 0; i < node.length; i++) { dense[node[i].index] = node[i].value; } return dense[attributeIndex]; } public String getClassificationLabel(int index) { double functionValue = getRegressionLabel(index); if (!Double.isNaN(functionValue)) return getLabel().getMapping().mapIndex((int)functionValue); else return "?"; } public double getRegressionLabel(int index) { if (model.labelValues != null) return model.labelValues[index]; else return Double.NaN; } public double getFunctionValue(int index) { if (getLabel().isNominal()) { double[] classProbs = new double[getLabel().getMapping().size()]; Svm.svm_predict_probability(model, model.SV[index], classProbs); return classProbs[0]; } else { return Svm.svm_predict(model, model.SV[index]); } } public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws UserError { FastExample2SparseTransform ripper = new FastExample2SparseTransform(exampleSet); Attribute label = getLabel(); Attribute[] confidenceAttributes = null; if (label.isNominal() && label.getMapping().size() >= 2) { confidenceAttributes = new Attribute[model.label.length]; for (int j = 0; j < model.label.length; j++) { String labelName = label.getMapping().mapIndex(model.label[j]); confidenceAttributes[j] = exampleSet.getAttributes().getSpecial(Attributes.CONFIDENCE_NAME + "_" + labelName); } } if (label.isNominal() && (label.getMapping().size() == 1)) { // one class SVM double[] allConfidences = new double[exampleSet.size()]; int counter = 0; double maxConfidence = Double.NEGATIVE_INFINITY; double minConfidence = Double.POSITIVE_INFINITY; Iterator<Example> i = exampleSet.iterator(); while (i.hasNext()) { Example e = i.next(); svm_node[] currentNodes = LibSVMLearner.makeNodes(e, ripper); double[] prob = new double[1]; Svm.svm_predict_values(model, currentNodes, prob); allConfidences[counter++] = prob[0]; minConfidence = Math.min(minConfidence, prob[0]); maxConfidence = Math.max(maxConfidence, prob[0]); } counter = 0; String className = predictedLabel.getMapping().mapIndex(0); i = exampleSet.iterator(); while (i.hasNext()) { Example e = i.next(); e.setValue(predictedLabel, 0); e.setConfidence(className, (allConfidences[counter++] - minConfidence) / (maxConfidence - minConfidence)); } } else { Iterator<Example> i = exampleSet.iterator(); while (i.hasNext()) { Example e = i.next(); if (label.isNominal()) { // set prediction svm_node[] currentNodes = LibSVMLearner.makeNodes(e, ripper); // set class probs (properly calculated during training) if ((model.probA != null) && (model.probB != null)) { double[] classProbs = new double[model.nr_class]; int nr_class = model.nr_class; double[] dec_values = new double[nr_class * (nr_class - 1) / 2]; Svm.svm_predict_values(model, currentNodes, dec_values); double min_prob = 1e-7; double[][] pairwise_prob = new double[nr_class][nr_class]; int k = 0; for (int a = 0; a < nr_class; a++) for (int j = a + 1; j < nr_class; j++) { pairwise_prob[a][j] = Math.min(Math.max(Svm.sigmoid_predict(dec_values[k], model.probA[k], model.probB[k]), min_prob), 1 - min_prob); pairwise_prob[j][a] = 1 - pairwise_prob[a][j]; k++; } Svm.multiclass_probability(nr_class, pairwise_prob, classProbs); for (k = 0; k < nr_class; k++) { e.setValue(confidenceAttributes[k], classProbs[k]); } if (confidenceForMultiClass) { // use highest confidence double predictedClass = Svm.svm_predict_probability(model, currentNodes, classProbs); e.setValue(predictedLabel, predictedClass); } else { // binary majority vote over 1-vs-1 classifiers double predictedClass = Svm.svm_predict(model, currentNodes); e.setValue(predictedLabel, predictedClass); } } else { double predictedClass = Svm.svm_predict(model, currentNodes); e.setValue(predictedLabel, predictedClass); // use simple calculation for binary cases... if (label.getMapping().size() == 2) { double[] functionValues = new double[model.nr_class]; Svm.svm_predict_values(model, currentNodes, functionValues); double prediction = functionValues[0]; if ((confidenceAttributes != null) && (confidenceAttributes.length > 0)) { e.setValue(confidenceAttributes[0], 1.0d / (1.0d + java.lang.Math.exp(-prediction))); if (confidenceAttributes.length > 1) { e.setValue(confidenceAttributes[1], 1.0d / (1.0d + java.lang.Math.exp(prediction))); } } } else { // ...or no proper confidence value for the multi class setting // here the confidence attribute calculated above cannot be used e.setConfidence(getLabel().getMapping().mapIndex((int)predictedClass), 1.0d); } } } else { e.setValue(predictedLabel, Svm.svm_predict(model, LibSVMLearner.makeNodes(e, ripper))); } } } return exampleSet; } public String toString() { StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator()); result.append("number of classes: " + model.nr_class + Tools.getLineSeparator()); if (getLabel().isNominal() && (getLabel().getMapping().size() >= 2)) { for (int i = 0; i < model.nSV.length; i++) { result.append("number of support vectors for class " + getLabel().getMapping().mapIndex(model.label[i]) + ": " + model.nSV[i] + Tools.getLineSeparator()); } } else { result.append("number of support vectors: " + model.l + Tools.getLineSeparator()); } return result.toString(); } }