/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.clustering;
import java.util.HashMap;
import java.util.Vector;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ProcessSetupError.Severity;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule;
import com.rapidminer.operator.ports.metadata.ExampleSetPrecondition;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.SimpleMetaDataError;
/**
* This operator estimates a mapping between a given clustering and a prediction.
* It adjusts the given clusters with the given labels and so estimates the best fitting pairs.
* @author Regina Fritsch
*/
public class ClusterToPrediction extends Operator {
private final InputPort exampleSetInput = getInputPorts().createPort("example set");
private final InputPort clusterModelInput = getInputPorts().createPort("cluster model", ClusterModel.class);
private final OutputPort exampleSetOutput = getOutputPorts().createPort("example set");
private final OutputPort clusterModelOutput = getOutputPorts().createPort("cluster model");
public ClusterToPrediction(OperatorDescription description) {
super(description);
exampleSetInput.addPrecondition(new ExampleSetPrecondition(exampleSetInput, -1, Attributes.LABEL_NAME, Attributes.CLUSTER_NAME));
getTransformer().addRule(new ExampleSetPassThroughRule(exampleSetInput, exampleSetOutput, null) {
@Override
public MetaData modifyMetaData(MetaData metaData) {
if (metaData instanceof ExampleSetMetaData) {
ExampleSetMetaData emd = (ExampleSetMetaData)metaData;
switch (emd.hasSpecial(Attributes.LABEL_NAME)) {
case NO:
exampleSetInput.addError(new SimpleMetaDataError(Severity.ERROR, exampleSetInput, "special_missing", "label"));
return emd;
case UNKNOWN:
exampleSetInput.addError(new SimpleMetaDataError(Severity.WARNING, exampleSetInput, "special_unknown", "label"));
return emd;
case YES:
AttributeMetaData predictionMD = AttributeMetaData.createPredictionMetaData(emd.getLabelMetaData());
emd.addAttribute(predictionMD);
AttributeMetaData.createConfidenceAttributeMetaData(emd);
return emd;
default:
return emd;
}
}
return metaData;
}
});
getTransformer().addPassThroughRule(clusterModelInput, clusterModelOutput);
}
@Override
public void doWork() throws OperatorException {
ExampleSet exampleSet = exampleSetInput.getData();
ClusterModel model = clusterModelInput.getData();
// generate the predicted attribute
Attribute labelAttribute = exampleSet.getAttributes().getLabel();
PredictionModel.createPredictedLabel(exampleSet, labelAttribute);
Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
HashMap<Integer,String> intToClusterMapping = new HashMap<Integer, String>();
int[][] mappingTable = new int[model.getNumberOfClusters()][model.getNumberOfClusters()];
// count the occurrence of each label with every cluster
int a = 0;
for (int i = 0; i < model.getNumberOfClusters(); i++) {
HashMap<String, Integer> labelOccurrence = new HashMap<String, Integer>();
for (Example example : exampleSet) {
String label = example.getValueAsString(labelAttribute);
if (!labelOccurrence.containsKey(label)) {
labelOccurrence.put(label, 0);
if (i == 0) {
intToClusterMapping.put(a, label);
a++;
}
}
if (example.getValue(example.getAttributes().getCluster()) == i)
labelOccurrence.put(label, labelOccurrence.get(label)+1);
}
if (i == 0 && model.getNumberOfClusters() != labelOccurrence.size()) {
throw new UserError(this, 943, labelOccurrence.size(), model.getNumberOfClusters());
}
for (int j = 0; j < mappingTable[i].length; j++) {
String clusterName = intToClusterMapping.get(j);
int occ = labelOccurrence.get(clusterName);
mappingTable[i][j] = occ;
}
}
/*
* Munkres-algorithm or
* the hungarian method
*/
// find the maximum
int maxValue = -1;
for (int i = 0; i < mappingTable.length; i++) {
for (int j = 0; j < mappingTable[i].length; j++) {
if (mappingTable[i][j] > maxValue) {
maxValue = mappingTable[i][j];
}
}
}
// compute the new (inverted) table (and column-minima)
for (int i = 0; i < mappingTable.length; i++) {
int minimum = Integer.MAX_VALUE;
for (int j = 0; j < mappingTable[i].length; j++) {
mappingTable[i][j] = maxValue - mappingTable[i][j];
if (mappingTable[i][j] < minimum) {
minimum = mappingTable[i][j];
}
}
// subtract the column-minima
if (minimum > 0) {
for (int j = 0; j < mappingTable[i].length; j++) {
mappingTable[i][j] = mappingTable[i][j] - minimum;
}
}
}
// compute and subtract the row-minima
for (int i = 0; i < mappingTable[0].length; i++) {
int minimum = Integer.MAX_VALUE;
for (int j = 0; j < mappingTable.length; j++) {
if (mappingTable[j][i] < minimum) {
minimum = mappingTable[j][i];
}
}
// subtract the row-minima
if (minimum > 0) {
for (int j = 0; j < mappingTable.length; j++) {
mappingTable[j][i] = mappingTable[j][i] - minimum;
}
}
}
while(!assignmentAvailable(mappingTable)) {
Vector<Integer> markedRows = new Vector<Integer>();
Vector<Integer> markedColumns = new Vector<Integer>();
// mark all rows which have no marked zero (start labeling)
for (int i = 0; i < mappingTable[0].length; i++) {
boolean markedZero = false;
for (int j = 0; j < mappingTable.length; j++) {
if (mappingTable[j][i] == Integer.MIN_VALUE) {
markedZero = true;
break;
}
}
if (!markedZero) {
markedRows.add(i);
}
}
boolean newMarked = true;
while (newMarked) {
newMarked = false;
// mark all columns with a slashed zero in a marked row
for (int i = 0; i < mappingTable.length; i++) {
for (int j = 0; j < mappingTable[i].length; j++) {
if (mappingTable[i][j] == Integer.MAX_VALUE) {
if (markedRows.contains(j) && !markedColumns.contains(i)) {
newMarked = true;
markedColumns.add(i);
}
}
}
}
// mark all rows with a marked zero in a marked column
for (int i = 0; i < mappingTable[0].length; i++) {
for(int j = 0; j < mappingTable.length; j++) {
if (mappingTable[j][i] == Integer.MIN_VALUE) {
if (markedColumns.contains(j) && !markedRows.contains(i)) {
newMarked = true;
markedRows.add(i);
}
}
}
}
} // end while (newMarked)
// inverting of the marked columns
for (int i = 0; i < mappingTable.length; i++) {
if (!markedColumns.contains(i)) {
markedColumns.add(i);
} else {
markedColumns.removeElement(i);
}
}
// find the minimum in the marked range
int minimum = Integer.MAX_VALUE;
for (int i = 0; i < markedRows.size(); i++) {
for (int j = 0; j < markedColumns.size(); j++) {
if (mappingTable[markedColumns.get(j)][markedRows.get(i)] < minimum) {
minimum = mappingTable[markedColumns.get(j)][markedRows.get(i)];
}
}
}
// substract the minimum from all elements in the marked range
for (int i = 0; i < markedRows.size(); i++) {
for (int j = 0; j < markedColumns.size(); j++) {
mappingTable[markedColumns.get(j)][markedRows.get(i)] = mappingTable[markedColumns.get(j)][markedRows.get(i)] - minimum;
}
}
// add the minimum to all elements which are neither marked in a row nor in a column
for (int i = 0; i < mappingTable.length; i++) {
if (!markedColumns.contains(i)) {
for (int j = 0; j < mappingTable[i].length; j++) {
if (!markedRows.contains(j)) {
mappingTable[i][j] = mappingTable[i][j] + minimum;
}
}
}
}
// reset the Integer.MIN_VALUE and Integer.MAX_VALUE to zero
for (int i = 0; i < mappingTable.length; i++) {
for (int j = 0; j < mappingTable[i].length; j++) {
if (mappingTable[i][j] == Integer.MAX_VALUE) {
mappingTable[i][j] = 0;
}
if (mappingTable[i][j] == Integer.MIN_VALUE) {
mappingTable[i][j] = 0;
}
}
}
} // end while(!assignmentAvailable)
// compute the mapping (there must be a possible assignment)
HashMap<Integer, String> clusterToPrediction = new HashMap<Integer, String>();
for (int i = 0; i < mappingTable.length; i++) {
int result = -1;
for (int j = 0; j < mappingTable[i].length; j++) {
if (mappingTable[i][j] == Integer.MIN_VALUE) {
result = j;
break;
}
}
String resultCluster = intToClusterMapping.get(result);
clusterToPrediction.put(i, resultCluster);
}
// insert the result in the predicted attribute
HashMap<String, Integer> predictionToCluster = new HashMap<String, Integer>();
// set the preditedLabel in the example table and compute to each prediction the cluster
int i = 0;
Attribute clusterAttribute = exampleSet.getAttributes().getCluster();
for (Example example : exampleSet) {
String resultLabel = clusterToPrediction.get((int)example.getValue(example.getAttributes().getCluster()));
example.setValue(predictedLabel, resultLabel);
if (predictionToCluster.size() < model.getNumberOfClusters()) {
if (!predictionToCluster.containsKey(example.getValueAsString(example.getAttributes().getPredictedLabel()))) {
String clusterNumber = example.getValueAsString(clusterAttribute).replaceAll("[^\\d]+", "");
try {
int number = Integer.parseInt(clusterNumber);
predictionToCluster.put(example.getValueAsString(example.getAttributes().getPredictedLabel()), number);
} catch (NumberFormatException e) {
throw new UserError(this, 145, clusterAttribute.getName());
}
}
}
i++;
}
// set the confidence in the example table
i = 0;
for (Example example : exampleSet) {
if (model.getClass() == FlatFuzzyClusterModel.class) {
FlatFuzzyClusterModel fuzzyModel = (FlatFuzzyClusterModel)model;
for (int j = 0; j < clusterToPrediction.size(); j++) {
String label = clusterToPrediction.get(j);
example.setConfidence(label, fuzzyModel.getExampleInClusterProbability(i, predictionToCluster.get(label)));
}
} else {
example.setConfidence(clusterToPrediction.get((int)example.getValue(example.getAttributes().getCluster())), 1);
}
i++;
}
exampleSetOutput.deliver(exampleSet);
clusterModelOutput.deliver(model);
}
/* Returns true, if there is a solution availble.*/
private boolean assignmentAvailable(int[][] mappingTable) {
int markedZeros = 0;
boolean modificationDone = true;
while(modificationDone) {
while(modificationDone) {
modificationDone = false;
// column by column
for (int i = 0; i < mappingTable.length; i++) {
int position = -1;
for (int j = 0; j < mappingTable[i].length; j++) {
if (mappingTable[i][j] == 0) {
if (position == -1) {
position = j;
} else {
position = -1;
break;
}
}
}
if (position != -1) {
modificationDone = true;
mappingTable[i][position] = Integer.MIN_VALUE; // marked zero
for (int k = 0; k < mappingTable.length; k++) {
if (mappingTable[k][position] == 0) {
mappingTable[k][position] = Integer.MAX_VALUE; // slashed zeros
}
}
markedZeros++;
}
}
if (markedZeros == mappingTable.length) {
return true;
}
// line by line
for (int i = 0; i < mappingTable[0].length; i++) {
int position = -1;
for (int j = 0; j < mappingTable.length; j++) {
if (mappingTable[j][i] == 0) {
if (position == -1) {
position = j;
} else {
position = -1;
break;
}
}
}
if (position != -1) {
modificationDone = true;
mappingTable[position][i] = Integer.MIN_VALUE;// marked zero
for (int k = 0; k < mappingTable[0].length; k++) {
if (mappingTable[position][k] == 0) {
mappingTable[position][k] = Integer.MAX_VALUE; // slashed zeros
}
}
markedZeros++;
}
}
if (markedZeros == mappingTable.length) {
return true;
}
}
// modificationDone is here always false
// ambiguous zeros
int aktMarkedZeros = markedZeros;
for (int i = 0; i < mappingTable.length; i++) {
for (int j = 0; j < mappingTable[i].length; j++) {
if (mappingTable[i][j] == 0) {
mappingTable[i][j] = Integer.MIN_VALUE;// marked zero
for (int k = j+1; k < mappingTable[i].length; k++) {
if (mappingTable[i][k] == 0) {
mappingTable[i][k] = Integer.MAX_VALUE; // slashed zeros in the same column
}
}
for (int k = 0; k < mappingTable.length; k++) {
if (mappingTable[k][j] == 0) {
mappingTable[k][j] = Integer.MAX_VALUE; // slashed zeros
}
}
modificationDone = true;
markedZeros++;
break;
}
}
if (aktMarkedZeros != markedZeros) {
break;
}
}
if (markedZeros == mappingTable.length) {
return true;
}
}
return false;
}
}