/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.functions.kernel;
import java.util.Iterator;
import java.util.Map;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.FormulaProvider;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples.MeanVariance;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.Kernel;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.KernelDot;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.svm.SVMInterface;
import com.rapidminer.tools.Tools;
/**
* The abstract superclass for the SVM models by Stefan Rueping.
*
* @author Ingo Mierswa
*/
public abstract class AbstractMySVMModel extends KernelModel implements FormulaProvider {
private static final long serialVersionUID = 2812901947459843681L;
private com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples model;
private Kernel kernel;
private double[] weights = null;
public AbstractMySVMModel(ExampleSet exampleSet, com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples model, Kernel kernel, int kernelType) {
super(exampleSet);
this.model = model;
this.kernel = kernel;
if (this.kernel instanceof KernelDot) {
this.weights = new double[getNumberOfAttributes()];
for (int i = 0; i < getNumberOfSupportVectors(); i++) {
SupportVector sv = getSupportVector(i);
if (sv != null) {
double[] x = sv.getX();
double alpha = sv.getAlpha();
double y = sv.getY();
for (int j = 0; j < weights.length; j++) {
weights[j] += y * alpha * x[j];
}
} else {
this.weights = null;
break;
}
}
}
}
/** Creates a new SVM for prediction. */
public abstract SVMInterface createSVM();
@Override
public boolean isClassificationModel() {
return getLabel().isNominal();
}
@Override
public double getBias() {
return model.get_b();
}
/** This method must divide the alpha by the label since internally the alpha value is already multiplied with y. */
@Override
public SupportVector getSupportVector(int index) {
double alpha = model.get_alpha(index);
double y = model.get_y(index);
if (y != 0.0d) {
alpha /= y;
}
return new SupportVector(model.get_example(index).toDense(getNumberOfAttributes()), y, alpha);
}
@Override
public double getAlpha(int index) {
return model.get_alpha(index);
}
@Override
public String getId(int index) {
return model.getId(index);
}
@Override
public int getNumberOfSupportVectors() {
return model.count_examples();
}
@Override
public int getNumberOfAttributes() {
return model.get_dim();
}
@Override
public double getAttributeValue(int exampleIndex, int attributeIndex) {
com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExample sVMExample = model.get_example(exampleIndex);
double value = 0.0d;
try {
value = sVMExample.toDense(getNumberOfAttributes())[attributeIndex];
} catch (ArrayIndexOutOfBoundsException e) {
// dense array to short --> use default value
}
return value;
}
@Override
public String getClassificationLabel(int index) {
double y = model.get_y(index);
if (y < 0)
return getLabel().getMapping().getNegativeString();
else
return getLabel().getMapping().getPositiveString();
}
@Override
public double getRegressionLabel(int index) {
return model.get_y(index);
}
@Override
public double getFunctionValue(int index) {
SVMInterface svm = createSVM();
svm.init(kernel, model);
return svm.predict(model.get_example(index));
}
/** Gets the kernel. */
public Kernel getKernel() {
return kernel;
}
/** Gets the model, i.e. an SVM example set. */
public com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples getExampleSet() {
return model;
}
/**
* Sets the correct prediction to the example from the result value of the
* SVM.
*/
public abstract void setPrediction(Example example, double prediction);
@Override
public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabelAttribute) throws OperatorException {
if (kernel instanceof KernelDot) {
if (weights != null) {
Map<Integer, MeanVariance> meanVariances = model.getMeanVariances();
for (Example example : exampleSet) {
double prediction = getBias();
int a = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
double value = example.getValue(attribute);
MeanVariance meanVariance = meanVariances.get(a);
if (meanVariance != null) {
if (meanVariance.getVariance() == 0.0d)
value = 0.0d;
else
value = (value - meanVariance.getMean()) / Math.sqrt(meanVariance.getVariance());
}
prediction += weights[a] * value;
a++;
}
setPrediction(example, prediction);
}
return exampleSet;
}
}
// only if not simple dot hyperplane (see above)...
com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples toPredict = new com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples(exampleSet, exampleSet.getAttributes().getPredictedLabel(), model.getMeanVariances());
SVMInterface svm = createSVM();
svm.init(kernel, model);
svm.predict(toPredict);
// set predictions from toPredict
Iterator<Example> reader = exampleSet.iterator();
int k = 0;
while (reader.hasNext()) {
setPrediction(reader.next(), toPredict.get_y(k++));
}
return exampleSet;
}
public String getFormula() {
StringBuffer result = new StringBuffer();
Kernel kernel = getKernel();
boolean first = true;
for (int i = 0; i < getNumberOfSupportVectors(); i++) {
SupportVector sv = getSupportVector(i);
if (sv != null) {
double alpha = sv.getAlpha();
if (!Tools.isZero(alpha)) {
result.append(Tools.getLineSeparator());
double[] x = sv.getX();
double y = sv.getY();
double factor = y * alpha;
if (factor < 0.0d) {
if (first)
result.append("- " + Math.abs(factor));
else
result.append("- " + Math.abs(factor));
} else {
if (first) {
result.append(" " + factor);
} else {
result.append("+ " + factor);
}
}
result.append(" * (" + kernel.getDistanceFormula(x, getAttributeConstructions()) + ")");
first = false;
}
}
}
double bias = getBias();
if (!Tools.isZero(bias)) {
result.append(Tools.getLineSeparator());
if (bias < 0.0d) {
if (first)
result.append("- " + Math.abs(bias));
else
result.append("- " + Math.abs(bias));
} else {
if (first) {
result.append(bias);
} else {
result.append("+ " + bias);
}
}
}
return result.toString();
}
}