/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.meta;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.Partition;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.RandomGenerator;
/**
* A metaclassifier for handling multi-class datasets with 2-class classifiers. This class
* supports several strategies for multiclass classification including procedures which are
* capable of using error-correcting output codes for increased accuracy.
*
* @author Helge Homburg
* @version $Id: Binary2MultiClassLearner.java,v 1.11 2008/05/09 19:22:46 ingomierswa Exp $
*/
public class Binary2MultiClassLearner extends AbstractMetaLearner {
/** The parameter name for "What strategy should be used for multi class classifications?" */
public static final String PARAMETER_CLASSIFICATION_STRATEGIES = "classification_strategies";
/** The parameter name for "A multiplier regulating the codeword length in random code modus." */
public static final String PARAMETER_RANDOM_CODE_MULTIPLICATOR = "random_code_multiplicator";
/** The parameter name for "Use the given random seed instead of global random numbers (-1: use global)" */
public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
private static final String[] STRATEGIES = { "1 against all", "1 against 1", "exhaustive code (ECOC)", "random code (ECOC)" };
private static final int ONE_AGAINST_ALL = 0;
private static final int ONE_AGAINST_ONE = 1;
private static final int EXHAUSTIVE_CODE = 2;
private static final int RANDOM_CODE = 3;
/** This List stores a short description for the generated models. */
private LinkedList<String> modelNames = new LinkedList<String>();
/**
* A class which stores all necessary information to train a series of models
* according to a certain classification strategy.
*/
private static class CodePattern {
String[][] data;
boolean[][] partitionEnabled;
public CodePattern(int numberOfClasses, int numberOfFunctions) {
data = new String[numberOfClasses][numberOfFunctions];
partitionEnabled = new boolean[numberOfClasses][numberOfFunctions];
for (int i = 0; i < numberOfClasses; i++) {
for (int j = 0; j < numberOfFunctions; j++) {
partitionEnabled[i][j] = true;
}
}
}
}
public Binary2MultiClassLearner(OperatorDescription description) {
super(description);
}
private SplittedExampleSet constructClassPartitionSet(ExampleSet inputSet){
Attribute classLabel = inputSet.getAttributes().getLabel();
int numberOfClasses = classLabel.getMapping().size();
int[] examples = new int[inputSet.size()];
Iterator<Example> exampleIterator = inputSet.iterator();
int i = 0;
while (exampleIterator.hasNext()) {
Example e = exampleIterator.next();
examples[i] = (int)e.getValue(classLabel);
i++;
}
Partition separatedClasses = new Partition(examples, numberOfClasses);
return new SplittedExampleSet((ExampleSet)inputSet.clone(), separatedClasses);
}
/**
* Trains a series of models depending on the classification method specified by a
* certain code pattern.
*/
private Model[] applyCodePattern(SplittedExampleSet seSet, Attribute classLabel, CodePattern cP) throws OperatorException {
int numberOfClasses = classLabel.getMapping().size();
int numberOfFunctions = cP.data[0].length;
Model[] models = new Model[numberOfFunctions];
// Hash maps are used for addressing particular class values using indices without relying
// upon a consistent index distribution of the corresponding substructure.
HashMap<Integer, Integer> classIndexMap = new HashMap<Integer, Integer> (numberOfClasses);
for (int currentFunction = 0; currentFunction < numberOfFunctions; currentFunction++) {
// 1. Configure a split example set and add a temporary label.
int counter = 0;
seSet.clearSelection();
for (String currentClass : classLabel.getMapping().getValues()) {
classIndexMap.put(classLabel.getMapping().mapString(currentClass), counter);
if (cP.partitionEnabled[counter][currentFunction]) {
seSet.selectAdditionalSubset(classLabel.getMapping().mapString(currentClass));
}
counter++;
}
Attribute workingLabel = AttributeFactory.createAttribute("multiclass_working_label", Ontology.BINOMINAL);
seSet.getExampleTable().addAttribute(workingLabel);
seSet.getAttributes().addRegular(workingLabel);
int currentIndex = 0;
Iterator<Example> iterator = seSet.iterator();
while (iterator.hasNext()) {
Example e = iterator.next();
currentIndex = classIndexMap.get((int)e.getValue(classLabel));
if (cP.partitionEnabled[currentIndex][currentFunction]) {
e.setValue(workingLabel, workingLabel.getMapping().mapString(cP.data[currentIndex][currentFunction]));
}
}
seSet.getAttributes().remove(workingLabel);
seSet.getAttributes().setLabel(workingLabel);
// 2. Apply the example set to the inner learner.
models[currentFunction] = applyInnerLearner(seSet);
inApplyLoop();
// 3. Clean up for the next run.
seSet.getAttributes().setLabel(classLabel);
seSet.getExampleTable().removeAttribute(workingLabel);
}
return models;
}
/**
* Builds a code pattern according to the "1 against all" classification scheme.
*/
private CodePattern buildCodePattern_ONE_VS_ALL(Attribute classLabel) {
int numberOfClasses = classLabel.getMapping().size();
CodePattern cP = new CodePattern(numberOfClasses, numberOfClasses); //, ONE_AGAINST_ALL);
Iterator<String> classIt = classLabel.getMapping().getValues().iterator();
modelNames.clear();
for (int i = 0; i < numberOfClasses; i++) {
for (int j = 0; j < numberOfClasses; j++) {
if (i == j) {
String currentClass = classIt.next();
modelNames.add(currentClass+" vs. all other");
cP.data[i][j] = currentClass;
} else {
cP.data[i][j] = "all_other_classes";
}
}
}
return cP;
}
/**
* Builds a code pattern according to the "1 against 1" classification scheme.
*/
private CodePattern buildCodePattern_ONE_VS_ONE(Attribute classLabel) {
int numberOfClasses = classLabel.getMapping().size();
int numberOfCombinations = (numberOfClasses * (numberOfClasses -1)) / 2;
String[] classIndexMap = new String[numberOfClasses];
CodePattern cP = new CodePattern(numberOfClasses, numberOfCombinations); //, ONE_AGAINST_ONE);
modelNames.clear();
for (int i = 0; i < numberOfClasses; i++) {
for (int j = 0; j < numberOfCombinations; j++) {
cP.partitionEnabled[i][j] = false;
}
}
int classIndex = 0;
for (String className : classLabel.getMapping().getValues()) {
classIndexMap[classIndex] = className;
classIndex++;
}
int currentClassA = 0, currentClassB = 1;
for (int counter = 0; counter < numberOfCombinations; counter++) {
if (currentClassB > (numberOfClasses - 1) ) {
currentClassA++;
currentClassB = currentClassA + 1;
}
if (currentClassA > (numberOfClasses - 2) ) {
break;
}
cP.partitionEnabled[currentClassA][counter] = true;
cP.partitionEnabled[currentClassB][counter] = true;
String currentClassNameA = classIndexMap[currentClassA];
String currentClassNameB = classIndexMap[currentClassB];
cP.data[currentClassA][counter] = currentClassNameA;
cP.data[currentClassB][counter] = currentClassNameB;
modelNames.add(currentClassNameA+" vs. "+currentClassNameB);
currentClassB++;
}
return cP;
}
/**
* Builds a code pattern according to the "exhaustive code" classification scheme.
*/
private CodePattern buildCodePattern_EXHAUSTIVE_CODE(Attribute classLabel) {
int numberOfClasses = classLabel.getMapping().size();
int numberOfFunctions = (int)Math.pow(2, numberOfClasses - 1) - 1;
CodePattern cP = new CodePattern(numberOfClasses, numberOfFunctions); //, EXHAUSTIVE_CODE);
for (int i = 0; i < numberOfFunctions; i++) {
cP.data[0][i] = "true";
}
for (int i = 1; i < numberOfClasses; i++) {
int currentStep = (int)Math.pow(2, numberOfClasses - (i + 1));
for (int j = 0; j < numberOfFunctions; j++) {
cP.data[i][j] = ""+(((j / currentStep) % 2) > 0);
}
}
return cP;
}
/**
* Builds a code pattern according to the "random code" classification scheme.
*/
private CodePattern buildCodePattern_RANDOM_CODE(Attribute classLabel) throws OperatorException {
double multiplicator = getParameterAsDouble(PARAMETER_RANDOM_CODE_MULTIPLICATOR);
int randomSeed = getParameterAsInt(PARAMETER_LOCAL_RANDOM_SEED);
int numberOfClasses = classLabel.getMapping().size();
CodePattern cP = new CodePattern(numberOfClasses, (int)(numberOfClasses * multiplicator)); //, RANDOM_CODE);
Random randomGenerator = RandomGenerator.getRandomGenerator(randomSeed);
for (int i = 0; i < cP.data.length; i++) {
for (int j = 0; j < cP.data[0].length; j++) {
cP.data[i][j] = ""+randomGenerator.nextBoolean();
}
}
//TODO: Improve random codeword quality
// Ensure that each column shows at least one occurrence of "1" (true) or "0" (false),
// otherwise the following two-class classification procedure fails.
for (int i = 0; i < cP.data[0].length; i++) {
boolean containsNoOne = true, containsNoZero = true;
for (int j = 0; j < cP.data.length; j++) {
if ("true".equals(cP.data[j][i])) {
containsNoOne = false;
} else {
containsNoZero = false;
}
}
if (containsNoOne) {
cP.data[(int)(randomGenerator.nextDouble()*(cP.data.length - 1))][i] = "true";
}
if (containsNoZero) {
cP.data[(int)(randomGenerator.nextDouble()*(cP.data.length - 1))][i] = "false";
}
}
return cP;
}
public Model learn(ExampleSet inputSet) throws OperatorException {
Attribute classLabel = inputSet.getAttributes().getLabel();
if (classLabel.getMapping().size() == 2) {
return applyInnerLearner(inputSet);
}
int classificationStrategy = getParameterAsInt(PARAMETER_CLASSIFICATION_STRATEGIES);
CodePattern cP;
Model[] models;
SplittedExampleSet seSet = constructClassPartitionSet(inputSet);
switch (classificationStrategy) {
case ONE_AGAINST_ALL: {
log("Binary2MultiCLassLearner set to <<1-vs-all>>");
cP = buildCodePattern_ONE_VS_ALL(classLabel);
models = applyCodePattern(seSet, classLabel, cP);
return new Binary2MultiClassModel(inputSet, models, classificationStrategy, modelNames);
}
case ONE_AGAINST_ONE: {
log("Binary2MultiCLassLearner set to <<1-vs-1>>");
cP = buildCodePattern_ONE_VS_ONE(classLabel);
models = applyCodePattern(seSet, classLabel, cP);
return new Binary2MultiClassModel(inputSet, models, classificationStrategy, modelNames);
}
case EXHAUSTIVE_CODE: {
log("Binary2MultiCLassLearner set to <<exhaustive code>>");
cP = buildCodePattern_EXHAUSTIVE_CODE(classLabel);
models = applyCodePattern(seSet, classLabel, cP);
return new Binary2MultiClassModel(inputSet, models, classificationStrategy, cP.data);
}
case RANDOM_CODE: {
log("Binary2MultiCLassLearner set to <<random code>>");
cP = buildCodePattern_RANDOM_CODE(classLabel);
models = applyCodePattern(seSet, classLabel, cP);
return new Binary2MultiClassModel(inputSet, models, classificationStrategy, cP.data);
}
default: {
throw new OperatorException("Binary2MultiCLassLearner: Unknown classification strategy selected");
}
}
}
public boolean supportsCapability(LearnerCapability capability) {
if (capability == com.rapidminer.operator.learner.LearnerCapability.POLYNOMINAL_CLASS)
return true;
else
return super.supportsCapability(capability);
}
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
types.add(new ParameterTypeCategory(PARAMETER_CLASSIFICATION_STRATEGIES, "What strategy should be used for multi class classifications?", STRATEGIES, ONE_AGAINST_ALL));
types.add(new ParameterTypeDouble(PARAMETER_RANDOM_CODE_MULTIPLICATOR, "A multiplicator regulating the codeword length in random code modus.", 1.0d, Double.POSITIVE_INFINITY, 2.0d));
types.add(new ParameterTypeInt(PARAMETER_LOCAL_RANDOM_SEED, "Use the given random seed instead of global random numbers (-1: use global)", -1, Integer.MAX_VALUE, -1));
return types;
}
}