/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.rules;
import java.util.LinkedList;
import java.util.List;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.tree.AbstractTreeLearner;
import com.rapidminer.operator.learner.tree.EmptyTermination;
import com.rapidminer.operator.learner.tree.NoAttributeLeftTermination;
import com.rapidminer.operator.learner.tree.SplitCondition;
import com.rapidminer.operator.learner.tree.Terminator;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeStringCategory;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
/**
* <p>
* This operator works similar to the propositional rule learner named Repeated Incremental Pruning
* to Produce Error Reduction (RIPPER, Cohen 1995). Starting with the less prevalent classes, the
* algorithm iteratively grows and prunes rules until there are no positive examples left or the
* error rate is greater than 50%.
* </p>
*
* <p>
* In the growing phase, for each rule greedily conditions are added to the rule until the rule is
* perfect (i.e. 100% accurate). The procedure tries every possible value of each attribute and
* selects the condition with highest information gain.
* </p>
*
* <p>
* In the prune phase, for each rule any final sequences of the antecedents is pruned with the
* pruning metric p/(p+n).
* </p>
*
* @author Sebastian Land, Ingo Mierswa
*/
public class RuleLearner extends AbstractLearner {
private static final String PARAMETER_SAMPLE_RATIO = "sample_ratio";
private static final String PARAMETER_MINIMAL_PRUNE_BENEFIT = "minimal_prune_benefit";
public static final String[] CRITERIA_NAMES = { "information_gain", "accuracy" };
public static final Class<?>[] CRITERIA_CLASSES = { InfoGainCriterion.class, AccuracyCriterion.class };
public static final int CRITERION_INFO_GAIN = 0;
public static final int CRITERION_ACCURACY = 1;
private List<Terminator> terminators = new LinkedList<Terminator>();
public RuleLearner(OperatorDescription description) {
super(description);
}
@Override
public Model learn(ExampleSet exampleSet) throws OperatorException {
// init
terminators.add(new EmptyTermination());
terminators.add(new NoAttributeLeftTermination());
double pureness = getParameterAsDouble(SimpleRuleLearner.PARAMETER_PURENESS);
double sampleRatio = getParameterAsDouble(PARAMETER_SAMPLE_RATIO);
double minimalPruneBenefit = getParameterAsDouble(PARAMETER_MINIMAL_PRUNE_BENEFIT);
boolean useLocalRandomSeed = getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED);
int localRandomSeed = getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED);
Attribute label = exampleSet.getAttributes().getLabel();
RuleModel ruleModel = new RuleModel(exampleSet);
TermDetermination termDetermination = new TermDetermination(createCriterion());
ExampleSet trainingSet = (ExampleSet) exampleSet.clone();
trainingSet.recalculateAttributeStatistics(label);
while (!shouldStop(trainingSet)) {
String labelName = getNextLabel(trainingSet);
Rule rule = new Rule(labelName);
ExampleSet oldTrainingSet = (ExampleSet) trainingSet.clone();
SplittedExampleSet growPruneSet = new SplittedExampleSet(trainingSet, sampleRatio,
SplittedExampleSet.STRATIFIED_SAMPLING, useLocalRandomSeed, true, localRandomSeed);
// growing
SplittedExampleSet growingSet = new SplittedExampleSet(growPruneSet);
growingSet.selectSingleSubset(0);
SplittedExampleSet pruneSet = growPruneSet;
pruneSet.selectSingleSubset(1);
int growOldSize = -1;
ExampleSet growSet = growingSet;
while (growSet.size() > 0 && growSet.size() != growOldSize && !rule.isPure(growSet, pureness)
&& growSet.getAttributes().size() > 0) {
SplitCondition term = termDetermination.getBestTerm(growSet, labelName);
if (term == null) {
break;
}
// before adding: Check benefit if not added
double prunedBenefit = 0;
if (pruneSet.size() > 0) {
prunedBenefit = getPruneBenefit(rule, pruneSet);
}
// add term
rule.addTerm(term);
// pruning
if (pruneSet.size() > 0) {
double unprunedBenefit = getPruneBenefit(rule, pruneSet);
if (unprunedBenefit < prunedBenefit - minimalPruneBenefit) {
rule.removeLastTerm();
// if best new term is pruned: no further extension of the rule
break;
}
}
growOldSize = growSet.size();
// removing uncovered rules
growSet = rule.getCovered(growSet);
// removing attribute
Attribute splitAttribute = growSet.getAttributes().get(term.getAttributeName());
if (splitAttribute.isNominal()) {
growSet.getAttributes().remove(splitAttribute);
}
checkForStop();
}
if (rule.getTerms().size() > 0) {
growSet = rule.getCovered(trainingSet);
growSet.recalculateAttributeStatistics(label);
int[] frequencies = new int[label.getMapping().size()];
int counter = 0;
for (String value : label.getMapping().getValues()) {
frequencies[counter++] = (int) growSet.getStatistics(label, Statistics.COUNT, value);
}
rule.setFrequencies(frequencies);
ruleModel.addRule(rule);
trainingSet = rule.removeCovered(oldTrainingSet);
} else {
break;
}
trainingSet.recalculateAttributeStatistics(label);
checkForStop();
}
// training set not empty? add default rule
if (trainingSet.size() > 0) {
trainingSet.recalculateAttributeStatistics(label);
int index = (int) trainingSet.getStatistics(label, Statistics.MODE);
String defaultLabel = label.getMapping().mapIndex(index);
Rule defaultRule = new Rule(defaultLabel);
int[] frequencies = new int[label.getMapping().size()];
int counter = 0;
for (String value : label.getMapping().getValues()) {
frequencies[counter++] = (int) (trainingSet.getStatistics(label, Statistics.COUNT, value) * sampleRatio);
}
defaultRule.setFrequencies(frequencies);
ruleModel.addRule(defaultRule);
}
return ruleModel;
}
private double getPruneBenefit(Rule rule, ExampleSet exampleSet) {
Attribute label = exampleSet.getAttributes().getLabel();
Attribute weight = exampleSet.getAttributes().getWeight();
double pTotal = 0.0d;
double nTotal = 0.0d;
double p = 0.0d;
double n = 0.0d;
for (Example e : exampleSet) {
double currentWeight = 1.0d;
if (weight != null) {
currentWeight = e.getValue(weight);
}
if (e.getValue(label) == label.getMapping().getIndex(rule.getLabel())) {
pTotal += currentWeight;
} else {
nTotal += currentWeight;
}
if (rule.coversExample(e)) {
if (e.getValue(label) == label.getMapping().getIndex(rule.getLabel())) {
p += currentWeight;
} else {
n += currentWeight;
}
}
}
return (p + nTotal - n) / (pTotal + nTotal);
}
private String getNextLabel(ExampleSet exampleSet) {
Attribute label = exampleSet.getAttributes().getLabel();
int index = (int) exampleSet.getStatistics(label, Statistics.MODE);
return label.getMapping().mapIndex(index);
}
private boolean shouldStop(ExampleSet exampleSet) {
for (Terminator terminator : terminators) {
if (terminator.shouldStop(exampleSet, 0)) {
return true;
}
}
return false;
}
private Criterion createCriterion() throws UndefinedParameterError {
String criterionName = getParameterAsString(AbstractTreeLearner.PARAMETER_CRITERION);
Class<?> criterionClass = null;
for (int i = 0; i < CRITERIA_NAMES.length; i++) {
if (CRITERIA_NAMES[i].equals(criterionName)) {
criterionClass = CRITERIA_CLASSES[i];
}
}
if (criterionClass == null && criterionName != null) {
try {
criterionClass = Tools.classForName(criterionName);
} catch (ClassNotFoundException e) {
logWarning("Cannot find criterion '" + criterionName
+ "' and cannot instantiate a class with this name. Using gain ratio criterion instead.");
}
}
if (criterionClass != null) {
try {
return (Criterion) criterionClass.newInstance();
} catch (InstantiationException e) {
logWarning("Cannot instantiate criterion class '" + criterionClass.getName()
+ "'. Using gain ratio criterion instead.");
return new InfoGainCriterion();
} catch (IllegalAccessException e) {
logWarning("Cannot access criterion class '" + criterionClass.getName()
+ "'. Using gain ratio criterion instead.");
return new InfoGainCriterion();
}
} else {
log("No relevance criterion defined, using gain ratio...");
return new InfoGainCriterion();
}
}
@Override
public Class<? extends PredictionModel> getModelClass() {
return RuleModel.class;
}
@Override
public boolean supportsCapability(OperatorCapability capability) {
switch (capability) {
case BINOMINAL_ATTRIBUTES:
case POLYNOMINAL_ATTRIBUTES:
case NUMERICAL_ATTRIBUTES:
case POLYNOMINAL_LABEL:
case BINOMINAL_LABEL:
case WEIGHTED_EXAMPLES:
case MISSING_VALUES:
return true;
default:
return false;
}
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterType type = new ParameterTypeStringCategory(AbstractTreeLearner.PARAMETER_CRITERION,
"Specifies the used criterion for selecting attributes and numerical splits.", CRITERIA_NAMES,
CRITERIA_NAMES[CRITERION_INFO_GAIN], false);
type.setExpert(false);
types.add(type);
type = new ParameterTypeDouble(PARAMETER_SAMPLE_RATIO,
"The sample ratio of training data used for growing and pruning.", 0.0d, 1.0d, 0.9d);
type.setExpert(false);
types.add(type);
types.add(new ParameterTypeDouble(SimpleRuleLearner.PARAMETER_PURENESS,
"The desired pureness, i.e. the necessary amount of the major class in a covered subset in order become pure.",
0.0d, 1.0d, 0.9d, false));
types.add(new ParameterTypeDouble(PARAMETER_MINIMAL_PRUNE_BENEFIT,
"The minimum amount of benefit which must be exceeded over unpruned benefit in order to be pruned.", 0.0d,
1.0d, 0.25d));
types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
return types;
}
}