/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.postprocessing; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.example.table.NominalMapping; import com.rapidminer.operator.AbstractExampleSetProcessing; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.ports.metadata.ExampleSetPrecondition; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeBoolean; import com.rapidminer.parameter.ParameterTypeInt; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.container.Tupel; /** * This operator will generate predictions of the second, .. n-th most probable * class from the confidences attributes generated by applying a classification model. * * @author Sebastian Land */ public class GeneratePredictionRankingOperator extends AbstractExampleSetProcessing { public static final String PARAMETER_NUMBER_OF_RANKS = "number_of_ranks"; public static final String PARAMETER_REMOVE_OLD_PREDICTIONS = "remove_old_predictions"; public GeneratePredictionRankingOperator(OperatorDescription description) { super(description); getExampleSetInputPort().addPrecondition(new ExampleSetPrecondition(getExampleSetInputPort(), Attributes.PREDICTION_NAME, Ontology.NOMINAL)); getExampleSetInputPort().addPrecondition(new ExampleSetPrecondition(getExampleSetInputPort(), Attributes.CONFIDENCE_NAME, Ontology.NUMERICAL)); } @Override public ExampleSet apply(ExampleSet exampleSet) throws OperatorException { // searching confidence attributes Attributes attributes = exampleSet.getAttributes(); Attribute predictedLabel = attributes.getPredictedLabel(); if (predictedLabel == null) { throw new UserError(this, 107); } NominalMapping mapping = predictedLabel.getMapping(); int numberOfLabels = mapping.size(); Attribute[] confidences = new Attribute[numberOfLabels]; String[] labelValue = new String[numberOfLabels]; int i = 0; for (String value: mapping.getValues()) { labelValue[i] = value; confidences[i] = attributes.getConfidence(value); if (confidences[i] == null) { throw new UserError(this, 154, value); } i++; } // generating new prediction attributes int k = Math.min(numberOfLabels, getParameterAsInt(PARAMETER_NUMBER_OF_RANKS)); Attribute[] kthPredictions = new Attribute[k]; Attribute[] kthConfidences = new Attribute[k]; for (i = 0; i < k; i++) { kthPredictions[i] = AttributeFactory.createAttribute(predictedLabel.getValueType()); kthPredictions[i].setName(predictedLabel.getName() + "_" + (i + 1)); kthPredictions[i].setMapping((NominalMapping) predictedLabel.getMapping().clone()); kthConfidences[i] = AttributeFactory.createAttribute(Ontology.REAL); kthConfidences[i].setName(Attributes.CONFIDENCE_NAME + "_" + (i + 1)); attributes.addRegular(kthPredictions[i]); attributes.addRegular(kthConfidences[i]); attributes.setSpecialAttribute(kthPredictions[i], Attributes.PREDICTION_NAME + "_" + (i + 1)); attributes.setSpecialAttribute(kthConfidences[i], Attributes.CONFIDENCE_NAME + "_" + (i + 1)); } exampleSet.getExampleTable().addAttributes(Arrays.asList(kthConfidences)); exampleSet.getExampleTable().addAttributes(Arrays.asList(kthPredictions)); // now setting values for (Example example: exampleSet) { ArrayList<Tupel<Double, Integer>> labelConfidences = new ArrayList<Tupel<Double,Integer>>(numberOfLabels); for (i = 0; i < numberOfLabels; i++) { labelConfidences.add(new Tupel<Double, Integer>(example.getValue(confidences[i]), i)); } Collections.sort(labelConfidences); for (i = 0; i < k; i++) { Tupel<Double, Integer> tupel = labelConfidences.get(numberOfLabels - i - 1); example.setValue(kthPredictions[i], tupel.getSecond()); // Can use index since mapping has been cloned from above example.setValue(kthConfidences[i], tupel.getFirst()); } } // deleting old prediction / confidences attributes.remove(predictedLabel); if (getParameterAsBoolean(PARAMETER_REMOVE_OLD_PREDICTIONS)) { for (i = 0; i < confidences.length; i++) { attributes.remove(confidences[i]); } } return exampleSet; } @Override public boolean writesIntoExistingData() { return false; } @Override public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); types.add(new ParameterTypeInt(PARAMETER_NUMBER_OF_RANKS, "This determines how many ranks will be considered. ", 2, Integer.MAX_VALUE, false)); types.add(new ParameterTypeBoolean(PARAMETER_REMOVE_OLD_PREDICTIONS, "This indicates if the old confidence attributes should be removed.", true, false)); return types; } }