/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.subgroups; import java.util.Iterator; import java.util.LinkedHashSet; import java.util.LinkedList; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.SimplePredictionModel; import com.rapidminer.operator.learner.subgroups.hypothesis.Rule; import com.rapidminer.operator.learner.subgroups.utility.UtilityFunction; import com.rapidminer.tools.Tools; /** * A model consisting of rules which are scored by utility values. * Only the best rule (according to its utility) is used for prediction * at the moment. * * @author Tobias Malbrecht */ public class RuleSet extends SimplePredictionModel implements Iterable<Rule> { private boolean predictUncoveredRules = false; private static final long serialVersionUID = -47885282272818733L; private LinkedList<Rule> rules = null; private LinkedHashSet<UtilityFunction> utilityFunctions = null; public RuleSet(ExampleSet exampleSet) { super(exampleSet); rules = new LinkedList<Rule>(); utilityFunctions = new LinkedHashSet<UtilityFunction>(); } public void addRule(Rule rule) { rules.add(rule); utilityFunctions.addAll(rule.getUtilityFunctions()); } public Rule getRule(int index) { return rules.get(index); } public int getNumberOfRules() { return rules.size(); } public Iterator<Rule> iterator() { return rules.iterator(); } public LinkedList<Rule> getPositiveRules() { LinkedList<Rule> positiveRules = new LinkedList<Rule>(); for (Rule rule : this) { if (rule.predictsPositive()) { positiveRules.add(rule); } } return positiveRules; } public LinkedList<Rule> getNegativeRules() { LinkedList<Rule> negativeRules = new LinkedList<Rule>(); for (Rule rule : this) { if (!rule.predictsPositive()) { negativeRules.add(rule); } } return negativeRules; } public int size() { return getNumberOfRules(); } @Override public double predict(Example example) throws OperatorException { for (Rule rule : rules) { if (rule.applicable(example)) { return (rule.getPrediction()); } } return (predictUncoveredRules ? example.getAttributes().getLabel().getMapping().getNegativeIndex() : Double.NaN); } public UtilityFunction[] getUtilityFunctions() { UtilityFunction[] functions = new UtilityFunction[utilityFunctions.size()]; functions = utilityFunctions.toArray(functions); return functions; } @Override public String toString() { StringBuffer stringBuffer = new StringBuffer(); int i = 0; for (Rule rule : rules) { if (i < 10) { stringBuffer.append(rule.toStringScored()); stringBuffer.append(Tools.getLineSeparator()); } i++; } if (i > 10) { stringBuffer.append(Tools.getLineSeparators(2)); stringBuffer.append("... and " + (i-10) + " more rules!"); } return stringBuffer.toString(); } }