/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.meta;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.Partition;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.meta.HierarchicalMultiClassModel.Node;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.tools.RandomGenerator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
/**
* This is a meta learner for classifying multiple classes using a hierarchical approach. For a
* higher number of classes this might prove more accurate than the one versus all or one versus one
* approach. If applying the models in a binary tree like structure, less models need to be stored
* in memory as well as less applications have to be performed before having a result.
*
* @author Tobias Malbrecht, Sebastian Land
*/
public class HierarchicalMultiClassLearner extends AbstractMetaLearner {
public static final String PARAMETER_HIERARCHY = "hierarchy";
public static final String PARAMETER_PARENT_CLASS = "parent_class";
public static final String PARAMETER_CHILD_CLASS = "child_class";
public HierarchicalMultiClassLearner(OperatorDescription description) {
super(description);
}
@Override
public Model learn(ExampleSet inputSet) throws OperatorException {
Attribute labelAttribute = inputSet.getAttributes().getLabel();
// check if label attribute's value set is equal to defined classes
checkCompatibility(labelAttribute);
// create model hierarchy / tree
List<String[]> hierarchyEntryPairs = getParameterList(PARAMETER_HIERARCHY);
Map<String, Node> nodeMap = new HashMap<String, Node>();
Set<Node> innerNodes = new HashSet<Node>();
for (String[] entryPair : hierarchyEntryPairs) {
String parentClass = entryPair[0];
String childClass = entryPair[1];
Node parentNode = nodeMap.get(parentClass);
if (parentNode == null) {
parentNode = new Node(parentClass);
}
Node childNode = nodeMap.get(childClass);
if (childNode == null) {
childNode = new Node(childClass);
}
parentNode.addChild(childNode);
nodeMap.put(parentClass, parentNode);
nodeMap.put(childClass, childNode);
innerNodes.add(childNode);
}
// root node is single node that is not inner
Node root = null;
for (Node node : nodeMap.values()) {
if (!innerNodes.contains(node)) {
if (root == null) {
root = node;
} else {
throw new UserError(this, 220, root.getClassName(), node.getClassName());
}
}
}
if (root == null) {
throw new UserError(this, 221);
}
// check if each node has at least 2 children or is leaf
for (Node node : nodeMap.values()) {
if (node.getChildren().size() == 1) {
throw new UserError(this, 222, node.getClassName(), node.getChildren().size());
}
}
computeModel(root, inputSet, labelAttribute);
return new HierarchicalMultiClassModel(inputSet, root);
}
private void checkCompatibility(Attribute labelAttribute) throws UserError {
Set<String> values = new HashSet<String>(labelAttribute.getMapping().getValues());
// add all left hand side
List<String[]> hierarchy = getParameterList(PARAMETER_HIERARCHY);
for (String[] pair : hierarchy) {
values.add(pair[0]);
}
String rootValue = null;
for (String[] pair : hierarchy) {
// check if right hand side value is either defined as right hand side or is original
// label value
if (!values.contains(pair[1])) {
throw new UserError(this, 219, pair[1]);
}
// check if each left hand side is assigned as right hand side except the root
if (!values.contains(pair[0])) {
if (rootValue == null) {
rootValue = pair[0];
} else {
throw new UserError(this, 220, pair[0], rootValue);
}
}
}
}
/**
* This method will first create a working label column and after this run through the tree
* recursivly.
*/
private void computeModel(HierarchicalMultiClassModel.Node rootNode, ExampleSet exampleSet, Attribute originalLabel)
throws OperatorException {
// create working label with copy of original label values
exampleSet.getAttributes().setSpecialAttribute(originalLabel, "label_original");
Attribute workingLabel = AttributeFactory.createAttribute(originalLabel.getName() + "_working",
originalLabel.getValueType());
exampleSet.getExampleTable().addAttribute(workingLabel);
exampleSet.getAttributes().addRegular(workingLabel);
exampleSet.getAttributes().setLabel(workingLabel);
// create partition for recursive learning
int[] partitions = new int[exampleSet.size()];
int i = 0;
int lastLeafId = -1;
for (Example example : exampleSet) {
double value = example.getValue(originalLabel);
example.setValue(workingLabel, value);
partitions[i] = (int) value;
if (partitions[i] > lastLeafId) {
lastLeafId = partitions[i];
}
i++;
}
AtomicInteger nonLeafCounter = new AtomicInteger(lastLeafId);
setParitionIdRecursivly(rootNode, nonLeafCounter, lastLeafId, workingLabel);
// recursively walk through hierarchy and learn
computeModelRecursivly(rootNode, partitions, nonLeafCounter.get(), exampleSet);
// remove working_label again
exampleSet.getAttributes().remove(workingLabel);
exampleSet.getAttributes().setLabel(originalLabel);
exampleSet.getExampleTable().removeAttribute(workingLabel);
}
/**
* This will set the partition id by either taking the mapping value of the original label
* mapping if the node is a leaf, or the next free integer available after the highest entry in
* the mapping.
*/
private void setParitionIdRecursivly(Node node, AtomicInteger nonLeafCounter, int maxLeafId, Attribute workingLabel) {
if (node.isLeaf()) {
node.setPartitionId(workingLabel.getMapping().mapString(node.getClassName()));
} else {
for (Node child : node.getChildren()) {
setParitionIdRecursivly(child, nonLeafCounter, maxLeafId, workingLabel);
node.setPartitionId(nonLeafCounter.incrementAndGet());
}
}
}
/**
* This method will learn the model tree bottom up by splitting the example set into the
* partitions defined by the partitions array and use the ones defined by the child nodes.
*
* @throws OperatorException
*/
private void computeModelRecursivly(Node node, int[] partitions, int numberOfPartitions, ExampleSet exampleSet)
throws OperatorException {
if (node.isLeaf()) {
return;
} else {
// first learn all models below
for (Node child : node.getChildren()) {
computeModelRecursivly(child, partitions, numberOfPartitions, exampleSet);
}
// then it is assured that there exist partitions with the index of the child nodes. Now
// use these examples
SplittedExampleSet trainSet = new SplittedExampleSet(exampleSet, new Partition(partitions, numberOfPartitions));
Attribute workingLabel = trainSet.getAttributes().getLabel();
workingLabel.setMapping((NominalMapping) workingLabel.getMapping().clone());
workingLabel.getMapping().clear();
for (Node child : node.getChildren()) {
trainSet.selectSingleSubset(child.getPartitionId());
int nodeLabelIndex = workingLabel.getMapping().mapString(child.getClassName());
for (Example example : trainSet) {
example.setValue(workingLabel, nodeLabelIndex);
}
}
// select all participating subsets
trainSet.clearSelection();
for (Node child : node.getChildren()) {
trainSet.selectAdditionalSubset(child.getPartitionId());
}
// learn model by applying inner learner
Model model = applyInnerLearner(trainSet);
node.setModel(model);
// then replace partition entries of all child nodes with own
int partitionId = node.getPartitionId();
for (Node child : node.getChildren()) {
int childPartitionId = child.getPartitionId();
for (int i = 0; i < partitions.length; i++) {
if (partitions[i] == childPartitionId) {
partitions[i] = partitionId;
}
}
}
}
}
@Override
public boolean supportsCapability(OperatorCapability capability) {
switch (capability) {
case NUMERICAL_LABEL:
case NO_LABEL:
case UPDATABLE:
case FORMULA_PROVIDER:
case BINOMINAL_LABEL:
case ONE_CLASS_LABEL:
return false;
default:
return true;
}
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
types.add(new ParameterTypeList(PARAMETER_HIERARCHY, "The hierarchy...", new ParameterTypeString(
PARAMETER_PARENT_CLASS, "The parent class.", false), new ParameterTypeString(PARAMETER_CHILD_CLASS,
"The child class.", false)));
types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
return types;
}
}