/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.features.weighting;
import static com.rapidminer.operator.features.weighting.AbstractWeighting.PARAMETER_NORMALIZE_WEIGHTS;
import static com.rapidminer.operator.learner.tree.AbstractTreeLearner.*;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.tree.Edge;
import com.rapidminer.operator.learner.tree.RandomForestModel;
import com.rapidminer.operator.learner.tree.Tree;
import com.rapidminer.operator.learner.tree.TreeModel;
import com.rapidminer.operator.learner.tree.criterions.AbstractCriterion;
import com.rapidminer.operator.learner.tree.criterions.Criterion;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeStringCategory;
/**
* This weighting schema will use a given random forest to extract the implicit
* importance of the used attributes. Therefore each node is visited and the benefit created
* by the respective split is aggregated for the attribute the split was performed on.
* The mean benefit over all nodes in all trees is used as importance.
*
* @author Sebastian Land
*/
public class ForestBasedWeighting extends Operator {
/** The parameter name for "Specifies the used criterion for selecting attributes and numerical splits." */
public static final String PARAMETER_CRITERION = "criterion";
private InputPort forestInput = getInputPorts().createPort("random forest");
private OutputPort weightsOutput = getOutputPorts().createPort("weights");
private OutputPort forestOutput = getOutputPorts().createPort("random forest");
public ForestBasedWeighting(OperatorDescription description) {
super(description);
forestInput.addPrecondition(new SimplePrecondition(forestInput, new ModelMetaData(RandomForestModel.class, new ExampleSetMetaData()), true));
getTransformer().addPassThroughRule(forestInput, forestOutput);
getTransformer().addGenerationRule(weightsOutput, AttributeWeights.class);
}
@Override
public void doWork() throws OperatorException {
RandomForestModel forest = forestInput.getData(RandomForestModel.class);
String[] labelValues = forest.getTrainingHeader().getAttributes().getLabel().getMapping().getValues().toArray(new String[0]);
// now start measuring weights
Criterion criterion = AbstractCriterion.createCriterion(this, 0);
HashMap<String, Double> attributeBenefitMap = new HashMap<String, Double>();
for (Model model: forest.getModels()) {
TreeModel treeModel = (TreeModel) model;
extractWeights(attributeBenefitMap, criterion, treeModel.getRoot(), labelValues);
}
AttributeWeights weights = new AttributeWeights();
int numberOfModels = forest.getModels().size();
for (Entry<String, Double> entry: attributeBenefitMap.entrySet()) {
weights.setWeight(entry.getKey(), entry.getValue() / numberOfModels);
}
if (getParameterAsBoolean(PARAMETER_NORMALIZE_WEIGHTS)) {
weights.normalize();
}
weightsOutput.deliver(weights);
forestOutput.deliver(forest);
}
private void extractWeights(HashMap<String, Double> attributeBenefitMap, Criterion criterion, Tree root, String[] labelValues) {
if (!root.isLeaf()) {
int numberOfChildren = root.getNumberOfChildren();
double[][] weights = new double[numberOfChildren][];
String attributeName = null;
Iterator<Edge> childIterator = root.childIterator();
int i = 0;
while (childIterator.hasNext()) {
Edge edge = childIterator.next();
// retrieve attributeName: On each edge the same
attributeName = edge.getCondition().getAttributeName();
// retrieve weights after split: Weight in child
Map<String, Integer> subtreeCounterMap = edge.getChild().getSubtreeCounterMap();
weights[i] = new double[labelValues.length];
for (int j = 0; j < labelValues.length; j++) {
Integer weight = subtreeCounterMap.get(labelValues[j]);
double weightValue = 0;
if (weight != null)
weightValue = weight;
weights[i][j] = weightValue;
}
i++;
}
// calculate benefit and add to map
double benefit = criterion.getBenefit(weights);
Double knownBenefit = attributeBenefitMap.get(attributeName);
if (knownBenefit != null)
benefit += knownBenefit;
attributeBenefitMap.put(attributeName, benefit);
// recursively descent to children
childIterator = root.childIterator();
while (childIterator.hasNext()) {
Tree child = childIterator.next().getChild();
extractWeights(attributeBenefitMap, criterion, child, labelValues);
}
}
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterTypeStringCategory type = new ParameterTypeStringCategory(PARAMETER_CRITERION, "Specifies the used criterion for weighting attributes.", CRITERIA_NAMES, CRITERIA_NAMES[CRITERION_GAIN_RATIO], false);
type.setExpert(false);
types.add(type);
types.add(new ParameterTypeBoolean(PARAMETER_NORMALIZE_WEIGHTS, "Activates the normalization of all weights.", true, false));
return types;
}
}