/* * InfoGainSplitCriterion.java * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ package tr.gov.ulakbim.jDenetX.classifiers.splits; import tr.gov.ulakbim.jDenetX.core.ObjectRepository; import tr.gov.ulakbim.jDenetX.options.AbstractOptionHandler; import tr.gov.ulakbim.jDenetX.options.FloatOption; import tr.gov.ulakbim.jDenetX.tasks.TaskMonitor; import weka.core.Utils; public class InfoGainSplitCriterion extends AbstractOptionHandler implements SplitCriterion { private static final long serialVersionUID = 1L; private static final double MINVAL = 1.6009E-16; public FloatOption minBranchFracOption = new FloatOption("minBranchFrac", 'f', "Minimum fraction of weight required down at least two branches.", 0.01, 0.0, 0.5); public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) { if (numSubsetsGreaterThanFrac(postSplitDists, this.minBranchFracOption .getValue()) < 2) { return Double.NEGATIVE_INFINITY; } return computeEntropy(preSplitDist) - computeEntropy(postSplitDists); } public double getRangeOfMerit(double[] preSplitDist) { int numClasses = preSplitDist.length > 2 ? preSplitDist.length : 2; return Utils.log2(numClasses); } public static double computeEntropy(double[] dist) { double entropy = 0.0; double sum = 0.0; for (double d : dist) { if (d > MINVAL) { entropy -= d * Utils.log2(d); sum += d; } } return sum > 0.0 ? (entropy + sum * Utils.log2(sum)) / sum : 0.0; } /** * Compute the weighted sums of distributions. * @param dists * @return entropy of the matrix */ public static double computeEntropy(double[][] dists) { double totalWeight = 0.0; double entropy = 0.0; for (double []dist: dists) { double distWeight = Utils.sum(dist); entropy += distWeight * computeEntropy(dist); totalWeight += distWeight; } return entropy / totalWeight; } public static int numSubsetsGreaterThanFrac(double[][] distributions, double minFrac) { double totalWeight = 0.0; double[] distSums = new double[distributions.length]; int numGreater = 0; for (int i = 0; i < distSums.length; i++) { distSums[i] = Utils.sum(distributions[i]); totalWeight += distSums[i]; } for (double d : distSums) { double frac = d / totalWeight; if (frac > minFrac) { numGreater++; } } return numGreater; } public void getDescription(StringBuilder sb, int indent) {} @Override protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {} }