/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.set.Partition; import com.rapidminer.example.set.SplittedExampleSet; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.example.table.NominalMapping; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorCapability; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.learner.meta.HierarchicalMultiClassModel.Node; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeList; import com.rapidminer.parameter.ParameterTypeString; import com.rapidminer.tools.RandomGenerator; /** * This is a meta learner for classifying multiple classes using a hierarchical approach. For a higher number of classes * this might prove more accurate than the one versus all or one versus one approach. If applying the models in a binary * tree like structure, less models need to be stored in memory as well as less applications have to be performed before * having a result. * * @author Tobias Malbrecht, Sebastian Land */ public class HierarchicalMultiClassLearner extends AbstractMetaLearner { public static final String PARAMETER_HIERARCHY = "hierarchy"; public static final String PARAMETER_PARENT_CLASS = "parent_class"; public static final String PARAMETER_CHILD_CLASS = "child_class"; public HierarchicalMultiClassLearner(OperatorDescription description) { super(description); } public Model learn(ExampleSet inputSet) throws OperatorException { Attribute labelAttribute = inputSet.getAttributes().getLabel(); // check if label attribute's value set is equal to defined classes checkCompatibility(labelAttribute); // create model hierarchy / tree List<String[]> hierarchyEntryPairs = getParameterList(PARAMETER_HIERARCHY); Map<String, Node> nodeMap = new HashMap<String, Node>(); Set<Node> innerNodes = new HashSet<Node>(); for (String[] entryPair : hierarchyEntryPairs) { String parentClass = entryPair[0]; String childClass = entryPair[1]; Node parentNode = nodeMap.get(parentClass); if (parentNode == null) { parentNode = new Node(parentClass); } Node childNode = nodeMap.get(childClass); if (childNode == null) { childNode = new Node(childClass); } parentNode.addChild(childNode); nodeMap.put(parentClass, parentNode); nodeMap.put(childClass, childNode); innerNodes.add(childNode); } // root node is single node that is not inner Node root = null; for (Node node : nodeMap.values()) { if (!innerNodes.contains(node)) { if (root == null) root = node; else throw new UserError(this, 220, root.getClassName(), node.getClassName()); } } if (root == null) { throw new UserError(this, 221); } // check if each node has at least 2 children or is leaf for (Node node: nodeMap.values()) { if (node.getChildren().size() == 1) throw new UserError(this, 222, node.getClassName(), node.getChildren().size()); } computeModel(root, inputSet, labelAttribute); return new HierarchicalMultiClassModel(inputSet, root); } private void checkCompatibility(Attribute labelAttribute) throws UserError { Set<String> values = new HashSet<String>(labelAttribute.getMapping().getValues()); // add all left hand side for (String[] pair : getParameterList(PARAMETER_HIERARCHY)) { values.add(pair[0]); } String rootValue = null; for (String[] pair : getParameterList(PARAMETER_HIERARCHY)) { // check if right hand side value is either defined as right hand side or is original label value if (!values.contains(pair[1])) { throw new UserError(this, 219, pair[1]); } // check if each left hand side is assigned as right hand side except the root if (!values.contains(pair[0])) { if (rootValue == null) rootValue = pair[0]; else throw new UserError(this, 220, pair[0], rootValue); } } } /** * This method will first create a working label column and after this run through the tree recursivly. */ private void computeModel(HierarchicalMultiClassModel.Node rootNode, ExampleSet exampleSet, Attribute originalLabel) throws OperatorException { // create working label with copy of original label values exampleSet.getAttributes().setSpecialAttribute(originalLabel, "label_original"); Attribute workingLabel = AttributeFactory.createAttribute(originalLabel.getName() + "_working", originalLabel.getValueType()); exampleSet.getExampleTable().addAttribute(workingLabel); exampleSet.getAttributes().addRegular(workingLabel); exampleSet.getAttributes().setLabel(workingLabel); // create partition for recursive learning int[] partitions = new int[exampleSet.size()]; int i = 0; int lastLeafId = -1; for (Example example : exampleSet) { double value = example.getValue(originalLabel); example.setValue(workingLabel, value); partitions[i] = (int) value; if (partitions[i] > lastLeafId) lastLeafId = partitions[i]; i++; } AtomicInteger nonLeafCounter = new AtomicInteger(lastLeafId); setParitionIdRecursivly(rootNode, nonLeafCounter, lastLeafId, workingLabel); // recursively walk through hierarchy and learn computeModelRecursivly(rootNode, partitions, nonLeafCounter.get(), exampleSet); // remove working_label again exampleSet.getAttributes().remove(workingLabel); exampleSet.getAttributes().setLabel(originalLabel); exampleSet.getExampleTable().removeAttribute(workingLabel); } /** * This will set the partition id by either taking the mapping value of the original label mapping if the node is a leaf, * or the next free integer available after the highest entry in the mapping. */ private void setParitionIdRecursivly(Node node, AtomicInteger nonLeafCounter, int maxLeafId, Attribute workingLabel) { if (node.isLeaf()) { node.setPartitionId(workingLabel.getMapping().mapString(node.getClassName())); } else { for (Node child: node.getChildren()) { setParitionIdRecursivly(child, nonLeafCounter, maxLeafId, workingLabel); node.setPartitionId(nonLeafCounter.incrementAndGet()); } } } /** * This method will learn the model tree bottom up by splitting the example set into the partitions defined * by the partitions array and use the ones defined by the child nodes. * @throws OperatorException */ private void computeModelRecursivly(Node node, int[] partitions, int numberOfPartitions, ExampleSet exampleSet) throws OperatorException { if (node.isLeaf()) { return; } else { // first learn all models below for (Node child: node.getChildren()) { computeModelRecursivly(child, partitions, numberOfPartitions, exampleSet); } // then it is assured that there exist partitions with the index of the child nodes. Now use these examples SplittedExampleSet trainSet = new SplittedExampleSet(exampleSet, new Partition(partitions, numberOfPartitions)); Attribute workingLabel = trainSet.getAttributes().getLabel(); workingLabel.setMapping((NominalMapping) workingLabel.getMapping().clone()); workingLabel.getMapping().clear(); for (Node child: node.getChildren()) { trainSet.selectSingleSubset(child.getPartitionId()); int nodeLabelIndex = workingLabel.getMapping().mapString(child.getClassName()); for (Example example: trainSet) { example.setValue(workingLabel, nodeLabelIndex); } } // select all participating subsets trainSet.clearSelection(); for (Node child: node.getChildren()) { trainSet.selectAdditionalSubset(child.getPartitionId()); } // learn model by applying inner learner Model model = applyInnerLearner(trainSet); node.setModel(model); // then replace partition entries of all child nodes with own int partitionId = node.getPartitionId(); for (Node child: node.getChildren()) { int childPartitionId = child.getPartitionId(); for (int i = 0; i < partitions.length; i++) { if (partitions[i] == childPartitionId) partitions[i] = partitionId; } } } } @Override public boolean supportsCapability(OperatorCapability capability) { switch (capability) { case NUMERICAL_LABEL: case NO_LABEL: case UPDATABLE: case FORMULA_PROVIDER: case BINOMINAL_LABEL: case ONE_CLASS_LABEL: return false; default: return true; } } @Override public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); types.add(new ParameterTypeList(PARAMETER_HIERARCHY, "The hierarchy...", new ParameterTypeString(PARAMETER_PARENT_CLASS, "The parent class.", false), new ParameterTypeString(PARAMETER_CHILD_CLASS, "The child class.", false))); types.addAll(RandomGenerator.getRandomGeneratorParameters(this)); return types; } }