package com.blazebit.ai.decisiontree.impl; import com.blazebit.ai.decisiontree.Attribute; import com.blazebit.ai.decisiontree.AttributeSelector; import com.blazebit.ai.decisiontree.DecisionNode; import com.blazebit.ai.decisiontree.DecisionNodeFactory; import com.blazebit.ai.decisiontree.DecisionTree; import com.blazebit.ai.decisiontree.Example; import com.blazebit.ai.decisiontree.Item; import java.util.Collections; import java.util.HashSet; import java.util.Set; /** * * @author Christian Beikov */ public class SimpleDecisionTree<T> implements DecisionTree<T> { private final Set<Attribute> attributes; private final AttributeSelector attributeSelector; private final DecisionNode<T> root; public SimpleDecisionTree(final Set<Attribute> attributes, final Set<Example<T>> examples, final AttributeSelector<T> attributeSelector) { this.attributes = new HashSet<Attribute>(attributes); this.attributeSelector = attributeSelector; this.root = new SimpleDecisionNodeFactory(new HashSet<Attribute>(0)).createNode(null, examples); } @Override public Set<T> apply(final Item test) { return root.apply(test); } @Override public T applySingle(final Item test) { return root.applySingle(test); } private static class LeafNode<T> implements DecisionNode<T> { private final T result; private final Set<T> results; public LeafNode() { this.result = null; this.results = Collections.emptySet(); } public LeafNode(final Set<Example<T>> examples) { final Set<T> tempResults = new HashSet<T>(examples.size()); for (final Example<T> example : examples) { tempResults.add(example.getResult()); } if (tempResults.size() > 1) { this.result = null; } else { this.result = tempResults.iterator().next(); } this.results = Collections.unmodifiableSet(tempResults); } @Override public Attribute getAttribute() { return null; } @Override public Set<T> apply(final Item item) { return results; } @Override public T applySingle(final Item item) { final T localResult = result; if (localResult == null) { throw new IllegalArgumentException("Ambigious result for the given item!"); } return localResult; } } private class SimpleDecisionNodeFactory implements DecisionNodeFactory { private final Set<Attribute> usedAttributes; public SimpleDecisionNodeFactory(final Set<Attribute> usedAttributes) { this.usedAttributes = usedAttributes; } @Override public <T> DecisionNode<T> createNode(final Attribute usedAttribute, final Set<Example<T>> examples) { if (examples.size() < 1) { return new LeafNode<T>(); } final Set<Attribute> localUsedAttributes = usedAttributes; final Set<Attribute> usedAttributesNew; if (usedAttribute != null) { usedAttributesNew = new HashSet<Attribute>(localUsedAttributes.size() + 1); usedAttributesNew.addAll(localUsedAttributes); usedAttributesNew.add(usedAttribute); } else { usedAttributesNew = localUsedAttributes; } final Attribute selectedAttribute = attributeSelector.select(examples, attributes, usedAttributesNew); if (selectedAttribute == null) { return new LeafNode<T>(examples); } return selectedAttribute.createNode(new SimpleDecisionNodeFactory(usedAttributesNew), examples); } } }