/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.features.weighting;
import static com.rapidminer.operator.features.weighting.AbstractWeighting.PARAMETER_NORMALIZE_WEIGHTS;
import static com.rapidminer.operator.learner.tree.AbstractTreeLearner.CRITERIA_NAMES;
import static com.rapidminer.operator.learner.tree.AbstractTreeLearner.CRITERION_GAIN_RATIO;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.gui.renderer.RendererService;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.PortUserError;
import com.rapidminer.operator.learner.meta.MetaModel;
import com.rapidminer.operator.learner.tree.ConfigurableRandomForestModel;
import com.rapidminer.operator.learner.tree.Edge;
import com.rapidminer.operator.learner.tree.RandomForestModel;
import com.rapidminer.operator.learner.tree.Tree;
import com.rapidminer.operator.learner.tree.TreeModel;
import com.rapidminer.operator.learner.tree.criterions.AbstractCriterion;
import com.rapidminer.operator.learner.tree.criterions.Criterion;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.CompatibilityLevel;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.ModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeStringCategory;
/**
* This weighting schema will use a given random forest to extract the implicit importance of the
* used attributes. Therefore each node is visited and the benefit created by the respective split
* is aggregated for the attribute the split was performed on. The mean benefit over all nodes in
* all trees is used as importance.
*
* @author Sebastian Land
*/
@SuppressWarnings("deprecation")
public class ForestBasedWeighting extends Operator {
/**
* The parameter name for "Specifies the used criterion for selecting attributes and
* numerical splits."
*/
public static final String PARAMETER_CRITERION = "criterion";
private InputPort forestInput = getInputPorts().createPort("random forest");
private OutputPort weightsOutput = getOutputPorts().createPort("weights");
private OutputPort forestOutput = getOutputPorts().createPort("random forest");
/**
* {@link ModelMetaData} that accepts both {@link RandomForestModel}s and
* {@link ConfigurableRandomForestModel}s.
*
* @author Michael Knopf
* @since 7.0.0
*/
public static class RandomForestModelMetaData extends ModelMetaData {
private static final long serialVersionUID = 1L;
public RandomForestModelMetaData() {
super(ConfigurableRandomForestModel.class, new ExampleSetMetaData());
}
@Override
public boolean isCompatible(MetaData isData, CompatibilityLevel level) {
if (RandomForestModel.class.isAssignableFrom(isData.getObjectClass())) {
return true;
}
return super.isCompatible(isData, level);
}
}
public ForestBasedWeighting(OperatorDescription description) {
super(description);
forestInput.addPrecondition(new SimplePrecondition(forestInput, new RandomForestModelMetaData(), true));
getTransformer().addPassThroughRule(forestInput, forestOutput);
getTransformer().addGenerationRule(weightsOutput, AttributeWeights.class);
}
@Override
public void doWork() throws OperatorException {
// The old and new random forest model implementations are not related (class hierarchy).
// Thus, Port#getData() would fail for one or the other. For this reason, the implementation
// below request the common super-type Model and performs the compatibility check manually.
Model forest = forestInput.getData(Model.class);
if (!(forest instanceof MetaModel)
|| !(forest instanceof ConfigurableRandomForestModel || forest instanceof RandomForestModel)) {
PortUserError error = new PortUserError(forestInput, 156, RendererService.getName(forest.getClass()),
forestInput.getName(), RendererService.getName(ConfigurableRandomForestModel.class));
error.setExpectedType(ConfigurableRandomForestModel.class);
error.setActualType(forest.getClass());
throw error;
}
String[] labelValues = forest.getTrainingHeader().getAttributes().getLabel().getMapping().getValues()
.toArray(new String[0]);
// now start measuring weights
Criterion criterion = AbstractCriterion.createCriterion(this, 0);
HashMap<String, Double> attributeBenefitMap = new HashMap<>();
for (Model model : ((MetaModel) forest).getModels()) {
TreeModel treeModel = (TreeModel) model;
extractWeights(attributeBenefitMap, criterion, treeModel.getRoot(), labelValues);
}
AttributeWeights weights = new AttributeWeights();
int numberOfModels = ((MetaModel) forest).getModels().size();
for (Entry<String, Double> entry : attributeBenefitMap.entrySet()) {
weights.setWeight(entry.getKey(), entry.getValue() / numberOfModels);
}
if (getParameterAsBoolean(PARAMETER_NORMALIZE_WEIGHTS)) {
weights.normalize();
}
weightsOutput.deliver(weights);
forestOutput.deliver(forest);
}
private void extractWeights(HashMap<String, Double> attributeBenefitMap, Criterion criterion, Tree root,
String[] labelValues) {
if (!root.isLeaf()) {
int numberOfChildren = root.getNumberOfChildren();
double[][] weights = new double[numberOfChildren][];
String attributeName = null;
Iterator<Edge> childIterator = root.childIterator();
int i = 0;
while (childIterator.hasNext()) {
Edge edge = childIterator.next();
// retrieve attributeName: On each edge the same
attributeName = edge.getCondition().getAttributeName();
// retrieve weights after split: Weight in child
Map<String, Integer> subtreeCounterMap = edge.getChild().getSubtreeCounterMap();
weights[i] = new double[labelValues.length];
for (int j = 0; j < labelValues.length; j++) {
Integer weight = subtreeCounterMap.get(labelValues[j]);
double weightValue = 0;
if (weight != null) {
weightValue = weight;
}
weights[i][j] = weightValue;
}
i++;
}
// calculate benefit and add to map
double benefit = criterion.getBenefit(weights);
Double knownBenefit = attributeBenefitMap.get(attributeName);
if (knownBenefit != null) {
benefit += knownBenefit;
}
attributeBenefitMap.put(attributeName, benefit);
// recursively descent to children
childIterator = root.childIterator();
while (childIterator.hasNext()) {
Tree child = childIterator.next().getChild();
extractWeights(attributeBenefitMap, criterion, child, labelValues);
}
}
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterTypeStringCategory type = new ParameterTypeStringCategory(PARAMETER_CRITERION,
"Specifies the used criterion for weighting attributes.", CRITERIA_NAMES,
CRITERIA_NAMES[CRITERION_GAIN_RATIO], false);
type.setExpert(false);
types.add(type);
types.add(new ParameterTypeBoolean(PARAMETER_NORMALIZE_WEIGHTS, "Activates the normalization of all weights.", false,
false));
return types;
}
}