/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.functions.kernel; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorCapability; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.annotation.ResourceConsumptionEstimator; import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples; import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.Kernel; import com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel.KernelDot; import com.rapidminer.operator.learner.functions.kernel.jmysvm.svm.SVMInterface; import com.rapidminer.operator.learner.functions.kernel.jmysvm.svm.SVMpattern; import com.rapidminer.operator.learner.functions.kernel.jmysvm.svm.SVMregression; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeBoolean; import com.rapidminer.parameter.ParameterTypeDouble; import com.rapidminer.parameter.ParameterTypeInt; import com.rapidminer.tools.OperatorResourceConsumptionHandler; import com.rapidminer.tools.RandomGenerator; /** * This class implements a special case of the MySVM by restricting it to the linear (dot) kernel. * This way the weights of the linear combination can be extracted and stored solely in the resulting model. * The model is optimized for small size / fast store and retrieve operations as well as time efficient * during application. * * @author Sebastian Land */ public class LinearMySVMLearner extends AbstractKernelBasedLearner { /** The parameter name for "Size of the cache for kernel evaluations im MB " */ public static final String PARAMETER_KERNEL_CACHE = "kernel_cache"; /** The parameter name for "Precision on the KKT conditions" */ public static final String PARAMETER_CONVERGENCE_EPSILON = "convergence_epsilon"; /** The parameter name for "Stop after this many iterations" */ public static final String PARAMETER_MAX_ITERATIONS = "max_iterations"; /** The parameter name for "Scale the example values and store the scaling parameters for test set." */ public static final String PARAMETER_SCALE = "scale"; public static final String PARAMETER_C = "C"; /** The parameter name for "A factor for the SVM complexity constant for positive examples" */ public static final String PARAMETER_L_POS = "L_pos"; /** The parameter name for "A factor for the SVM complexity constant for negative examples" */ public static final String PARAMETER_L_NEG = "L_neg"; /** The parameter name for "Insensitivity constant. No loss if prediction lies this close to true value" */ public static final String PARAMETER_EPSILON = "epsilon"; /** The parameter name for "Epsilon for positive deviation only" */ public static final String PARAMETER_EPSILON_PLUS = "epsilon_plus"; /** The parameter name for "Epsilon for negative deviation only" */ public static final String PARAMETER_EPSILON_MINUS = "epsilon_minus"; /** The parameter name for "Adapts Cpos and Cneg to the relative size of the classes" */ public static final String PARAMETER_BALANCE_COST = "balance_cost"; /** The parameter name for "Use quadratic loss for positive deviation" */ public static final String PARAMETER_QUADRATIC_LOSS_POS = "quadratic_loss_pos"; /** The parameter name for "Use quadratic loss for negative deviation" */ public static final String PARAMETER_QUADRATIC_LOSS_NEG = "quadratic_loss_neg"; /** Indicates a linear kernel. */ public static final int KERNEL_DOT = 0; /** The SVM example set. */ private com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples svmExamples; public LinearMySVMLearner(OperatorDescription description) { super(description); } protected SVMInterface createSVM(Attribute label, Kernel kernel, SVMExamples sVMExamples, com.rapidminer.example.ExampleSet rapidMinerExamples) throws OperatorException { if (label.isNominal()) { return new SVMpattern(this, kernel, sVMExamples, rapidMinerExamples, RandomGenerator.getGlobalRandomGenerator()); } else { return new SVMregression(this, kernel, sVMExamples, rapidMinerExamples, RandomGenerator.getGlobalRandomGenerator()); } } @Override public Model learn(ExampleSet exampleSet) throws OperatorException { Attribute label = exampleSet.getAttributes().getLabel(); if ((label.isNominal()) && (label.getMapping().size() != 2)) { throw new UserError(this, 114, getName(), label.getName()); } this.svmExamples = new com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples(exampleSet, label, getParameterAsBoolean(PARAMETER_SCALE)); // kernel int cacheSize = getParameterAsInt(PARAMETER_KERNEL_CACHE); Kernel kernel = new KernelDot(); kernel.init(svmExamples, cacheSize); // SVM SVMInterface svm = createSVM(label, kernel, svmExamples, exampleSet); svm.init(kernel, svmExamples); svm.train(); LinearMySVMModel model = new LinearMySVMModel(exampleSet, svmExamples, kernel, KERNEL_DOT); this.svmExamples = null; return model; } @Override public boolean supportsCapability(OperatorCapability lc) { if (lc == OperatorCapability.NUMERICAL_ATTRIBUTES) return true; if (lc == OperatorCapability.BINOMINAL_LABEL) return true; if (lc == OperatorCapability.NUMERICAL_LABEL) return true; if (lc == OperatorCapability.WEIGHTED_EXAMPLES) return true; if (lc == OperatorCapability.FORMULA_PROVIDER) return true; return false; } @Override public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); types.add(new ParameterTypeInt(PARAMETER_KERNEL_CACHE, "Size of the cache for kernel evaluations im MB ", 0, Integer.MAX_VALUE, 200)); ParameterType type = new ParameterTypeDouble(PARAMETER_C, "The SVM complexity constant. Use -1 for different C values for positive and negative.", -1, Double.POSITIVE_INFINITY, 0.0d); type.setExpert(false); types.add(type); type = new ParameterTypeDouble(PARAMETER_CONVERGENCE_EPSILON, "Precision on the KKT conditions", 0.0d, Double.POSITIVE_INFINITY, 1e-3); types.add(type); types.add(new ParameterTypeInt(PARAMETER_MAX_ITERATIONS, "Stop after this many iterations", 1, Integer.MAX_VALUE, 100000)); types.add(new ParameterTypeBoolean(PARAMETER_SCALE, "Scale the example values and store the scaling parameters for test set.", true)); types.add(new ParameterTypeDouble(JMySVMLearner.PARAMETER_L_POS, "A factor for the SVM complexity constant for positive examples", 0, Double.POSITIVE_INFINITY, 1.0d)); types.add(new ParameterTypeDouble(JMySVMLearner.PARAMETER_L_NEG, "A factor for the SVM complexity constant for negative examples", 0, Double.POSITIVE_INFINITY, 1.0d)); types.add(new ParameterTypeDouble(JMySVMLearner.PARAMETER_EPSILON, "Insensitivity constant. No loss if prediction lies this close to true value", 0.0d, Double.POSITIVE_INFINITY, 0.0d)); types.add(new ParameterTypeDouble(JMySVMLearner.PARAMETER_EPSILON_PLUS, "Epsilon for positive deviation only", 0.0d, Double.POSITIVE_INFINITY, 0.0d)); types.add(new ParameterTypeDouble(JMySVMLearner.PARAMETER_EPSILON_MINUS, "Epsilon for negative deviation only", 0.0d, Double.POSITIVE_INFINITY, 0.0d)); types.add(new ParameterTypeBoolean(JMySVMLearner.PARAMETER_BALANCE_COST, "Adapts Cpos and Cneg to the relative size of the classes", false)); types.add(new ParameterTypeBoolean(JMySVMLearner.PARAMETER_QUADRATIC_LOSS_POS, "Use quadratic loss for positive deviation", false)); types.add(new ParameterTypeBoolean(JMySVMLearner.PARAMETER_QUADRATIC_LOSS_NEG, "Use quadratic loss for negative deviation", false)); return types; } @Override public ResourceConsumptionEstimator getResourceConsumptionEstimator() { return OperatorResourceConsumptionHandler.getResourceConsumptionEstimator(getExampleSetInputPort(), LinearMySVMLearner.class, null); } }