/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.functions.kernel.evosvm; import java.util.Iterator; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.learner.FormulaProvider; import com.rapidminer.operator.learner.functions.kernel.KernelModel; import com.rapidminer.operator.learner.functions.kernel.SupportVector; import com.rapidminer.tools.Tools; import com.rapidminer.tools.math.kernels.DotKernel; import com.rapidminer.tools.math.kernels.Kernel; /** * The model for the evolutionary SVM. Basically the same as other SVM models. * * @author Ingo Mierswa */ public class EvoSVMModel extends KernelModel implements FormulaProvider { private static final long serialVersionUID = 2848059541066828127L; /** The used kernel function. */ private Kernel kernel; /** The list of all support vectors. */ private List<SupportVector> supportVectors; /** The bias. */ private double bias; private double[] weights = null; /** Creates a classification model. */ public EvoSVMModel(ExampleSet exampleSet, List<SupportVector> supportVectors, Kernel kernel, double bias) { super(exampleSet); this.supportVectors = supportVectors; if ((supportVectors == null) || (supportVectors.size() == 0)) throw new IllegalArgumentException("Null or empty support vector collection: not possible to predict values!"); this.kernel = kernel; this.bias = bias; if (this.kernel instanceof DotKernel) { this.weights = new double[getNumberOfAttributes()]; for (int i = 0; i < getNumberOfSupportVectors(); i++) { SupportVector sv = getSupportVector(i); if (sv != null) { double[] x = sv.getX(); double alpha = sv.getAlpha(); double y = sv.getY(); for (int j = 0; j < weights.length; j++) { weights[j] += y * alpha * x[j]; } } else { this.weights = null; break; } } } } @Override public boolean isClassificationModel() { return getLabel().isNominal(); } @Override public double getAlpha(int index) { return supportVectors.get(index).getAlpha(); } @Override public String getId(int index) { return null; } @Override public double getBias() { return this.bias; } @Override public SupportVector getSupportVector(int index) { return supportVectors.get(index); } @Override public int getNumberOfSupportVectors() { return supportVectors.size(); } @Override public int getNumberOfAttributes() { return supportVectors.get(0).getX().length; } @Override public double getAttributeValue(int exampleIndex, int attributeIndex) { return this.supportVectors.get(exampleIndex).getX()[attributeIndex]; } @Override public String getClassificationLabel(int index) { double y = getRegressionLabel(index); if (y < 0) return getLabel().getMapping().getNegativeString(); else return getLabel().getMapping().getPositiveString(); } @Override public double getRegressionLabel(int index) { return this.supportVectors.get(index).getY(); } @Override public double getFunctionValue(int index) { double[] values = this.supportVectors.get(index).getX(); return (bias + kernel.getSum(supportVectors, values)); } /** Applies the model to each example of the example set. */ @Override public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predLabel) { if (exampleSet.getAttributes().size() != getNumberOfAttributes()) throw new RuntimeException("Cannot apply model: incompatible numbers of attributes (" + exampleSet.getAttributes().size() + " != " + getNumberOfAttributes() + ")!"); if (kernel instanceof DotKernel) { if (weights != null) { for (Example example : exampleSet) { double sum = getBias(); int a = 0; for (Attribute attribute : exampleSet.getAttributes()) { sum += weights[a] * example.getValue(attribute); a++; } if (getLabel().isNominal()) { int index = sum > 0 ? getLabel().getMapping().getPositiveIndex() : getLabel().getMapping().getNegativeIndex(); example.setValue(predLabel, index); example.setConfidence(predLabel.getMapping().getPositiveString(), 1.0d / (1.0d + java.lang.Math.exp(-sum))); example.setConfidence(predLabel.getMapping().getNegativeString(), 1.0d / (1.0d + java.lang.Math.exp(sum))); } else { example.setValue(predLabel, sum); } } return exampleSet; } } Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { Example current = reader.next(); double[] currentX = new double[exampleSet.getAttributes().size()]; int x = 0; for (Attribute attribute : exampleSet.getAttributes()) currentX[x++] = current.getValue(attribute); double sum = bias + kernel.getSum(supportVectors, currentX); if (getLabel().isNominal()) { int index = sum > 0 ? getLabel().getMapping().getPositiveIndex() : getLabel().getMapping().getNegativeIndex(); current.setValue(predLabel, index); current.setConfidence(predLabel.getMapping().getPositiveString(), 1.0d / (1.0d + java.lang.Math.exp(-sum))); current.setConfidence(predLabel.getMapping().getNegativeString(), 1.0d / (1.0d + java.lang.Math.exp(sum))); } else { current.setValue(predLabel, sum); } } return exampleSet; } public String getFormula() { StringBuffer result = new StringBuffer(); boolean first = true; for (int i = 0; i < getNumberOfSupportVectors(); i++) { SupportVector sv = getSupportVector(i); if (sv != null) { double alpha = sv.getAlpha(); if (!Tools.isZero(alpha)) { result.append(Tools.getLineSeparator()); double[] x = sv.getX(); double y = sv.getY(); double factor = y * alpha; if (factor < 0.0d) { if (first) result.append("- " + Math.abs(factor)); else result.append("- " + Math.abs(factor)); } else { if (first) { result.append(" " + factor); } else { result.append("+ " + factor); } } result.append(" * (" + kernel.getDistanceFormula(x, getAttributeConstructions()) + ")"); first = false; } } } double bias = getBias(); if (!Tools.isZero(bias)) { result.append(Tools.getLineSeparator()); if (bias < 0.0d) { if (first) result.append("- " + Math.abs(bias)); else result.append("- " + Math.abs(bias)); } else { if (first) { result.append(bias); } else { result.append("+ " + bias); } } } return result.toString(); } }