/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.preprocessing.filter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.SimpleAttributes;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.ViewAttribute;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.preprocessing.PreprocessingModel;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.container.Pair;
/**
* The model class for the {@link NominalToNumericModel} operator. Can either
* transform nominals to numeric by simply replacing the nominal values by the
* respective integer mapping, or by using effect coding or dummy coding.
*
* @author Marius Helf
*/
class NominalToNumericModel extends PreprocessingModel {
private static final long serialVersionUID = -4203775081616082145L;
private int codingType;
/**
* maps a target attribute to the value for which it becomes one (for dummy coding)
*/
private Map<String,Double> attributeTo1ValueMap = null;
/**
* maps a target attribute to the value for which it becomes -1 or 1 respectively (for effect coding).
* The first value of the pair is the value for 1, the second for -1.
*/
private Map<String,Pair<Double,Double>> attributeToValuesMap = null;
/**
* maps an original attribute name to the list of all its values which occurred in the training data.
* Used for dummy and effect coding.
*/
private Map<String,List<String>> attributeToAllNominalValues = null;
/**
* maps source attributes to their comparison group.
*/
private Map<String,Double> sourceAttributeToComparisonGroupMap = null;
/**
* maps target attributes in the output set to their respective source attributes in the training set.
*/
private Map<String,String> targetAttributeToSourceAttributeMap = null;
/**
* maps source attributes to the comparison group string.
* Only used with dummy/effect coding.
*
* This map is *only* used for displaying the model (i.e. in toResultString()).
*/
private Map<String,String> sourceAttributeToComparisonGroupStringsMap = null;
/**
* Relevant only when using dummy coding or effect coding.
*
* If true, the naming scheme for target attributes is "sourceAttribute_value",
* if false, "sourceAttribute = value"
*/
private boolean useUnderscoreInName = false;
private boolean useComparisonGroups = false;
// how unexpected values are handled. One of ALL_ZEROES_AND_WARNING or ERROR_AND_ABORT.
private int unexpectedValueHandling = NominalToNumeric.ALL_ZEROES_AND_WARNING;
/**
* Constructs a new model. Use this ctor to create a model for value encoding.
* @param exampleSet
* @param codingType the coding type. Should be NominalToNumeric.INTEGERS when called manually.
*/
protected NominalToNumericModel(ExampleSet exampleSet, int codingType) {
super(exampleSet);
this.codingType = codingType;
}
/**
* Constructs a new model. Use this ctor to create a model for dummy encoding or effect encoding.
* @param exampleSet
* @param codingType the coding type. Should be NominalToNumeric.EFFECT_CODING or DUMMY_CODING.
* @param useUnderscoreInName @see NominalToNumericModel#useUnderscoreInName
* @param sourceAttributeToComparisonGroupMap @see NominalToNumericModel#sourceAttributeToComparisonGroupMap @see NominalToNumeric#getSourceAttributeToComparisonGroupMap
* @param attributeTo1ValueMap @see NominalToNumericModel#attributeTo1ValueMap. Must be non-null for dummy coding, should be null for effect coding. @see NominalToNumeric#getAttributeTo1ValueMap
* @param attributeToValuesMap @see NominalToNumericModel#attributeToValuesMap. Must be non-null for effect coding, should be null for dummy coding. @see NominalToNumeric#getAttributeToValuesMap
* @param useComparisonGroup Indicates if comparison groups for dummy coding should be used. Is ignored if codingType == EFFECT_CODING.
* @param unexpectedValueHandling Defines how unexpected values are handled. @see NominalToNumericModel#unexpectedValueHandling.
*/
protected NominalToNumericModel(
ExampleSet exampleSet,
int codingType,
boolean useUnderscoreInName,
Map<String,Double> sourceAttributeToComparisonGroupMap,
Map<String,Double> attributeTo1ValueMap,
Map<String, Pair<Double,Double>> attributeToValuesMap,
boolean useComparisonGroups,
int unexpectedValueHandling) {
this(exampleSet, codingType);
this.useUnderscoreInName = useUnderscoreInName;
this.sourceAttributeToComparisonGroupMap = sourceAttributeToComparisonGroupMap;
this.attributeTo1ValueMap = attributeTo1ValueMap;
this.attributeToValuesMap = attributeToValuesMap;
this.useComparisonGroups = useComparisonGroups || (codingType == NominalToNumeric.EFFECT_CODING);
this.unexpectedValueHandling = unexpectedValueHandling;
if (useComparisonGroups) {
// store comparison group strings for display
assert(sourceAttributeToComparisonGroupMap != null); // must not be null for dummy/effect coding
sourceAttributeToComparisonGroupStringsMap = new LinkedHashMap<String, String>();
for ( Map.Entry<String, Double> entry : sourceAttributeToComparisonGroupMap.entrySet() ) {
String attributeName = entry.getKey();
double comparisonGroup = entry.getValue();
Attribute attribute = exampleSet.getAttributes().get(attributeName);
String comparisonGroupString = attribute.getMapping().mapIndex((int)comparisonGroup);
sourceAttributeToComparisonGroupStringsMap.put(attributeName, comparisonGroupString);
}
}
if (codingType == NominalToNumeric.DUMMY_CODING || codingType == NominalToNumeric.EFFECT_CODING) {
// remember all nominal values from training data
attributeToAllNominalValues = new HashMap<String, List<String>>();
for (Attribute attribute : exampleSet.getAttributes()) {
if (!attribute.isNumerical()) {
String attributeName = attribute.getName();
List<String> values = new LinkedList<String>();
for (String value : attribute.getMapping().getValues()) {
values.add(value);
}
attributeToAllNominalValues.put(attributeName, values);
}
}
// remember source attributes for each target attribute
targetAttributeToSourceAttributeMap = new HashMap<String, String>();
for (Attribute sourceAttribute : exampleSet.getAttributes()) {
if (!sourceAttribute.isNumerical()) {
String sourceAttributeName = sourceAttribute.getName();
for (String targetAttribute : getTargetAttributesFromSourceAttribute(sourceAttribute)) {
targetAttributeToSourceAttributeMap.put(targetAttribute, sourceAttributeName);
}
}
}
}
}
@Override
public ExampleSet applyOnData(ExampleSet exampleSet) throws OperatorException {
switch(codingType) {
case NominalToNumeric.INTEGERS_CODING:
return applyOnDataIntegers(exampleSet);
case NominalToNumeric.DUMMY_CODING:
return applyOnDataDummyCoding(exampleSet, false);
case NominalToNumeric.EFFECT_CODING:
return applyOnDataDummyCoding(exampleSet, true);
default:
assert(false); // codingType must be one of the above
return null;
}
}
/**
* Returns a list containing the names of those attributes which will
* represent the coding of the given source attribute.
*/
private List<String> getTargetAttributesFromSourceAttribute(Attribute sourceAttribute) {
List<String> targetNames = new LinkedList<String>();
double comparisonGroup = -1;
if ( useComparisonGroups ) {
comparisonGroup = sourceAttributeToComparisonGroupMap.get( sourceAttribute.getName() );
}
List<String> originalAttributeValues = attributeToAllNominalValues.get(sourceAttribute.getName());
String comparisonGroupValue = null;
if (comparisonGroup != -1) {
comparisonGroupValue = originalAttributeValues.get((int)comparisonGroup);
}
for ( String currentValue : originalAttributeValues ) {
if ( (!useComparisonGroups) || (!currentValue.equals(comparisonGroupValue)) ) {
targetNames.add( NominalToNumeric.getTargetAttributeName(sourceAttribute.getName(), currentValue, useUnderscoreInName) );
}
}
return targetNames;
}
/**
* Creates a dummy coding or effect coding from the given example set.
* @param effectCoding If true, the function does effect coding. If false, dummy coding.
*/
private ExampleSet applyOnDataDummyCoding(ExampleSet exampleSet, boolean effectCoding ) {
// selecting transformation attributes and creating new numeric attributes
LinkedList<Attribute> nominalAttributes = new LinkedList<Attribute>();
LinkedList<Attribute> transformedAttributes = new LinkedList<Attribute>();
for (Attribute attribute : exampleSet.getAttributes()) {
if (!attribute.isNumerical()) {
nominalAttributes.add(attribute);
List<String> targetNames = getTargetAttributesFromSourceAttribute(attribute);
for ( String targetName : targetNames ) {
transformedAttributes.add(AttributeFactory.createAttribute(targetName, Ontology.INTEGER));
}
}
}
// ensuring capacity in ExampleTable
exampleSet.getExampleTable().addAttributes(transformedAttributes);
for ( Attribute attribute : transformedAttributes ) {
exampleSet.getAttributes().addRegular(attribute);
}
// copying values
for (Example example: exampleSet) {
for ( Attribute nominalAttribute : nominalAttributes ) {
double sourceValue = example.getValue(nominalAttribute);
for ( String targetName : getTargetAttributesFromSourceAttribute(nominalAttribute) ) {
Attribute targetAttribute = exampleSet.getAttributes().get(targetName);
example.setValue(targetAttribute, getValue(targetAttribute, sourceValue));
}
}
}
// remove nominal attributes
for ( Attribute nominalAttribute : nominalAttributes ) {
exampleSet.getAttributes().remove(nominalAttribute);
}
return exampleSet;
}
/**
* Transforms the numerical attributes to integer values (corresponding to the internal mapping).
*/
private ExampleSet applyOnDataIntegers(ExampleSet exampleSet) {
// selecting transformation attributes and creating new numeric attributes
LinkedList<Attribute> nominalAttributes = new LinkedList<Attribute>();
LinkedList<Attribute> transformedAttributes = new LinkedList<Attribute>();
for (Attribute attribute : exampleSet.getAttributes()) {
if (!attribute.isNumerical()) {
nominalAttributes.add(attribute);
// creating new attributes for nominal attributes
transformedAttributes.add(AttributeFactory.createAttribute(attribute.getName(), Ontology.NUMERICAL));
}
}
// ensuring capacity in ExampleTable
exampleSet.getExampleTable().addAttributes(transformedAttributes);
// copying values
for (Example example: exampleSet) {
Iterator<Attribute> target = transformedAttributes.iterator();
for (Attribute attribute: nominalAttributes) {
example.setValue(target.next(), example.getValue(attribute));
}
}
// removing nominal attributes from example Set
Attributes attributes = exampleSet.getAttributes();
for(Attribute attribute: exampleSet.getAttributes()) {
if (!attribute.isNumerical())
attributes.replace(attribute, transformedAttributes.poll());
}
return exampleSet;
}
@Override
public Attributes getTargetAttributes(ExampleSet parentSet) {
SimpleAttributes attributes = new SimpleAttributes();
// add special attributes to new attributes
Iterator<AttributeRole> specialRoles = parentSet.getAttributes().specialAttributes();
while (specialRoles.hasNext()) {
attributes.add(specialRoles.next());
}
// add regular attributes
for (Attribute attribute : parentSet.getAttributes()) {
if (!attribute.isNumerical()) {
if ( codingType == NominalToNumeric.EFFECT_CODING || codingType == NominalToNumeric.DUMMY_CODING ) {
double comparisonGroup = -1;
if (useComparisonGroups) {
comparisonGroup = sourceAttributeToComparisonGroupMap.get(attribute.getName());
}
List<String> valueList = attributeToAllNominalValues.get(attribute.getName());
if (valueList != null) {
int currentValue = 0;
for (String attributeValue : valueList) {
if ( currentValue != comparisonGroup ) {
ViewAttribute viewAttribute = new ViewAttribute(
this,
attribute,
NominalToNumeric.getTargetAttributeName(attribute.getName(), attributeValue, useUnderscoreInName),
Ontology.INTEGER,
null);
attributes.addRegular(viewAttribute);
}
++currentValue;
}
}
} else if ( codingType == NominalToNumeric.INTEGERS_CODING ) {
attributes.addRegular(new ViewAttribute(this, attribute, attribute.getName(), Ontology.INTEGER, null));
} else {
assert(false); // unsupported coding
}
} else {
attributes.addRegular(attribute);
}
}
return attributes;
}
@Override
public double getValue(Attribute targetAttribute, double value) {
if ( codingType == NominalToNumeric.DUMMY_CODING ) {
String targetName = targetAttribute.getName();
Double oneValue = attributeTo1ValueMap.get(targetName);
if ( oneValue != null && oneValue == value ) {
return 1;
} else {
// check if the value has been present in the training set
if (unexpectedValueHandling != NominalToNumeric.ALL_ZEROES_AND_NO_WARNING && !isValueInTrainingSet(targetAttribute, value)) {
handleUnexpectedValue(targetName);
}
return 0;
}
} else if ( codingType == NominalToNumeric.EFFECT_CODING ) {
String targetName = targetAttribute.getName();
Pair<Double,Double> storedValue = attributeToValuesMap.get(targetName);
if ( storedValue.getFirst() == value ) {
return 1;
} else if ( storedValue.getSecond() == value ) {
return -1;
} else {
// check if the value has been present in the training set
if (unexpectedValueHandling != NominalToNumeric.ALL_ZEROES_AND_NO_WARNING && !isValueInTrainingSet(targetAttribute, value)) {
handleUnexpectedValue(targetName);
}
return 0;
}
} else if ( codingType == NominalToNumeric.INTEGERS_CODING ) {
return value;
} else {
assert(false); // unsupported coding
return Double.NaN;
}
}
private int handleUnexpectedValue(String targetName) {
switch(unexpectedValueHandling) {
case NominalToNumeric.ALL_ZEROES_AND_WARNING:
getLog().logWarning("unexpected value during application of Nominal to Numerical Model for attribute '" + targetName + "'. Setting to 0."); // TODO i18n
return 0;
case NominalToNumeric.ALL_ZEROES_AND_NO_WARNING:
return 0;
default:
assert(false); // should be one of the above values
return 0;
}
}
private boolean isValueInTrainingSet(Attribute targetAttribute, double value) {
String sourceAttribute = targetAttributeToSourceAttributeMap.get(targetAttribute.getName());
if (sourceAttribute != null) {
List<String> trainingValues = attributeToAllNominalValues.get(sourceAttribute);
if (trainingValues != null) {
int valueCount = trainingValues.size();
if (value >= valueCount) {
return false;
}
} else {
return false;
}
}
return true;
}
@Override
public String getName() {
return "Nominal2Numerical Model";
}
@Override
public String toResultString() {
StringBuilder builder = new StringBuilder();
Attributes trainAttributes = getTrainingHeader().getAttributes();
builder.append(getName()+ Tools.getLineSeparators(2));
String codingTypeString = "";
switch(codingType) {
case NominalToNumeric.INTEGERS_CODING:
codingTypeString = "unique integers";
break;
case NominalToNumeric.DUMMY_CODING:
codingTypeString = "dummy coding";
break;
case NominalToNumeric.EFFECT_CODING:
codingTypeString = "effect coding";
break;
}
builder.append("Coding Type: " + codingTypeString + Tools.getLineSeparator());
if ( codingType == NominalToNumeric.INTEGERS_CODING ) {
builder.append("Model covering " + trainAttributes.size() + " attributes:" + Tools.getLineSeparator());
for (Attribute attribute: trainAttributes) {
builder.append(" - " + attribute.getName() + Tools.getLineSeparator());
}
} else {
builder.append("Model covering " + trainAttributes.size() + " attributes (with comparison group):" + Tools.getLineSeparator());
for (Attribute attribute: trainAttributes) {
builder.append(" - " + attribute.getName() + " ('" + sourceAttributeToComparisonGroupStringsMap.get(attribute.getName()) + "')" + Tools.getLineSeparator());
}
}
return builder.toString();
}
}