/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.subgroups.hypothesis; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import java.io.Serializable; import java.util.Collection; import java.util.LinkedList; import java.util.LinkedHashMap; /** * This is a hypothesis for subgroup discovery. * * @author Tobias Malbrecht */ public class Hypothesis implements Serializable { private static class AttributeQueue extends LinkedList<Attribute> { private static final long serialVersionUID = 8693212785374243323L; private AttributeQueue() { } private AttributeQueue(Iterable<Attribute> attributes) { addAll(attributes); } private void addAll(Iterable<Attribute> attributes) { for (Attribute attribute : attributes) { add(attribute); } } } private static final long serialVersionUID = 8694312785374243323L; public static final int POSITIVE_RULE = 0; public static final int NEGATIVE_RULE = 1; public static final int PREDICTION_RULE = 2; public static final int POSITIVE_AND_NEGATIVE_RULES = 3; public static final String[] RULE_GENERATION_MODES = { "positive" , "negative" , "prediction" , "both" }; private LinkedHashMap<Attribute,Literal> literalMap = null; private AttributeQueue restrictedAttributes = null; private double coveredWeight = 0.0d; private double positiveWeight = 0.0d; public Hypothesis() { literalMap = new LinkedHashMap<Attribute,Literal>(); } public Hypothesis(Collection<Literal> literals) { this(); for (Literal literal : literals) { literalMap.put(literal.getAttribute(), literal); } } public void apply(Example example) { if (applicable(example)) { double weight = 1.0d; if (example.getAttributes().getWeight() != null) { weight = example.getWeight(); } coveredWeight += weight; if (example.getLabel() == example.getAttributes().getLabel().getMapping().getPositiveIndex()) { positiveWeight += weight; } } } public boolean applicable(Example example) { for (Literal literal : literalMap.values()) { if (! literal.applicable(example)) { return false; } } return true; } private Hypothesis refine(Attribute attribute, double value) { Hypothesis hypothesis = clone(); hypothesis.literalMap.put(attribute, new Literal(attribute, value)); return hypothesis; } public LinkedList<Hypothesis> refine(Iterable<Attribute> attributes) { LinkedList<Hypothesis> hypotheses = new LinkedList<Hypothesis>(); for (Attribute attribute : attributes) { if (! literalMap.containsKey(attribute)) { for (String valueString : attribute.getMapping().getValues()) { hypotheses.add(refine(attribute, attribute.getMapping().mapString(valueString))); } } } return hypotheses; } public LinkedList<Hypothesis> restrictedRefine(Iterable<Attribute> attributes) { AttributeQueue restrictedAttributes = new AttributeQueue(attributes); LinkedList<Hypothesis> hypotheses = new LinkedList<Hypothesis>(); Attribute attribute = null; while ((attribute = restrictedAttributes.poll()) != null) { if (! literalMap.containsKey(attribute)) { for (String valueString : attribute.getMapping().getValues()) { Hypothesis hypothesis = refine(attribute, attribute.getMapping().mapString(valueString)); hypothesis.restrictedAttributes = new AttributeQueue(restrictedAttributes); hypotheses.add(hypothesis); } } } return hypotheses; } public LinkedList<Hypothesis> restrictedRefine() { if (restrictedAttributes == null) { return null; } return restrictedRefine(restrictedAttributes); } public Hypothesis subsume(Hypothesis otherHypothesis) { LinkedList<Literal> newLiterals = new LinkedList<Literal>(); for (Literal otherLiteral : otherHypothesis.literalMap.values()) { Literal correspondingLiteral = literalMap.get(otherLiteral.getAttribute()); if (correspondingLiteral == null) { continue; } if (otherLiteral.equals(correspondingLiteral)) { newLiterals.add(otherLiteral); continue; } if (otherLiteral.contradicts(literalMap.get(otherLiteral.getAttribute()))) { return null; } } return new Hypothesis(newLiterals); } public Hypothesis combine(Hypothesis otherHypothesis) { LinkedList<Literal> newLiterals = new LinkedList<Literal>(); for (Literal otherLiteral : otherHypothesis.literalMap.values()) { Literal correspondingLiteral = literalMap.get(otherLiteral.getAttribute()); if (correspondingLiteral == null) { newLiterals.add(otherLiteral); continue; } if (otherLiteral.contradicts(literalMap.get(otherLiteral.getAttribute()))) { return null; } } for (Literal literal : literalMap.values()) { for (Literal otherLiteral : newLiterals) { if (literal.equals(otherLiteral)) { continue; } } newLiterals.add(literal); } return new Hypothesis(newLiterals); } private Rule getPredictionRule(Attribute label) { double predictionIndex = positiveWeight / coveredWeight > 0.5 ? label.getMapping().getPositiveIndex() : label.getMapping().getNegativeIndex(); return new Rule(this, new Literal(label, predictionIndex)); } private Rule getPositiveRule(Attribute label) { return new Rule(this, new Literal(label, label.getMapping().getPositiveIndex())); } private Rule getNegativeRule(Attribute label) { return new Rule(this, new Literal(label, label.getMapping().getNegativeIndex())); } public LinkedList<Rule> generateRules(int ruleGenerationMode, Attribute label) { LinkedList<Rule> rules = new LinkedList<Rule>(); switch (ruleGenerationMode) { case POSITIVE_RULE: rules.add(getPositiveRule(label)); break; case NEGATIVE_RULE: rules.add(getNegativeRule(label)); break; case PREDICTION_RULE: rules.add(getPredictionRule(label)); break; case POSITIVE_AND_NEGATIVE_RULES: rules.add(getPositiveRule(label)); rules.add(getNegativeRule(label)); break; } return rules; } private Literal getLiteral(Attribute attribute) { return literalMap.get(attribute); } private Collection<Literal> getLiterals() { return literalMap.values(); } public int getNumberOfLiterals() { return literalMap.size(); } public double getCoveredWeight() { return coveredWeight; } public double getPositiveWeight() { return positiveWeight; } @Override public boolean equals(Object object) { if (object == null) { return false; } if (getClass() != object.getClass()) { return false; } Hypothesis otherHypothesis = (Hypothesis) object; for (Literal literal : getLiterals()) { if (! literal.equals(otherHypothesis.getLiteral(literal.getAttribute()))) { return false; } } return true; } @Override public Hypothesis clone() { Hypothesis newHypothesis = new Hypothesis(); newHypothesis.literalMap.putAll(literalMap); return newHypothesis; } @Override public String toString() { StringBuffer stringBuffer = new StringBuffer(); for (Literal literal : literalMap.values()) { stringBuffer.append(literal + " , "); } if (stringBuffer.length() > 3) { return stringBuffer.substring(0, stringBuffer.length() - 3); } else { return stringBuffer.toString(); } } }