/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.functions.kernel;
import java.util.Iterator;
import libsvm.Svm;
import libsvm.svm_model;
import libsvm.svm_node;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.FastExample2SparseTransform;
import com.rapidminer.operator.UserError;
import com.rapidminer.tools.Tools;
/**
* A model generated by the <a
* href="http://www.csie.ntu.edu.tw/~cjlin/libsvm">libsvm</a> by Chih-Chung
* Chang and Chih-Jen Lin.
*
* @author Ingo Mierswa
* @version $Id: LibSVMModel.java,v 1.15 2008/06/04 11:33:10 ingomierswa Exp $
*/
public class LibSVMModel extends KernelModel {
private static final long serialVersionUID = -2654603017217487365L;
private svm_model model;
private int numberOfAttributes;
private boolean confidenceForMultiClass = true;
public LibSVMModel(ExampleSet exampleSet, svm_model model, int numberOfAttributes, boolean confidenceForMultiClass) {
super(exampleSet);
this.model = model;
this.numberOfAttributes = numberOfAttributes;
this.confidenceForMultiClass = confidenceForMultiClass;
}
public boolean isClassificationModel() {
return getLabel().isNominal();
}
public double getAlpha(int index) {
return model.sv_coef[0][index];
}
public String getId(int index) {
return null;
}
public int getNumberOfSupportVectors() {
return model.SV.length;
}
public int getNumberOfAttributes() {
return numberOfAttributes;
}
public double getBias() {
if (this.model.rho.length > 0)
return this.model.rho[0];
else
return 0.0d;
}
public SupportVector getSupportVector(int index) {
svm_node[] nodes = this.model.SV[index];
double[] x = new double[getNumberOfAttributes()];
for (int i = 0; i < nodes.length; i++)
x[nodes[i].index] = nodes[i].value;
return new SupportVector(x, getRegressionLabel(index), Math.abs(getAlpha(index)));
}
public double getAttributeValue(int exampleIndex, int attributeIndex) {
double[] dense = new double[numberOfAttributes];
svm_node[] node = model.SV[exampleIndex];
for (int i = 0; i < node.length; i++) {
dense[node[i].index] = node[i].value;
}
return dense[attributeIndex];
}
public String getClassificationLabel(int index) {
double functionValue = getRegressionLabel(index);
if (!Double.isNaN(functionValue))
return getLabel().getMapping().mapIndex((int)functionValue);
else
return "?";
}
public double getRegressionLabel(int index) {
if (model.labelValues != null)
return model.labelValues[index];
else
return Double.NaN;
}
public double getFunctionValue(int index) {
if (getLabel().isNominal()) {
double[] classProbs = new double[getLabel().getMapping().size()];
Svm.svm_predict_probability(model, model.SV[index], classProbs);
return classProbs[0];
} else {
return Svm.svm_predict(model, model.SV[index]);
}
}
public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws UserError {
FastExample2SparseTransform ripper = new FastExample2SparseTransform(exampleSet);
Attribute label = getLabel();
Attribute[] confidenceAttributes = null;
if (label.isNominal() && label.getMapping().size() >= 2) {
confidenceAttributes = new Attribute[model.label.length];
for (int j = 0; j < model.label.length; j++) {
String labelName = label.getMapping().mapIndex(model.label[j]);
confidenceAttributes[j] = exampleSet.getAttributes().getSpecial(Attributes.CONFIDENCE_NAME + "_" + labelName);
}
}
if (label.isNominal() && (label.getMapping().size() == 1)) { // one class SVM
double[] allConfidences = new double[exampleSet.size()];
int counter = 0;
double maxConfidence = Double.NEGATIVE_INFINITY;
double minConfidence = Double.POSITIVE_INFINITY;
Iterator<Example> i = exampleSet.iterator();
while (i.hasNext()) {
Example e = i.next();
svm_node[] currentNodes = LibSVMLearner.makeNodes(e, ripper);
double[] prob = new double[1];
Svm.svm_predict_values(model, currentNodes, prob);
allConfidences[counter++] = prob[0];
minConfidence = Math.min(minConfidence, prob[0]);
maxConfidence = Math.max(maxConfidence, prob[0]);
}
counter = 0;
String className = predictedLabel.getMapping().mapIndex(0);
i = exampleSet.iterator();
while (i.hasNext()) {
Example e = i.next();
e.setValue(predictedLabel, 0);
e.setConfidence(className, (allConfidences[counter++] - minConfidence) / (maxConfidence - minConfidence));
}
} else {
Iterator<Example> i = exampleSet.iterator();
while (i.hasNext()) {
Example e = i.next();
if (label.isNominal()) {
// set prediction
svm_node[] currentNodes = LibSVMLearner.makeNodes(e, ripper);
// set class probs (properly calculated during training)
if ((model.probA != null) && (model.probB != null)) {
double[] classProbs = new double[model.nr_class];
int nr_class = model.nr_class;
double[] dec_values = new double[nr_class * (nr_class - 1) / 2];
Svm.svm_predict_values(model, currentNodes, dec_values);
double min_prob = 1e-7;
double[][] pairwise_prob = new double[nr_class][nr_class];
int k = 0;
for (int a = 0; a < nr_class; a++)
for (int j = a + 1; j < nr_class; j++) {
pairwise_prob[a][j] = Math.min(Math.max(Svm.sigmoid_predict(dec_values[k], model.probA[k], model.probB[k]), min_prob), 1 - min_prob);
pairwise_prob[j][a] = 1 - pairwise_prob[a][j];
k++;
}
Svm.multiclass_probability(nr_class, pairwise_prob, classProbs);
for (k = 0; k < nr_class; k++) {
e.setValue(confidenceAttributes[k], classProbs[k]);
}
if (confidenceForMultiClass) { // use highest confidence
double predictedClass = Svm.svm_predict_probability(model, currentNodes, classProbs);
e.setValue(predictedLabel, predictedClass);
} else { // binary majority vote over 1-vs-1 classifiers
double predictedClass = Svm.svm_predict(model, currentNodes);
e.setValue(predictedLabel, predictedClass);
}
} else {
double predictedClass = Svm.svm_predict(model, currentNodes);
e.setValue(predictedLabel, predictedClass);
// use simple calculation for binary cases...
if (label.getMapping().size() == 2) {
double[] functionValues = new double[model.nr_class];
Svm.svm_predict_values(model, currentNodes, functionValues);
double prediction = functionValues[0];
if ((confidenceAttributes != null) && (confidenceAttributes.length > 0)) {
e.setValue(confidenceAttributes[0], 1.0d / (1.0d + java.lang.Math.exp(-prediction)));
if (confidenceAttributes.length > 1) {
e.setValue(confidenceAttributes[1], 1.0d / (1.0d + java.lang.Math.exp(prediction)));
}
}
} else {
// ...or no proper confidence value for the multi class setting
// here the confidence attribute calculated above cannot be used
e.setConfidence(getLabel().getMapping().mapIndex((int)predictedClass), 1.0d);
}
}
} else {
e.setValue(predictedLabel, Svm.svm_predict(model, LibSVMLearner.makeNodes(e, ripper)));
}
}
}
return exampleSet;
}
public String toString() {
StringBuffer result = new StringBuffer(super.toString() + Tools.getLineSeparator());
result.append("number of classes: " + model.nr_class + Tools.getLineSeparator());
if (getLabel().isNominal() && (getLabel().getMapping().size() >= 2)) {
for (int i = 0; i < model.nSV.length; i++) {
result.append("number of support vectors for class " + getLabel().getMapping().mapIndex(model.label[i]) + ": " + model.nSV[i] + Tools.getLineSeparator());
}
} else {
result.append("number of support vectors: " + model.l + Tools.getLineSeparator());
}
return result.toString();
}
}