/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.functions; import com.rapidminer.datatable.DataTable; import com.rapidminer.datatable.SimpleDataTable; import com.rapidminer.datatable.SimpleDataTableRow; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.SimplePredictionModel; import com.rapidminer.tools.Tools; import com.rapidminer.tools.math.kernels.DotKernel; import com.rapidminer.tools.math.kernels.Kernel; /** * This model is a separating hyperplane for two classes. * * @author Sebastian Land */ public class HyperplaneModel extends SimplePredictionModel { private static final long serialVersionUID = -4990692589416639697L; private String[] coefficientNames; private double[] coefficients; private double intercept; private String classNegative; private String classPositive; private Kernel kernel; public HyperplaneModel(ExampleSet exampleSet) { this(exampleSet, null, null); } public HyperplaneModel(ExampleSet exampleSet, String classNegative, String classPositive) { this(exampleSet, classNegative, classPositive, new DotKernel()); } public HyperplaneModel(ExampleSet exampleSet, String classNegative, String classPositive, Kernel kernel) { super(exampleSet); this.coefficientNames = com.rapidminer.example.Tools.getRegularAttributeNames(exampleSet); this.classNegative = classNegative; this.classPositive = classPositive; this.kernel = kernel; } @Override public double predict(Example example) throws OperatorException { int i = 0; double distance = intercept; // using kernel for distance calculation double[] values = new double[example.getAttributes().size()]; for (Attribute currentAttribute : example.getAttributes()) { values[i] = example.getValue(currentAttribute); i++; } distance += kernel.calculateDistance(values, coefficients); if (getLabel().isNominal()) { int positiveMapping = getLabel().getMapping().mapString(classPositive); int negativeMapping = getLabel().getMapping().mapString(classNegative); boolean isApplying = example.getAttributes().getPredictedLabel() != null; if (isApplying) { example.setConfidence(classPositive, 1.0d / (1.0d + java.lang.Math.exp(-distance))); example.setConfidence(classNegative, 1.0d / (1.0d + java.lang.Math.exp(distance))); } if (distance < 0) { return negativeMapping; } else { return positiveMapping; } } else { return distance; } } public void init(double[] coefficients, double intercept) { this.coefficients = coefficients; this.intercept = intercept; } public double[] getCoefficients() { return coefficients; } public double getIntercept() { return intercept; } public void setCoefficients(double[] coefficients) { this.coefficients = coefficients; } public void setIntercept(double intercept) { this.intercept = intercept; } @Override public String toString() { StringBuffer buffer = new StringBuffer(); if ((classPositive != null) && (classNegative != null)) buffer.append("Hyperplane seperating " + classPositive + " and " + classNegative + "." + Tools.getLineSeparator()); else buffer.append("Hyperplane for linear regression." + Tools.getLineSeparator()); buffer.append("Intercept: "); buffer.append(Double.toString(intercept)); buffer.append(Tools.getLineSeparator()); buffer.append("Coefficients: " + Tools.getLineSeparator()); int counter = 0; for (double value : coefficients) { buffer.append("w(" + coefficientNames[counter] + ") = " + Tools.formatIntegerIfPossible(value, 3) + Tools.getLineSeparator()); counter++; } buffer.append(Tools.getLineSeparator()); return buffer.toString(); } public DataTable createWeightsTable() { SimpleDataTable weightTable = new SimpleDataTable("Hyperplane Model Weights", new String[] { "Attribute", "Weight" } ); for (int j = 0; j < this.coefficientNames.length; j++) { int nameIndex = weightTable.mapString(0, this.coefficientNames[j]); weightTable.add(new SimpleDataTableRow(new double[] { nameIndex, this.coefficients[j]})); } return weightTable; } }