/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.postprocessing; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.set.SortedExampleSet; import com.rapidminer.operator.Operator; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.ports.InputPort; import com.rapidminer.operator.ports.OutputPort; import com.rapidminer.operator.ports.metadata.ExampleSetPrecondition; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeBoolean; import com.rapidminer.parameter.ParameterTypeDouble; import com.rapidminer.parameter.ParameterTypeString; import com.rapidminer.tools.Ontology; /** * This operator finds the lowest threshold which reaches a given recall. * * @author Marius Helf */ public class RecallChooser extends Operator { // The parameters of this operator: public static final String PARAMETER_USE_EXAMPLE_WEIGHTS = "use_example_weights"; public static final String PARAMETER_RECALL = "min_recall"; public static final String PARAMETER_POSITIVE_LABEL = "positive_label"; private InputPort exampleSetInput = getInputPorts().createPort("example set", ExampleSet.class); private OutputPort exampleSetOutput = getOutputPorts().createPort("example set"); private OutputPort thresholdOutput = getOutputPorts().createPort("threshold"); public RecallChooser(OperatorDescription description) { super(description); exampleSetInput.addPrecondition(new ExampleSetPrecondition(exampleSetInput, Ontology.VALUE_TYPE, Attributes.LABEL_NAME, Attributes.PREDICTION_NAME, Attributes.CONFIDENCE_NAME)); getTransformer().addPassThroughRule(exampleSetInput, exampleSetOutput); getTransformer().addGenerationRule(thresholdOutput, Threshold.class); } @Override public void doWork() throws OperatorException { ExampleSet exampleSet = exampleSetInput.getData(); boolean useWeights = getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS); // checking preconditions Attribute label = exampleSet.getAttributes().getLabel(); exampleSet.recalculateAttributeStatistics(label); if (label == null) throw new UserError(this, 105); if (!label.isNominal()) throw new UserError(this, 101, label, "threshold finding"); if (label.getMapping().size() != 2) throw new UserError(this, 118, new Object[] { label, Integer.valueOf(label.getMapping().getValues().size()), Integer.valueOf(2) }); if (exampleSet.getAttributes().getPredictedLabel() == null) { throw new UserError(this, 107); } // find positive class String positiveClassName = null; if (isParameterSet(PARAMETER_POSITIVE_LABEL)) { positiveClassName = getParameterAsString(PARAMETER_POSITIVE_LABEL); if (label.getMapping().getIndex(positiveClassName) < 0) { throw new UserError(this, 143, positiveClassName, label.getName()); } } else { if (label.isNominal() && (label.getMapping().size() == 2)) { int positiveIndex = label.getMapping().getPositiveIndex(); positiveClassName = label.getMapping().mapIndex(positiveIndex); } else if (label.isNominal() && (label.getMapping().size() == 1)) { positiveClassName = label.getMapping().mapIndex(0); } else { throw new UserError(this, 954); } } double positiveIndex = label.getMapping().getIndex(positiveClassName); // calculate weighted count of positive class double totalSum = 0; for (Example e : exampleSet) { if (e.getLabel() == positiveIndex) { if (useWeights) { double w = e.getWeight(); if (Double.isNaN(w)) { w = 1.0; } totalSum += w; } else { totalSum += 1.0; } } } // now find the actual threshold double currentSum = 0; double desiredRecall = getParameterAsDouble(PARAMETER_RECALL); double thresholdValue = 0; Attribute confidenceAttribute = exampleSet.getAttributes().getSpecial(Attributes.CONFIDENCE_NAME + "_" + positiveClassName); SortedExampleSet sortedExampleSet = new SortedExampleSet(exampleSet, confidenceAttribute, SortedExampleSet.INCREASING); for (Example e : sortedExampleSet) { if (e.getLabel() == positiveIndex) { if (useWeights) { double w = e.getWeight(); if (Double.isNaN(w)) { w = 1.0; } currentSum += w; } else { currentSum += 1.0; } if (currentSum / totalSum >= 1 - desiredRecall) { break; } thresholdValue = (e.getConfidence(positiveClassName) + thresholdValue) / 2.0; } } // create and return output exampleSetOutput.deliver(exampleSet); thresholdOutput.deliver(new Threshold(thresholdValue, label.getMapping().getNegativeString(), label.getMapping().getPositiveString())); } @Override public List<ParameterType> getParameterTypes() { List<ParameterType> list = super.getParameterTypes(); list.add(new ParameterTypeDouble(PARAMETER_RECALL, "The minimal desired recall on the positive class.", 0, 1, .7, false)); list.add(new ParameterTypeBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS, "Indicates if example weights should be used.", true)); list.add(new ParameterTypeString(PARAMETER_POSITIVE_LABEL, "If set, this value of the label attribute is treated as positive.", true)); return list; } }