/** * Copyright (C) 2001-2017 by RapidMiner and the contributors * * Complete list of developers available at our web site: * * http://rapidminer.com * * This program is free software: you can redistribute it and/or modify it under the terms of the * GNU Affero General Public License as published by the Free Software Foundation, either version 3 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License along with this program. * If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.performance; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.tools.Tools; import com.rapidminer.tools.math.Averagable; /** * Measures the accuracy and classification error for both binary classification problems and multi * class problems. Additionally, this performance criterion can also compute the kappa statistics * for multi class problems. This is calculated as k = (P(A) - P(E)) / (1 - P(E)) with [ P(A) = * diagonal sum / number of examples ] and [ P(E) = sum over i of ((sum of i-th row * sum of i-th * column) / (n to the power of 2) ]. * * @author Ingo Mierswa */ public class MultiClassificationPerformance extends MeasuredPerformance { private static final long serialVersionUID = 3068421566038331525L; /** Indicates an undefined type (should not happen). */ public static final int UNDEFINED = -1; /** Indicates accuracy. */ public static final int ACCURACY = 0; /** Indicates classification error. */ public static final int ERROR = 1; /** Indicates kappa statistics. */ public static final int KAPPA = 2; /** The names of the criteria. */ public static final String[] NAMES = { "accuracy", "classification_error", "kappa" }; /** The descriptions of the criteria. */ public static final String[] DESCRIPTIONS = { "Relative number of correctly classified examples", "Relative number of misclassified examples", "The kappa statistics for the classification" }; /** * The counter for true labels and the prediction. */ private double[][] counter; /** The class names of the label. Used for logging and result display. */ private String[] classNames; /** Maps class names to indices. */ private Map<String, Integer> classNameMap = new HashMap<String, Integer>(); /** The currently used label attribute. */ private Attribute labelAttribute; /** The currently used predicted label attribute. */ private Attribute predictedLabelAttribute; /** The weight attribute. Might be null. */ private Attribute weightAttribute; /** The type of this performance: accuracy or classification error. */ private int type = ACCURACY; /** Creates a MultiClassificationPerformance with undefined type. */ public MultiClassificationPerformance() { this(UNDEFINED); } /** Creates a MultiClassificationPerformance with the given type. */ public MultiClassificationPerformance(int type) { this.type = type; } /** Clone constructor. */ public MultiClassificationPerformance(MultiClassificationPerformance m) { super(m); this.type = m.type; this.classNames = new String[m.classNames.length]; for (int i = 0; i < this.classNames.length; i++) { this.classNames[i] = m.classNames[i]; this.classNameMap.put(this.classNames[i], i); } this.counter = new double[m.counter.length][m.counter.length]; for (int i = 0; i < this.counter.length; i++) { for (int j = 0; j < this.counter[i].length; j++) { this.counter[i][j] = m.counter[i][j]; } } this.labelAttribute = (Attribute) m.labelAttribute.clone(); this.predictedLabelAttribute = (Attribute) m.predictedLabelAttribute.clone(); if (m.weightAttribute != null) { this.weightAttribute = (Attribute) m.weightAttribute.clone(); } } /** Creates a MultiClassificationPerformance with the given type. */ public static MultiClassificationPerformance newInstance(String name) { for (int i = 0; i < NAMES.length; i++) { if (NAMES[i].equals(name)) { return new MultiClassificationPerformance(i); } } return null; } @Override public double getExampleCount() { double total = 0; for (int i = 0; i < counter.length; i++) { for (int j = 0; j < counter[i].length; j++) { total += counter[i][j]; } } return total; } /** Initializes the criterion and sets the label. */ @Override public void startCounting(ExampleSet eSet, boolean useExampleWeights) throws OperatorException { super.startCounting(eSet, useExampleWeights); this.labelAttribute = eSet.getAttributes().getLabel(); if (!this.labelAttribute.isNominal()) { throw new UserError(null, 101, "calculation of classification performance criteria", this.labelAttribute.getName()); } this.predictedLabelAttribute = eSet.getAttributes().getPredictedLabel(); if (this.predictedLabelAttribute == null || !this.predictedLabelAttribute.isNominal()) { throw new UserError(null, 101, "calculation of classification performance criteria", "predicted label attribute"); } if (useExampleWeights) { this.weightAttribute = eSet.getAttributes().getWeight(); } Collection<String> labelValues = this.labelAttribute.getMapping().getValues(); Collection<String> predictedLabelValues = this.predictedLabelAttribute.getMapping().getValues(); // searching for greater mapping for making symmetric matrix in case of different mapping // sizes Collection<String> unionedMapping = new LinkedHashSet<String>(labelValues); unionedMapping.addAll(predictedLabelValues); this.counter = new double[unionedMapping.size()][unionedMapping.size()]; this.classNames = new String[unionedMapping.size()]; int n = 0; for (String labelValue : unionedMapping) { classNames[n] = labelValue; classNameMap.put(classNames[n], n); n++; } } /** Increases the prediction value in the matrix. */ @Override public void countExample(Example example) { int label = classNameMap.get(example.getNominalValue(labelAttribute)); int plabel = classNameMap.get(example.getNominalValue(predictedLabelAttribute)); double weight = 1.0d; if (weightAttribute != null) { weight = example.getValue(weightAttribute); } counter[label][plabel] += weight; } /** Returns either the accuracy or the classification error. */ @Override public double getMikroAverage() { double diagonal = 0, total = 0; for (int i = 0; i < counter.length; i++) { diagonal += counter[i][i]; for (int j = 0; j < counter[i].length; j++) { total += counter[i][j]; } } if (total == 0) { return Double.NaN; } // returns either the accuracy, the error, or the kappa statistics double accuracy = diagonal / total; switch (type) { case ACCURACY: return accuracy; case ERROR: return 1.0d - accuracy; case KAPPA: double pa = accuracy; double pe = 0.0d; for (int i = 0; i < counter.length; i++) { double row = 0.0d; double column = 0.0d; for (int j = 0; j < counter[i].length; j++) { row += counter[i][j]; column += counter[j][i]; } // pe += ((row * column) / Math.pow(total, counter.length)); pe += row * column / (total * total); } return (pa - pe) / (1.0d - pe); default: throw new RuntimeException("Unknown type " + type + " for multi class performance criterion!"); } } /** Returns true. */ @Override public boolean formatPercent() { if (type == KAPPA) { return false; } else { return true; } } @Override public double getMikroVariance() { return Double.NaN; } /** Returns the name. */ @Override public String getName() { return NAMES[type]; } /** Returns the description. */ @Override public String getDescription() { return DESCRIPTIONS[type]; } // ================================================================================ /** Returns the accuracy or 1 - error. */ @Override public double getFitness() { if (type == ERROR) { return 1.0d - getAverage(); } else { return getAverage(); } } /** Returns 1. */ @Override public double getMaxFitness() { return 1.0d; } @Override public void buildSingleAverage(Averagable performance) { MultiClassificationPerformance other = (MultiClassificationPerformance) performance; // can only add the counter matrices if they have the same "headers" (classNames) if (this.classNames.length == other.classNames.length && Arrays.equals(this.classNames, other.classNames)) { for (int i = 0; i < this.counter.length; i++) { for (int j = 0; j < this.counter[i].length; j++) { this.counter[i][j] += other.counter[i][j]; } } } else { // if the classNames are different build the union and create new counter matrix // associated to this union String[] unionClassNames = combineArrays(this.classNames, other.classNames); double[][] unionCounter = new double[unionClassNames.length][unionClassNames.length]; Map<String, Integer> unionClassNameMap = new HashMap<String, Integer>(); for (int i = 0; i < unionClassNames.length; i++) { String nameI = unionClassNames[i]; unionClassNameMap.put(nameI, i); for (int j = 0; j < unionClassNames.length; j++) { String nameJ = unionClassNames[j]; double thisValue = 0; Integer indexNameI = this.classNameMap.get(nameI); Integer indexNameJ = this.classNameMap.get(nameJ); if (indexNameI != null && indexNameJ != null) { thisValue = this.counter[indexNameI][indexNameJ]; } double otherValue = 0; indexNameI = other.classNameMap.get(nameI); indexNameJ = other.classNameMap.get(nameJ); if (indexNameI != null && indexNameJ != null) { otherValue = other.counter[indexNameI][indexNameJ]; } unionCounter[i][j] = thisValue + otherValue; } } this.classNames = unionClassNames; this.classNameMap = unionClassNameMap; this.counter = unionCounter; } } /** * Builds the union of the arrays, taking first the elements of the firstArray and then the new * elements of the secondArray. * * @param firstArray * @param secondArray * @return the union of firstArray and secondArray */ private String[] combineArrays(String[] firstArray, String[] secondArray) { Set<String> unionSet = new LinkedHashSet<>(); unionSet.addAll(Arrays.asList(firstArray)); unionSet.addAll(Arrays.asList(secondArray)); return unionSet.toArray(new String[unionSet.size()]); } // ================================================================================ @Override public String toString() { StringBuffer result = new StringBuffer(super.toString()); result.append(Tools.getLineSeparator() + "ConfusionMatrix:" + Tools.getLineSeparator() + "True:"); for (int i = 0; i < this.counter.length; i++) { result.append("\t" + classNames[i]); } for (int i = 0; i < this.counter.length; i++) { result.append(Tools.getLineSeparator() + classNames[i] + ":"); for (int j = 0; j < this.counter[i].length; j++) { result.append("\t" + Tools.formatIntegerIfPossible(this.counter[j][i])); } } return result.toString(); } public String getTitle() { return super.toString(); } public String[] getClassNames() { return classNames; } public double[][] getCounter() { return counter; } }