/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.postprocessing; import java.util.HashMap; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.ports.metadata.ExampleSetPrecondition; import com.rapidminer.operator.preprocessing.AbstractDataProcessing; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeCategory; import com.rapidminer.parameter.ParameterTypeDouble; import com.rapidminer.parameter.ParameterTypeList; import com.rapidminer.parameter.ParameterTypeString; import com.rapidminer.parameter.conditions.EqualTypeCondition; import com.rapidminer.tools.Ontology; /** * This operator sets all predictions which do not have a higher confidence than the * specified one to "unknown" (missing value). This operator is a quite simple * version of the CostBasedThresholdLearner which might be useful in simple binominal * classification settings (although it does also work for polynominal classifications). * * @author Ingo Mierswa */ public class SimpleUncertainPredictionsTransformation extends AbstractDataProcessing { public static final String PARAMETER_CLASS_HANDLING = "class_handling"; public static final String[] CLASS_HANDLING_MODES = { "balanced" , "unbalanced" }; public static final int CLASS_HANDLING_BALANCED = 0; public static final int CLASS_HANDLING_UNBALANCED = 1; public static final String PARAMETER_MIN_CONFIDENCE = "min_confidence"; public static final String PARAMETER_MIN_CONFIDENCES = "min_confidences"; public static final String PARAMETER_CLASS_VALUE = "class"; public SimpleUncertainPredictionsTransformation(OperatorDescription description) { super(description); getExampleSetInputPort().addPrecondition(new ExampleSetPrecondition(getExampleSetInputPort(), Ontology.VALUE_TYPE, Attributes.PREDICTION_NAME, Attributes.CONFIDENCE_NAME)); } @Override public ExampleSet apply(ExampleSet exampleSet) throws OperatorException { // checks Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel(); if (predictedLabel == null) { throw new UserError(this, 107); } if (!predictedLabel.isNominal()) { throw new UserError(this, 119, predictedLabel, getName()); } switch (getParameterAsInt(PARAMETER_CLASS_HANDLING)) { case CLASS_HANDLING_BALANCED: double minConfidence = getParameterAsDouble(PARAMETER_MIN_CONFIDENCE); for (Example example : exampleSet) { double predictionValue = example.getValue(predictedLabel); String predictionClass = predictedLabel.getMapping().mapIndex((int) predictionValue); double confidence = example.getConfidence(predictionClass); if (!Double.isNaN(confidence)) { if (confidence < minConfidence) { example.setValue(predictedLabel, Double.NaN); } } } break; case CLASS_HANDLING_UNBALANCED: HashMap<String, Double> thresholdMap = new HashMap<String, Double>(); for (String[] threshold : getParameterList(PARAMETER_MIN_CONFIDENCES)) { thresholdMap.put(threshold[0], Double.valueOf(threshold[1])); } for (Example example : exampleSet) { double predictionValue = example.getValue(predictedLabel); String predictionClass = predictedLabel.getMapping().mapIndex((int) predictionValue); double confidence = example.getConfidence(predictionClass); Double threshold = thresholdMap.get(predictionClass); if (!Double.isNaN(confidence) && threshold != null) { if (confidence < threshold.doubleValue()) { example.setValue(predictedLabel, Double.NaN); } } } break; } return exampleSet; } @Override public boolean writesIntoExistingData() { return true; } @Override public List<ParameterType> getParameterTypes() { List<ParameterType> list = super.getParameterTypes(); list.add(new ParameterTypeCategory(PARAMETER_CLASS_HANDLING, "The mode which defines if all classes are handled equally or if class individual thresholds are set.", CLASS_HANDLING_MODES, CLASS_HANDLING_BALANCED, false)); ParameterType type = new ParameterTypeDouble(PARAMETER_MIN_CONFIDENCE, "The minimal confidence necessary for not setting the prediction to 'unknown'.", 0.0d, 1.0d, 0.5d); type.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_CLASS_HANDLING, CLASS_HANDLING_MODES, true, CLASS_HANDLING_BALANCED)); type.setExpert(false); list.add(type); type = new ParameterTypeList(PARAMETER_MIN_CONFIDENCES, "A list which defines individual thresholds for classes.", new ParameterTypeString(PARAMETER_CLASS_VALUE, "The class for which the confidence threshold should be set."), new ParameterTypeDouble(PARAMETER_MIN_CONFIDENCE, "The minimal confidence necessary for not setting the prediction to 'unknown'.", 0.0d, 1.0d, 0.5d), false); type.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_CLASS_HANDLING, CLASS_HANDLING_MODES, true, CLASS_HANDLING_UNBALANCED)); list.add(type); return list; } }