/*
* RapidMiner
*
* Copyright (C) 2001-2011 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree.criterions;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.learner.tree.FrequencyCalculator;
/**
* The gain ratio divides the information gain by the prior split info in order to prevent id-like attributes to be
* selected as the best.
*
* @author Sebastian Land, Ingo Mierswa
*/
public class GainRatioCriterion extends InfoGainCriterion {
private static double LOG_FACTOR = 1d / Math.log(2);
private FrequencyCalculator frequencyCalculator = new FrequencyCalculator();
public GainRatioCriterion() {
}
public GainRatioCriterion(double minimalGain) {
super(minimalGain);
}
@Override
public double getNominalBenefit(ExampleSet exampleSet, Attribute attribute) {
double[][] weightCounts = frequencyCalculator.getNominalWeightCounts(exampleSet, attribute);
return getBenefit(weightCounts);
}
@Override
public double getNumericalBenefit(ExampleSet exampleSet, Attribute attribute, double splitValue) {
double[][] weightCounts = frequencyCalculator.getNumericalWeightCounts(exampleSet, attribute, splitValue);
return getBenefit(weightCounts);
}
@Override
public double getBenefit(double[][] weightCounts) {
double gain = super.getBenefit(weightCounts);
double splitInfo = getSplitInfo(weightCounts);
if (splitInfo == 0)
return gain;
else
return gain / splitInfo;
}
protected double getSplitInfo(double[][] weightCounts) {
double[] splitCounts = new double[weightCounts.length];
for (int v = 0; v < weightCounts.length; v++) {
for (int l = 0; l < weightCounts[v].length; l++) {
splitCounts[v] += weightCounts[v][l];
}
}
double totalSplitCount = 0.0d;
for (double w : splitCounts)
totalSplitCount += w;
double splitInfo = 0.0d;
for (int v = 0; v < splitCounts.length; v++) {
if (splitCounts[v] > 0) {
double proportion = splitCounts[v] / totalSplitCount;
splitInfo -= (Math.log(proportion) * LOG_FACTOR) * proportion;
}
}
return splitInfo;
}
protected double getSplitInfo(double[] partitionWeights, double totalWeight) {
double splitInfo = 0;
for (double partitionWeight : partitionWeights) {
if (partitionWeight > 0) {
double partitionProportion = partitionWeight / totalWeight;
splitInfo += partitionProportion * Math.log(partitionProportion) * LOG_FACTOR;
}
}
return -splitInfo;
}
@Override
public boolean supportsIncrementalCalculation() {
return true;
}
@Override
public double getIncrementalBenefit() {
double gain = getEntropy(totalLabelWeights, totalWeight);
gain -= getEntropy(leftLabelWeights, leftWeight) * leftWeight / totalWeight;
gain -= getEntropy(rightLabelWeights, rightWeight) * rightWeight / totalWeight;
double splitInfo = getSplitInfo(new double[] { leftWeight, rightWeight }, totalWeight);
if (splitInfo == 0)
return gain;
else
return gain / splitInfo;
}
}