/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.preprocessing.transformation;
import java.util.Collections;
import java.util.List;
import java.util.Vector;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SortedExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.DoubleArrayDataRow;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.annotation.ResourceConsumptionEstimator;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.AttributeSetPrecondition;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MDInteger;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeAttribute;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.OperatorResourceConsumptionHandler;
import com.rapidminer.tools.math.function.aggregation.AbstractAggregationFunction;
import com.rapidminer.tools.math.function.aggregation.AggregationFunction;
/**
* <p>Transforms an example set by grouping multiple examples of single groups
* into single examples. The parameter <em>group_attribute</em> specifies an
* attribute which identifies examples belonging to the groups. The parameter
* <em>index_attribute</em> specifies an attribute whose values are used to
* identify the examples inside the groups. The values of this attributes are
* used to name the group attributes which are created during the pivoting.
* Typically the values of such an attribute capture subgroups or dates.
* If the source example set contains example weights, these weights may be
* aggregated in each group to maintain the weightings among groups.</p>
*
* @author Tobias Malbrecht
*/
public class Example2AttributePivoting extends ExampleSetTransformationOperator {
public static final String PARAMETER_GROUP_ATTRIBUTE = "group_attribute";
public static final String PARAMETER_INDEX_ATTRIBUTE = "index_attribute";
public static final String PARAMETER_CONSIDER_WEIGHTS = "consider_weights";
public static final String PARAMETER_WEIGHT_AGGREGATION = "weight_aggregation";
public static final String PARAMETER_SKIP_CONSTANT_ATTRIBUTES = "skip_constant_attributes";
public Example2AttributePivoting(OperatorDescription description) {
super(description);
getExampleSetInputPort().addPrecondition(new AttributeSetPrecondition(getExampleSetInputPort(), AttributeSetPrecondition.getAttributesByParameter(this, PARAMETER_GROUP_ATTRIBUTE, PARAMETER_INDEX_ATTRIBUTE)));
}
@Override
protected MetaData modifyMetaData(ExampleSetMetaData metaData) throws UndefinedParameterError {
AttributeMetaData groupAttribute = metaData.getAttributeByName(getParameterAsString(PARAMETER_GROUP_ATTRIBUTE));
AttributeMetaData indexAttribute = metaData.getAttributeByName(getParameterAsString(PARAMETER_INDEX_ATTRIBUTE));
if (groupAttribute != null && indexAttribute != null) {
ExampleSetMetaData emd = new ExampleSetMetaData();
// number of examples
if (groupAttribute.isNominal()) {
emd.setNumberOfExamples(groupAttribute.getValueSet().size());
if (groupAttribute.getValueSetRelation() == SetRelation.SUBSET)
emd.getNumberOfExamples().reduceByUnknownAmount();
else if (groupAttribute.getValueSetRelation() == SetRelation.SUPERSET)
emd.getNumberOfExamples().increaseByUnknownAmount();
} else {
emd.setNumberOfExamples(new MDInteger());
}
// attributes
if (indexAttribute.isNominal()) {
//nominal index attribute
for (AttributeMetaData originalAMD: metaData.getAllAttributes()) {
if (!originalAMD.isSpecial() && originalAMD != indexAttribute && originalAMD != groupAttribute) {
if (indexAttribute.getValueSet().size() > 1) {
for (String value: indexAttribute.getValueSet()) {
AttributeMetaData newIndexedAttribute = originalAMD.clone();
newIndexedAttribute.setName(originalAMD.getName() + "_" + value);
newIndexedAttribute.getNumberOfMissingValues().increaseByUnknownAmount();
newIndexedAttribute.setValueSetRelation(SetRelation.SUBSET);
emd.addAttribute(newIndexedAttribute);
emd.mergeSetRelation(newIndexedAttribute.getValueSetRelation());
}
} else {
AttributeMetaData newIndexedAttribute = originalAMD.clone();
emd.addAttribute(newIndexedAttribute);
}
}
if (originalAMD == groupAttribute) {
emd.addAttribute(originalAMD.clone());
}
}
} else {
// numerical index attribute
// add range borders in order to have an example how it could look like
for (AttributeMetaData originalAMD: metaData.getAllAttributes()) {
if (!originalAMD.isSpecial() && originalAMD != indexAttribute && originalAMD != groupAttribute) {
AttributeMetaData newIndexedAttribute = originalAMD.clone();
newIndexedAttribute.setName(originalAMD.getName() + "_" + newIndexedAttribute.getValueRange().getLower());
newIndexedAttribute.getNumberOfMissingValues().increaseByUnknownAmount();
newIndexedAttribute.setValueSetRelation(SetRelation.SUBSET);
emd.addAttribute(newIndexedAttribute);
newIndexedAttribute = originalAMD.clone();
newIndexedAttribute.setName(originalAMD.getName() + "_" + newIndexedAttribute.getValueRange().getUpper());
newIndexedAttribute.getNumberOfMissingValues().increaseByUnknownAmount();
newIndexedAttribute.setValueSetRelation(SetRelation.SUBSET);
emd.addAttribute(newIndexedAttribute);
}
if (originalAMD == groupAttribute) {
emd.addAttribute(originalAMD.clone());
}
}
emd.mergeSetRelation(SetRelation.SUPERSET);
}
return emd;
} else {
return new ExampleSetMetaData();
}
}
@Override
public ExampleSet apply(ExampleSet sourceExampleSet) throws OperatorException {
boolean skipConstantAttributes = getParameterAsBoolean(PARAMETER_SKIP_CONSTANT_ATTRIBUTES);
String groupAttributeName = getParameterAsString(PARAMETER_GROUP_ATTRIBUTE);
String indexAttributeName = getParameterAsString(PARAMETER_INDEX_ATTRIBUTE);
boolean considerWeights = getParameterAsBoolean(PARAMETER_CONSIDER_WEIGHTS);
int weightAggregationFunctionIndex = getParameterAsInt(PARAMETER_WEIGHT_AGGREGATION);
Attribute groupAttribute = sourceExampleSet.getAttributes().get(groupAttributeName);
if (groupAttribute == null) {
throw new UserError(this, 111, groupAttributeName);
}
Attribute indexAttribute = sourceExampleSet.getAttributes().get(indexAttributeName);
if (indexAttribute == null) {
throw new UserError(this, 111, indexAttributeName);
}
Attribute weightAttribute = sourceExampleSet.getAttributes().getWeight();
SortedExampleSet exampleSet = new SortedExampleSet(new SortedExampleSet(sourceExampleSet, indexAttribute, SortedExampleSet.INCREASING), groupAttribute, SortedExampleSet.INCREASING);
// identify static or dynamic attributes and record index values
List<String> indexValues = new Vector<String>();
Attribute[] attributes = exampleSet.getAttributes().createRegularAttributeArray();
boolean[] constantAttributeValues = new boolean[attributes.length];
for (int i = 0; i < constantAttributeValues.length; i++) {
constantAttributeValues[i] = true;
}
Example lastExample = null;
for (Example example : exampleSet) {
if (lastExample != null) {
if (lastExample.getValue(groupAttribute) == example.getValue(groupAttribute)) {
for (int i = 0; i < attributes.length; i++) {
Attribute attribute = attributes[i];
if (Double.isNaN(lastExample.getValue(attribute)) && Double.isNaN(example.getValue(attribute))) {
continue;
}
if (lastExample.getValue(attribute) != example.getValue(attribute)) {
constantAttributeValues[i] = false;
continue;
}
}
}
}
String indexValue = example.getValueAsString(indexAttribute);
if (!indexValues.contains(indexValue)) {
indexValues.add(indexValue);
}
lastExample = example;
}
if (!indexAttribute.isNominal()) {
Collections.sort(indexValues);
}
List<String> attributeNames = new Vector<String>();
List<Attribute> newAttributes = new Vector<Attribute>();
Attribute newWeightAttribute = null;
if (weightAttribute != null && considerWeights) {
newWeightAttribute = AttributeFactory.createAttribute(weightAttribute.getName(), Ontology.REAL);
newAttributes.add(newWeightAttribute);
attributeNames.add(newWeightAttribute.getName());
}
for (int i = 0; i < attributes.length; i++) {
Attribute attribute = attributes[i];
if (!attribute.equals(indexAttribute)) {
if ((skipConstantAttributes && constantAttributeValues[i]) || attribute.equals(groupAttribute)) {
newAttributes.add(AttributeFactory.createAttribute(attribute.getName(), attribute.getValueType()));
attributeNames.add(attribute.getName());
} else {
for (String indexValue : indexValues) {
String newAttributeName = attribute.getName() + "_" + indexValue;
newAttributes.add(AttributeFactory.createAttribute(newAttributeName, attribute.getValueType()));
attributeNames.add(newAttributeName);
}
}
}
}
MemoryExampleTable table = new MemoryExampleTable(newAttributes);
AggregationFunction aggregationFunction = null;
if (newWeightAttribute != null && considerWeights) {
try {
aggregationFunction = AbstractAggregationFunction.createAggregationFunction(weightAggregationFunctionIndex);
} catch (Exception e) {
throw new UserError(this, 904, AbstractAggregationFunction.KNOWN_AGGREGATION_FUNCTION_NAMES[weightAggregationFunctionIndex], e.getMessage());
}
}
double lastGroupValue = Double.NaN;
double[] data = new double[newAttributes.size()];
for (int i = 0; i < data.length; i++) {
data[i] = Double.NaN;
}
for (Example example : exampleSet) {
double currentGroupValue = example.getValue(groupAttribute);
if (!Double.isNaN(lastGroupValue)) {
if (lastGroupValue != currentGroupValue) {
if (aggregationFunction != null) {
data[0] = aggregationFunction.getValue();
try {
aggregationFunction = AbstractAggregationFunction.createAggregationFunction(weightAggregationFunctionIndex);
} catch (Exception e) {
throw new UserError(this, 904, AbstractAggregationFunction.KNOWN_AGGREGATION_FUNCTION_NAMES[weightAggregationFunctionIndex], e.getMessage());
}
}
table.addDataRow(new DoubleArrayDataRow(data));
data = new double[newAttributes.size()];
for (int i = 0; i < data.length; i++) {
data[i] = Double.NaN;
}
}
}
if (aggregationFunction != null) {
aggregationFunction.update(example.getWeight());
}
for (int i = 0; i < attributes.length; i++) {
Attribute attribute = attributes[i];
int newIndex = -1;
if ((skipConstantAttributes && constantAttributeValues[i]) || (attribute.equals(groupAttribute))) {
newIndex = attributeNames.indexOf(attribute.getName());
} else {
String newAttributeName = attribute.getName() + "_" + example.getValueAsString(indexAttribute);
newIndex = attributeNames.indexOf(newAttributeName);
}
if (newIndex != -1) {
double value = example.getValue(attribute);
if (!Double.isNaN(value)) {
if (attribute.isNominal()) {
data[newIndex] = newAttributes.get(newIndex).getMapping().mapString(attribute.getMapping().mapIndex((int) value));
} else {
data[newIndex] = value;
}
}
}
}
lastGroupValue = currentGroupValue;
}
if (aggregationFunction != null) {
data[0] = aggregationFunction.getValue();
}
table.addDataRow(new DoubleArrayDataRow(data));
// create and deliver example set
ExampleSet result = table.createExampleSet();
if (newWeightAttribute != null) {
result.getAttributes().setWeight(newWeightAttribute);
}
result.recalculateAllAttributeStatistics();
return result;
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
types.add(new ParameterTypeAttribute(PARAMETER_GROUP_ATTRIBUTE, "Attribute that groups the examples which form one example after pivoting.", getExampleSetInputPort(), false));
types.add(new ParameterTypeAttribute(PARAMETER_INDEX_ATTRIBUTE, "Attribute which differentiates examples inside a group.", getExampleSetInputPort(), false));
types.add(new ParameterTypeBoolean(PARAMETER_CONSIDER_WEIGHTS, "Determines whether weights will be kept and aggregated or ignored.", true, false));
ParameterType type = new ParameterTypeCategory(PARAMETER_WEIGHT_AGGREGATION, "Specifies how example weights are aggregated in the groups.", AbstractAggregationFunction.KNOWN_AGGREGATION_FUNCTION_NAMES, AbstractAggregationFunction.SUM, false);
type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_CONSIDER_WEIGHTS, true, true));
types.add(type);
types
.add(new ParameterTypeBoolean(PARAMETER_SKIP_CONSTANT_ATTRIBUTES, "Skips attributes if their value never changes within a group.",
true));
return types;
}
@Override
public boolean writesIntoExistingData() {
return false;
}
@Override
public ResourceConsumptionEstimator getResourceConsumptionEstimator() {
return OperatorResourceConsumptionHandler.getResourceConsumptionEstimator(getInputPort(), Example2AttributePivoting.class, null);
}
}