/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.rules; import java.util.LinkedList; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.Statistics; import com.rapidminer.example.set.SplittedExampleSet; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.AbstractLearner; import com.rapidminer.operator.learner.LearnerCapability; import com.rapidminer.operator.learner.tree.AbstractTreeLearner; import com.rapidminer.operator.learner.tree.EmptyTermination; import com.rapidminer.operator.learner.tree.NoAttributeLeftTermination; import com.rapidminer.operator.learner.tree.SplitCondition; import com.rapidminer.operator.learner.tree.Terminator; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeDouble; import com.rapidminer.parameter.ParameterTypeStringCategory; import com.rapidminer.parameter.UndefinedParameterError; import com.rapidminer.tools.Tools; /** * <p>This operator works similar to the propositional rule learner named * Repeated Incremental Pruning to Produce Error Reduction (RIPPER, Cohen 1995). * Starting with the less prevalent classes, the algorithm iteratively grows * and prunes rules until there are no positive examples left or the error * rate is greater than 50%.</p> * * <p>In the growing phase, for each rule greedily conditions are added to the rule * until the rule is perfect (i.e. 100% accurate). The procedure tries every * possible value of each attribute and selects the condition with highest * information gain.</p> * * <p>In the prune phase, for each rule any final sequences of the antecedents is * pruned with the pruning metric p/(p+n).</p> * @author Sebastian Land, Ingo Mierswa * @version $Id: RuleLearner.java,v 1.16 2008/05/09 19:23:13 ingomierswa Exp $ */ public class RuleLearner extends AbstractLearner { private static final String PARAMETER_SAMPLE_RATIO = "sample_ratio"; private static final String PARAMETER_MINIMAL_PRUNE_BENEFIT = "minimal_prune_benefit"; public static final String[] CRITERIA_NAMES = { "information_gain", "accuracy" }; public static final Class[] CRITERIA_CLASSES = { InfoGainCriterion.class, AccuracyCriterion.class }; public static final int CRITERION_INFO_GAIN = 0; public static final int CRITERION_ACCURACY = 1; private List<Terminator> terminators = new LinkedList<Terminator>(); public RuleLearner(OperatorDescription description) { super(description); } public Model learn(ExampleSet exampleSet) throws OperatorException { // init terminators.add(new EmptyTermination()); terminators.add(new NoAttributeLeftTermination()); double pureness = getParameterAsDouble(SimpleRuleLearner.PARAMETER_PURENESS); double sampleRatio = getParameterAsDouble(PARAMETER_SAMPLE_RATIO); double minimalPruneBenefit = getParameterAsDouble(PARAMETER_MINIMAL_PRUNE_BENEFIT); Attribute label = exampleSet.getAttributes().getLabel(); RuleModel ruleModel = new RuleModel(exampleSet); TermDetermination termDetermination = new TermDetermination(createCriterion()); ExampleSet trainingSet = (ExampleSet)exampleSet.clone(); trainingSet.recalculateAttributeStatistics(label); while (!shouldStop(trainingSet)) { String labelName = getNextLabel(trainingSet); Rule rule = new Rule(labelName); ExampleSet oldTrainingSet = (ExampleSet)trainingSet.clone(); SplittedExampleSet growPruneSet = new SplittedExampleSet(trainingSet, sampleRatio, SplittedExampleSet.STRATIFIED_SAMPLING, -1); // growing SplittedExampleSet growingSet = (SplittedExampleSet)growPruneSet.clone(); growingSet.selectSingleSubset(0); SplittedExampleSet pruneSet = (SplittedExampleSet)growPruneSet.clone(); pruneSet.selectSingleSubset(1); int growOldSize = -1; ExampleSet growSet = (ExampleSet)growingSet.clone(); while ((growSet.size() > 0) && (growSet.size() != growOldSize) && (!rule.isPure(growSet, pureness)) && (growSet.getAttributes().size() > 0)) { SplitCondition term = termDetermination.getBestTerm(growSet, labelName); if (term == null) break; // before adding: Check benefit if not added double prunedBenefit = 0; if (pruneSet.size() > 0) prunedBenefit = getPruneBenefit(rule, pruneSet); // add term rule.addTerm(term); // pruning if (pruneSet.size() > 0) { double unprunedBenefit = getPruneBenefit(rule, pruneSet); if (unprunedBenefit < prunedBenefit - minimalPruneBenefit) { rule.removeLastTerm(); // if best new term is pruned: no further extension of the rule break; } } growOldSize = growSet.size(); // removing uncovered rules growSet = rule.getCovered(growSet); // removing attribute Attribute splitAttribute = growSet.getAttributes().get(term.getAttributeName()); if (splitAttribute.isNominal()) { growSet.getAttributes().remove(splitAttribute); } checkForStop(); } if (rule.getTerms().size() > 0) { growSet = rule.getCovered(trainingSet); growSet.recalculateAttributeStatistics(label); int[] frequencies = new int[label.getMapping().size()]; int counter = 0; for (String value : label.getMapping().getValues()) frequencies[counter++] = (int)growSet.getStatistics(label, Statistics.COUNT, value); rule.setFrequencies(frequencies); ruleModel.addRule(rule); trainingSet = rule.removeCovered(oldTrainingSet); } else { break; } trainingSet.recalculateAttributeStatistics(label); checkForStop(); } // training set not empty? add default rule if (trainingSet.size() > 0) { trainingSet.recalculateAttributeStatistics(label); int index = (int)trainingSet.getStatistics(label, Statistics.MODE); String defaultLabel = label.getMapping().mapIndex(index); Rule defaultRule = new Rule(defaultLabel); int[] frequencies = new int[label.getMapping().size()]; int counter = 0; for (String value : label.getMapping().getValues()) frequencies[counter++] = (int)(trainingSet.getStatistics(label, Statistics.COUNT, value) * sampleRatio); defaultRule.setFrequencies(frequencies); ruleModel.addRule(defaultRule); } return ruleModel; } private double getPruneBenefit(Rule rule, ExampleSet exampleSet) { Attribute label = exampleSet.getAttributes().getLabel(); Attribute weight = exampleSet.getAttributes().getWeight(); double pTotal = 0.0d; double nTotal = 0.0d; double p = 0.0d; double n = 0.0d; for (Example e : exampleSet) { double currentWeight = 1.0d; if (weight != null) currentWeight = e.getValue(weight); if (e.getValue(label) == label.getMapping().getIndex(rule.getLabel())) pTotal += currentWeight; else nTotal += currentWeight; if (rule.coversExample(e)) { if (e.getValue(label) == label.getMapping().getIndex(rule.getLabel())) p += currentWeight; else n += currentWeight; } } return (p + nTotal - n) / (pTotal + nTotal); } private String getNextLabel(ExampleSet exampleSet) { Attribute label = exampleSet.getAttributes().getLabel(); int index = (int)exampleSet.getStatistics(label, Statistics.MODE); return label.getMapping().mapIndex(index); } private boolean shouldStop(ExampleSet exampleSet) { for(Terminator terminator : terminators) { if (terminator.shouldStop(exampleSet, 0)) return true; } return false; } private Criterion createCriterion() throws UndefinedParameterError { String criterionName = getParameterAsString(AbstractTreeLearner.PARAMETER_CRITERION); Class criterionClass = null; for (int i = 0; i < CRITERIA_NAMES.length; i++) { if (CRITERIA_NAMES[i].equals(criterionName)) { criterionClass = CRITERIA_CLASSES[i]; } } if ((criterionClass == null) && (criterionName != null)) { try { criterionClass = Tools.classForName(criterionName); } catch (ClassNotFoundException e) { logWarning("Cannot find criterion '"+criterionName+"' and cannot instantiate a class with this name. Using gain ratio criterion instead."); } } if (criterionClass != null) { try { return (Criterion)criterionClass.newInstance(); } catch (InstantiationException e) { logWarning("Cannot instantiate criterion class '"+criterionClass.getName()+"'. Using gain ratio criterion instead."); return new InfoGainCriterion(); } catch (IllegalAccessException e) { logWarning("Cannot access criterion class '"+criterionClass.getName()+"'. Using gain ratio criterion instead."); return new InfoGainCriterion(); } } else { log("No relevance criterion defined, using gain ratio..."); return new InfoGainCriterion(); } } public boolean supportsCapability(LearnerCapability capability) { if (capability == com.rapidminer.operator.learner.LearnerCapability.BINOMINAL_ATTRIBUTES) return true; if (capability == com.rapidminer.operator.learner.LearnerCapability.POLYNOMINAL_ATTRIBUTES) return true; if (capability == com.rapidminer.operator.learner.LearnerCapability.NUMERICAL_ATTRIBUTES) return true; if (capability == com.rapidminer.operator.learner.LearnerCapability.POLYNOMINAL_CLASS) return true; if (capability == com.rapidminer.operator.learner.LearnerCapability.BINOMINAL_CLASS) return true; if (capability == com.rapidminer.operator.learner.LearnerCapability.WEIGHTED_EXAMPLES) return true; return false; } public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); ParameterType type = new ParameterTypeStringCategory(AbstractTreeLearner.PARAMETER_CRITERION, "Specifies the used criterion for selecting attributes and numerical splits.", CRITERIA_NAMES, CRITERIA_NAMES[CRITERION_INFO_GAIN]); type.setExpert(false); types.add(type); type = new ParameterTypeDouble(PARAMETER_SAMPLE_RATIO, "The sample ratio of training data used for growing and pruning.", 0.0d, 1.0d, 0.9d); type.setExpert(false); types.add(type); types.add(new ParameterTypeDouble(SimpleRuleLearner.PARAMETER_PURENESS, "The desired pureness, i.e. the necessary amount of the major class in a covered subset in order become pure.", 0.0d, 1.0d, 0.9d)); types.add(new ParameterTypeDouble(PARAMETER_MINIMAL_PRUNE_BENEFIT, "The minimum amount of benefit which must be exceeded over unpruned benefit in order to be pruned.", 0.0d, 1.0d, 0.25d)); return types; } }