/** * Copyright (C) 2001-2017 by RapidMiner and the contributors * * Complete list of developers available at our web site: * * http://rapidminer.com * * This program is free software: you can redistribute it and/or modify it under the terms of the * GNU Affero General Public License as published by the Free Software Foundation, either version 3 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License along with this program. * If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.preprocessing.filter; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.logging.Level; import com.rapidminer.example.Attribute; import com.rapidminer.example.AttributeRole; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.SimpleAttributes; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.example.table.ViewAttribute; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.OperatorProgress; import com.rapidminer.operator.ProcessStoppedException; import com.rapidminer.operator.preprocessing.PreprocessingModel; import com.rapidminer.tools.LogService; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.Tools; import com.rapidminer.tools.container.Pair; /** * The model class for the {@link NominalToNumericModel} operator. Can either transform nominals to * numeric by simply replacing the nominal values by the respective integer mapping, or by using * effect coding or dummy coding. * * @author Marius Helf */ public class NominalToNumericModel extends PreprocessingModel { private static final long serialVersionUID = -4203775081616082145L; private int codingType; /** * maps a target attribute to the value for which it becomes one (for dummy coding) */ private Map<String, Double> attributeTo1ValueMap = null; /** * maps a target attribute to the value for which it becomes -1 or 1 respectively (for effect * coding). The first value of the pair is the value for 1, the second for -1. */ private Map<String, Pair<Double, Double>> attributeToValuesMap = null; /** * maps an original attribute name to the list of all its values which occurred in the training * data. Used for dummy and effect coding. */ private Map<String, List<String>> attributeToAllNominalValues = null; /** * maps source attributes to their comparison group. */ private Map<String, Double> sourceAttributeToComparisonGroupMap = null; /** * maps target attributes in the output set to their respective source attributes in the * training set. */ private Map<String, String> targetAttributeToSourceAttributeMap = null; /** * maps source attributes to the comparison group string. Only used with dummy/effect coding. * * This map is *only* used for displaying the model (i.e. in toResultString()). */ private Map<String, String> sourceAttributeToComparisonGroupStringsMap = null; /** * Relevant only when using dummy coding or effect coding. * * If true, the naming scheme for target attributes is "sourceAttribute_value", if false, * "sourceAttribute = value" */ private boolean useUnderscoreInName = false; private boolean useComparisonGroups = false; // how unexpected values are handled. One of ALL_ZEROES_AND_WARNING or ERROR_AND_ABORT. private int unexpectedValueHandling = NominalToNumeric.ALL_ZEROES_AND_WARNING; /** * Constructs a new model. Use this ctor to create a model for value encoding. * * @param exampleSet * @param codingType * the coding type. Should be NominalToNumeric.INTEGERS when called manually. */ public NominalToNumericModel(ExampleSet exampleSet, int codingType) { super(exampleSet); this.codingType = codingType; } /** * Constructs a new model. Use this ctor to create a model for dummy encoding or effect * encoding. * * @param exampleSet * @param codingType * the coding type. Should be NominalToNumeric.EFFECT_CODING or DUMMY_CODING. * @param useUnderscoreInName * @see NominalToNumericModel#useUnderscoreInName * @param sourceAttributeToComparisonGroupMap * @see NominalToNumericModel#sourceAttributeToComparisonGroupMap @see * NominalToNumeric#getSourceAttributeToComparisonGroupMap * @param attributeTo1ValueMap * @see NominalToNumericModel#attributeTo1ValueMap be non-null for dummy coding, should be null * for effect coding. @see NominalToNumeric#getAttributeTo1ValueMap * @param attributeToValuesMap * @see NominalToNumericModel#attributeToValuesMap be non-null for effect coding, should be null * for dummy coding. @see NominalToNumeric#getAttributeToValuesMap * @param useComparisonGroup * Indicates if comparison groups for dummy coding should be used. Is ignored if * codingType == EFFECT_CODING. * @param unexpectedValueHandling * Defines how unexpected values are handled. @see * NominalToNumericModel#unexpectedValueHandling. */ public NominalToNumericModel(ExampleSet exampleSet, int codingType, boolean useUnderscoreInName, Map<String, Double> sourceAttributeToComparisonGroupMap, Map<String, Double> attributeTo1ValueMap, Map<String, Pair<Double, Double>> attributeToValuesMap, boolean useComparisonGroups, int unexpectedValueHandling) { this(exampleSet, codingType); this.useUnderscoreInName = useUnderscoreInName; this.sourceAttributeToComparisonGroupMap = sourceAttributeToComparisonGroupMap; this.attributeTo1ValueMap = attributeTo1ValueMap; this.attributeToValuesMap = attributeToValuesMap; this.useComparisonGroups = useComparisonGroups || codingType == NominalToNumeric.EFFECT_CODING; this.unexpectedValueHandling = unexpectedValueHandling; if (useComparisonGroups) { // store comparison group strings for display assert sourceAttributeToComparisonGroupMap != null; // must not be null for // dummy/effect coding sourceAttributeToComparisonGroupStringsMap = new LinkedHashMap<>(); for (Map.Entry<String, Double> entry : sourceAttributeToComparisonGroupMap.entrySet()) { String attributeName = entry.getKey(); double comparisonGroup = entry.getValue(); Attribute attribute = exampleSet.getAttributes().get(attributeName); String comparisonGroupString = attribute.getMapping().mapIndex((int) comparisonGroup); sourceAttributeToComparisonGroupStringsMap.put(attributeName, comparisonGroupString); } } if (codingType == NominalToNumeric.DUMMY_CODING || codingType == NominalToNumeric.EFFECT_CODING) { // remember all nominal values from training data attributeToAllNominalValues = new HashMap<>(); for (Attribute attribute : exampleSet.getAttributes()) { if (!attribute.isNumerical()) { String attributeName = attribute.getName(); List<String> values = new LinkedList<>(); for (String value : attribute.getMapping().getValues()) { values.add(value); } attributeToAllNominalValues.put(attributeName, values); } } // remember source attributes for each target attribute targetAttributeToSourceAttributeMap = new HashMap<>(); for (Attribute sourceAttribute : exampleSet.getAttributes()) { if (!sourceAttribute.isNumerical()) { String sourceAttributeName = sourceAttribute.getName(); for (String targetAttribute : getTargetAttributesFromSourceAttribute(sourceAttribute)) { targetAttributeToSourceAttributeMap.put(targetAttribute, sourceAttributeName); } } } } } @Override public ExampleSet applyOnData(ExampleSet exampleSet) throws OperatorException { switch (codingType) { case NominalToNumeric.INTEGERS_CODING: return applyOnDataIntegers(exampleSet); case NominalToNumeric.DUMMY_CODING: return applyOnDataDummyCoding(exampleSet, false); case NominalToNumeric.EFFECT_CODING: return applyOnDataDummyCoding(exampleSet, true); default: assert false; // codingType must be one of the above return null; } } /** * Returns a list containing the names of those attributes which will represent the coding of * the given source attribute. */ private List<String> getTargetAttributesFromSourceAttribute(Attribute sourceAttribute) { List<String> targetNames = new ArrayList<>(); double comparisonGroup = -1; if (useComparisonGroups) { comparisonGroup = sourceAttributeToComparisonGroupMap.get(sourceAttribute.getName()); } List<String> originalAttributeValues = attributeToAllNominalValues.get(sourceAttribute.getName()); String comparisonGroupValue = null; if (comparisonGroup != -1) { comparisonGroupValue = originalAttributeValues.get((int) comparisonGroup); } for (String currentValue : originalAttributeValues) { if (!useComparisonGroups || !currentValue.equals(comparisonGroupValue)) { targetNames.add(NominalToNumeric.getTargetAttributeName(sourceAttribute.getName(), currentValue, useUnderscoreInName)); } } return targetNames; } /** * Creates a dummy coding or effect coding from the given example set. * * @param effectCoding * If true, the function does effect coding. If false, dummy coding. * @throws ProcessStoppedException */ private ExampleSet applyOnDataDummyCoding(ExampleSet exampleSet, boolean effectCoding) throws ProcessStoppedException { // selecting transformation attributes and creating new numeric attributes List<Attribute> nominalAttributes = new ArrayList<>(); List<Attribute> transformedAttributes = new ArrayList<>(); Map<Attribute, List<Attribute>> targetAttributesFromSources = new HashMap<>(); for (Attribute attribute : exampleSet.getAttributes()) { if (!attribute.isNumerical()) { nominalAttributes.add(attribute); List<String> targetNames = getTargetAttributesFromSourceAttribute(attribute); List<Attribute> targets = new ArrayList<>(); for (String targetName : targetNames) { Attribute createAttribute = AttributeFactory.createAttribute(targetName, Ontology.INTEGER); transformedAttributes.add(createAttribute); targets.add(createAttribute); } targetAttributesFromSources.put(attribute, targets); } } // ensuring capacity in ExampleTable exampleSet.getExampleTable().addAttributes(transformedAttributes); for (Attribute attribute : transformedAttributes) { exampleSet.getAttributes().addRegular(attribute); } // initialize progress long progressCompletedCounter = 0; long progressTotal = (long) nominalAttributes.size() * exampleSet.size(); OperatorProgress progress = null; if (getShowProgress() && getOperator() != null && getOperator().getProgress() != null) { progress = getOperator().getProgress(); progress.setTotal(1000); } // copying values for (Attribute nominalAttribute : nominalAttributes) { for (Example example : exampleSet) { double sourceValue = example.getValue(nominalAttribute); for (Attribute targetAttribute : targetAttributesFromSources.get(nominalAttribute)) { example.setValue(targetAttribute, getValue(targetAttribute, sourceValue)); } if (++progressCompletedCounter % 10_000 == 0) { progress.setCompleted((int) (1000.0d * progressCompletedCounter / progressTotal)); } } } // remove nominal attributes for (Attribute nominalAttribute : nominalAttributes) { exampleSet.getAttributes().remove(nominalAttribute); } return exampleSet; } /** * Transforms the numerical attributes to integer values (corresponding to the internal * mapping). * * @throws ProcessStoppedException */ private ExampleSet applyOnDataIntegers(ExampleSet exampleSet) throws ProcessStoppedException { // selecting transformation attributes and creating new numeric attributes List<Attribute> nominalAttributes = new ArrayList<>(); LinkedList<Attribute> transformedAttributes = new LinkedList<>(); for (Attribute attribute : exampleSet.getAttributes()) { if (!attribute.isNumerical()) { nominalAttributes.add(attribute); // creating new attributes for nominal attributes transformedAttributes.add(AttributeFactory.createAttribute(attribute.getName(), Ontology.NUMERICAL)); } } // ensuring capacity in ExampleTable exampleSet.getExampleTable().addAttributes(transformedAttributes); // initialize progress long progressCompletedCounter = 0; long progressTotal = (long) nominalAttributes.size() * exampleSet.size(); OperatorProgress progress = null; if (getShowProgress() && getOperator() != null && getOperator().getProgress() != null) { progress = getOperator().getProgress(); progress.setTotal(1000); } // copying values Iterator<Attribute> target = transformedAttributes.iterator(); for (Attribute attribute : nominalAttributes) { Attribute targetAttribute = target.next(); for (Example example : exampleSet) { example.setValue(targetAttribute, example.getValue(attribute)); if (progress != null && ++progressCompletedCounter % 100_000 == 0) { progress.setCompleted((int) (1000.0d * progressCompletedCounter / progressTotal)); } } } // removing nominal attributes from example Set Attributes attributes = exampleSet.getAttributes(); for (Attribute attribute : exampleSet.getAttributes()) { if (!attribute.isNumerical()) { attributes.replace(attribute, transformedAttributes.poll()); } } return exampleSet; } @Override public Attributes getTargetAttributes(ExampleSet parentSet) { SimpleAttributes attributes = new SimpleAttributes(); // add special attributes to new attributes Iterator<AttributeRole> specialRoles = parentSet.getAttributes().specialAttributes(); while (specialRoles.hasNext()) { attributes.add(specialRoles.next()); } // add regular attributes for (Attribute attribute : parentSet.getAttributes()) { if (!attribute.isNumerical()) { if (codingType == NominalToNumeric.EFFECT_CODING || codingType == NominalToNumeric.DUMMY_CODING) { double comparisonGroup = -1; if (useComparisonGroups) { comparisonGroup = sourceAttributeToComparisonGroupMap.get(attribute.getName()); } List<String> valueList = attributeToAllNominalValues.get(attribute.getName()); if (valueList != null) { int currentValue = 0; for (String attributeValue : valueList) { if (currentValue != comparisonGroup) { ViewAttribute viewAttribute = new ViewAttribute(this, attribute, NominalToNumeric .getTargetAttributeName(attribute.getName(), attributeValue, useUnderscoreInName), Ontology.INTEGER, null); attributes.addRegular(viewAttribute); } ++currentValue; } } } else if (codingType == NominalToNumeric.INTEGERS_CODING) { attributes.addRegular(new ViewAttribute(this, attribute, attribute.getName(), Ontology.INTEGER, null)); } else { assert false; // unsupported coding } } else { attributes.addRegular(attribute); } } return attributes; } @Override public double getValue(Attribute targetAttribute, double value) { if (codingType == NominalToNumeric.DUMMY_CODING) { String targetName = targetAttribute.getName(); Double oneValue = attributeTo1ValueMap.get(targetName); if (oneValue != null && oneValue == value) { return 1; } else { // check if the value has been present in the training set if (unexpectedValueHandling != NominalToNumeric.ALL_ZEROES_AND_NO_WARNING && !isValueInTrainingSet(targetAttribute, value)) { handleUnexpectedValue(targetName); } return 0; } } else if (codingType == NominalToNumeric.EFFECT_CODING) { String targetName = targetAttribute.getName(); Pair<Double, Double> storedValue = attributeToValuesMap.get(targetName); if (storedValue.getFirst() == value) { return 1; } else if (storedValue.getSecond() == value) { return -1; } else { // check if the value has been present in the training set if (unexpectedValueHandling != NominalToNumeric.ALL_ZEROES_AND_NO_WARNING && !isValueInTrainingSet(targetAttribute, value)) { handleUnexpectedValue(targetName); } return 0; } } else if (codingType == NominalToNumeric.INTEGERS_CODING) { return value; } else { assert false; // unsupported coding return Double.NaN; } } private int handleUnexpectedValue(String targetName) { switch (unexpectedValueHandling) { case NominalToNumeric.ALL_ZEROES_AND_WARNING: LogService.getRoot().log(Level.WARNING, "com.rapidminer.operator.preprocessing.filter.NominalToNumericModel.unexpected_value", targetName); return 0; case NominalToNumeric.ALL_ZEROES_AND_NO_WARNING: return 0; default: assert false; // should be one of the above values return 0; } } private boolean isValueInTrainingSet(Attribute targetAttribute, double value) { String sourceAttribute = targetAttributeToSourceAttributeMap.get(targetAttribute.getName()); if (sourceAttribute != null) { List<String> trainingValues = attributeToAllNominalValues.get(sourceAttribute); if (trainingValues != null) { int valueCount = trainingValues.size(); if (value >= valueCount) { return false; } } else { return false; } } return true; } @Override public String getName() { return "Nominal2Numerical Model"; } @Override public String toResultString() { StringBuilder builder = new StringBuilder(); Attributes trainAttributes = getTrainingHeader().getAttributes(); builder.append(getName() + Tools.getLineSeparators(2)); String codingTypeString = ""; switch (codingType) { case NominalToNumeric.INTEGERS_CODING: codingTypeString = "unique integers"; break; case NominalToNumeric.DUMMY_CODING: codingTypeString = "dummy coding"; break; case NominalToNumeric.EFFECT_CODING: codingTypeString = "effect coding"; break; } builder.append("Coding Type: " + codingTypeString + Tools.getLineSeparator()); if (!useComparisonGroups) { builder.append("Model covering " + trainAttributes.size() + " attributes:" + Tools.getLineSeparator()); for (Attribute attribute : trainAttributes) { builder.append(" - " + attribute.getName() + Tools.getLineSeparator()); } } else { builder.append("Model covering " + trainAttributes.size() + " attributes (with comparison group):" + Tools.getLineSeparator()); for (Attribute attribute : trainAttributes) { builder.append(" - " + attribute.getName() + " ('" + sourceAttributeToComparisonGroupStringsMap.get(attribute.getName()) + "')" + Tools.getLineSeparator()); } } return builder.toString(); } public int getCodingType() { return codingType; } public Map<String, Double> getAttributeTo1ValueMap() { return attributeTo1ValueMap; } public Map<String, Pair<Double, Double>> getAttributeToValuesMap() { return attributeToValuesMap; } public Map<String, List<String>> getAttributeToAllNominalValues() { return attributeToAllNominalValues; } public Map<String, Double> getSourceAttributeToComparisonGroupMap() { return sourceAttributeToComparisonGroupMap; } public Map<String, String> getTargetAttributeToSourceAttributeMap() { return targetAttributeToSourceAttributeMap; } public boolean shouldUseUnderscoreInName() { return useUnderscoreInName; } public boolean shouldUseComparisonGroups() { return useComparisonGroups; } public int getUnexpectedValueHandling() { return unexpectedValueHandling; } }