/** * Copyright (C) 2001-2017 by RapidMiner and the contributors * * Complete list of developers available at our web site: * * http://rapidminer.com * * This program is free software: you can redistribute it and/or modify it under the terms of the * GNU Affero General Public License as published by the Free Software Foundation, either version 3 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License along with this program. * If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.postprocessing; import com.rapidminer.example.Attribute; import com.rapidminer.example.AttributeRole; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.operator.AbstractExampleSetProcessing; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.ports.metadata.ExampleSetPrecondition; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeString; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.container.Tupel; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; /** * Generates predictions from confidence attributes. * * @author Tobias Malbrecht */ public class GeneratePredictionOperator extends AbstractExampleSetProcessing { public static final String PARAMETER_PREDICTION_NAME = "prediction_name"; public GeneratePredictionOperator(OperatorDescription description) { super(description); getExampleSetInputPort().addPrecondition( new ExampleSetPrecondition(getExampleSetInputPort(), Attributes.CONFIDENCE_NAME, Ontology.NUMERICAL)); } @Override public ExampleSet apply(ExampleSet exampleSet) throws OperatorException { // searching confidence attributes Map<Attribute, String> confidenceAttributes = new LinkedHashMap<Attribute, String>(); for (Iterator<AttributeRole> iterator = exampleSet.getAttributes().specialAttributes(); iterator.hasNext();) { AttributeRole role = iterator.next(); if (role.getSpecialName().matches(Attributes.CONFIDENCE_NAME + "_.*")) { confidenceAttributes.put(role.getAttribute(), role.getSpecialName().replaceAll("^" + Attributes.CONFIDENCE_NAME + "_(.*)$", "$1")); } } if (confidenceAttributes.size() > 0) { String predictionName = getParameterAsString(PARAMETER_PREDICTION_NAME); String attributeName = "prediction(" + predictionName + ")"; Attribute predictionAttribute = AttributeFactory.createAttribute(attributeName, Ontology.NOMINAL); // check if an attribute with the resulting name already exists Attribute oldAttribute = exampleSet.getAttributes().get(attributeName); if (oldAttribute != null) { if (exampleSet.getAttributes().getSpecial(Attributes.PREDICTION_NAME) == oldAttribute) { // remove it iff it is the prediction attribute (since it would be removed later // anyway, but causes an error if not removed here) exampleSet.getAttributes().remove(oldAttribute); } else { // otherwise throw an error throw new UserError(this, 152, attributeName); } } for (String value : confidenceAttributes.values()) { predictionAttribute.getMapping().mapString(value); } exampleSet.getExampleTable().addAttribute(predictionAttribute); exampleSet.getAttributes().addRegular(predictionAttribute); exampleSet.getAttributes().setSpecialAttribute(predictionAttribute, Attributes.PREDICTION_NAME); for (Example example : exampleSet) { ArrayList<Tupel<Double, String>> labelConfidences = new ArrayList<Tupel<Double, String>>( confidenceAttributes.size()); for (Map.Entry<Attribute, String> entry : confidenceAttributes.entrySet()) { labelConfidences.add(new Tupel<Double, String>(example.getValue(entry.getKey()), entry.getValue())); } Collections.sort(labelConfidences); example.setValue( predictionAttribute, predictionAttribute.getMapping().mapString( labelConfidences.get(labelConfidences.size() - 1).getSecond())); } } return exampleSet; } @Override public boolean writesIntoExistingData() { return false; } @Override public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); types.add(new ParameterTypeString(PARAMETER_PREDICTION_NAME, "The name of the label that should be predicted.", false, false)); return types; } }