/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/*
* InfoGainSplitMetric.java
* Copyright (C) 2013 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees.ht;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import weka.core.ContingencyTables;
import weka.core.Utils;
/**
* Implements the info gain splitting criterion
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: 9720 $
*/
public class InfoGainSplitMetric extends SplitMetric implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 2173840581308675428L;
protected double m_minFracWeightForTwoBranches;
public InfoGainSplitMetric(double minFracWeightForTwoBranches) {
m_minFracWeightForTwoBranches = minFracWeightForTwoBranches;
}
@Override
public double evaluateSplit(Map<String, WeightMass> preDist,
List<Map<String, WeightMass>> postDist) {
double[] pre = new double[preDist.size()];
int count = 0;
for (Map.Entry<String, WeightMass> e : preDist.entrySet()) {
pre[count++] = e.getValue().m_weight;
}
double preEntropy = ContingencyTables.entropy(pre);
double[] distWeights = new double[postDist.size()];
double totalWeight = 0.0;
for (int i = 0; i < postDist.size(); i++) {
distWeights[i] = SplitMetric.sum(postDist.get(i));
totalWeight += distWeights[i];
}
int fracCount = 0;
for (double d : distWeights) {
if (d / totalWeight > m_minFracWeightForTwoBranches) {
fracCount++;
}
}
if (fracCount < 2) {
return Double.NEGATIVE_INFINITY;
}
double postEntropy = 0;
for (int i = 0; i < postDist.size(); i++) {
Map<String, WeightMass> d = postDist.get(i);
double[] post = new double[d.size()];
count = 0;
for (Map.Entry<String, WeightMass> e : d.entrySet()) {
post[count++] = e.getValue().m_weight;
}
postEntropy += distWeights[i] * ContingencyTables.entropy(post);
}
if (totalWeight > 0) {
postEntropy /= totalWeight;
}
return preEntropy - postEntropy;
}
@Override
public double getMetricRange(Map<String, WeightMass> preDist) {
int numClasses = preDist.size();
if (numClasses < 2) {
numClasses = 2;
}
return Utils.log2(numClasses);
}
}