/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * NBTreeSplit.java * Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.trees.j48; import java.util.Random; import weka.classifiers.bayes.NaiveBayesUpdateable; import weka.core.Instance; import weka.core.Instances; import weka.core.RevisionUtils; import weka.filters.Filter; import weka.filters.supervised.attribute.Discretize; /** * Class implementing a NBTree split on an attribute. * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 8034 $ */ public class NBTreeSplit extends ClassifierSplitModel{ /** for serialization */ private static final long serialVersionUID = 8922627123884975070L; /** Desired number of branches. */ private int m_complexityIndex; /** Attribute to split on. */ private int m_attIndex; /** Minimum number of objects in a split. */ private int m_minNoObj; /** Value of split point. */ private double m_splitPoint; /** The sum of the weights of the instances. */ private double m_sumOfWeights; /** The weight of the instances incorrectly classified by the naive bayes models arising from this split*/ private double m_errors; private C45Split m_c45S; /** The global naive bayes model for this node */ NBTreeNoSplit m_globalNB; /** * Initializes the split model. */ public NBTreeSplit(int attIndex, int minNoObj, double sumOfWeights) { // Get index of attribute to split on. m_attIndex = attIndex; // Set minimum number of objects. m_minNoObj = minNoObj; // Set the sum of the weights m_sumOfWeights = sumOfWeights; } /** * Creates a NBTree-type split on the given data. Assumes that none of * the class values is missing. * * @exception Exception if something goes wrong */ public void buildClassifier(Instances trainInstances) throws Exception { // Initialize the remaining instance variables. m_numSubsets = 0; m_splitPoint = Double.MAX_VALUE; m_errors = 0; if (m_globalNB != null) { m_errors = m_globalNB.getErrors(); } // Different treatment for enumerated and numeric // attributes. if (trainInstances.attribute(m_attIndex).isNominal()) { m_complexityIndex = trainInstances.attribute(m_attIndex).numValues(); handleEnumeratedAttribute(trainInstances); }else{ m_complexityIndex = 2; trainInstances.sort(trainInstances.attribute(m_attIndex)); handleNumericAttribute(trainInstances); } } /** * Returns index of attribute for which split was generated. */ public final int attIndex() { return m_attIndex; } /** * Creates split on enumerated attribute. * * @exception Exception if something goes wrong */ private void handleEnumeratedAttribute(Instances trainInstances) throws Exception { m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights, true); m_c45S.buildClassifier(trainInstances); if (m_c45S.numSubsets() == 0) { return; } m_errors = 0; Instance instance; Instances [] trainingSets = new Instances [m_complexityIndex]; for (int i = 0; i < m_complexityIndex; i++) { trainingSets[i] = new Instances(trainInstances, 0); } /* m_distribution = new Distribution(m_complexityIndex, trainInstances.numClasses()); */ int subset; for (int i = 0; i < trainInstances.numInstances(); i++) { instance = trainInstances.instance(i); subset = m_c45S.whichSubset(instance); if (subset > -1) { trainingSets[subset].add((Instance)instance.copy()); } else { double [] weights = m_c45S.weights(instance); for (int j = 0; j < m_complexityIndex; j++) { try { Instance temp = (Instance) instance.copy(); if (weights.length == m_complexityIndex) { temp.setWeight(temp.weight() * weights[j]); } else { temp.setWeight(temp.weight() / m_complexityIndex); } trainingSets[j].add(temp); } catch (Exception ex) { ex.printStackTrace(); System.err.println("*** "+m_complexityIndex); System.err.println(weights.length); System.exit(1); } } } } /* // compute weights (weights of instances per subset m_weights = new double [m_complexityIndex]; for (int i = 0; i < m_complexityIndex; i++) { m_weights[i] = trainingSets[i].sumOfWeights(); } Utils.normalize(m_weights); */ /* // Only Instances with known values are relevant. Enumeration enu = trainInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (!instance.isMissing(m_attIndex)) { // m_distribution.add((int)instance.value(m_attIndex),instance); trainingSets[(int)instances.value(m_attIndex)].add(instance); } else { // add these to the error count m_errors += instance.weight(); } } */ Random r = new Random(1); int minNumCount = 0; for (int i = 0; i < m_complexityIndex; i++) { if (trainingSets[i].numInstances() >= 5) { minNumCount++; // Discretize the sets Discretize disc = new Discretize(); disc.setInputFormat(trainingSets[i]); trainingSets[i] = Filter.useFilter(trainingSets[i], disc); trainingSets[i].randomize(r); trainingSets[i].stratify(5); NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable(); fullModel.buildClassifier(trainingSets[i]); // add the errors for this branch of the split m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r); } else { // if fewer than min obj then just count them as errors for (int j = 0; j < trainingSets[i].numInstances(); j++) { m_errors += trainingSets[i].instance(j).weight(); } } } // Check if there are at least five instances in at least two of the subsets // subsets. if (minNumCount > 1) { m_numSubsets = m_complexityIndex; } } /** * Creates split on numeric attribute. * * @exception Exception if something goes wrong */ private void handleNumericAttribute(Instances trainInstances) throws Exception { m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights, true); m_c45S.buildClassifier(trainInstances); if (m_c45S.numSubsets() == 0) { return; } m_errors = 0; Instances [] trainingSets = new Instances [m_complexityIndex]; trainingSets[0] = new Instances(trainInstances, 0); trainingSets[1] = new Instances(trainInstances, 0); int subset = -1; // populate the subsets for (int i = 0; i < trainInstances.numInstances(); i++) { Instance instance = trainInstances.instance(i); subset = m_c45S.whichSubset(instance); if (subset != -1) { trainingSets[subset].add((Instance)instance.copy()); } else { double [] weights = m_c45S.weights(instance); for (int j = 0; j < m_complexityIndex; j++) { Instance temp = (Instance)instance.copy(); if (weights.length == m_complexityIndex) { temp.setWeight(temp.weight() * weights[j]); } else { temp.setWeight(temp.weight() / m_complexityIndex); } trainingSets[j].add(temp); } } } /* // compute weights (weights of instances per subset m_weights = new double [m_complexityIndex]; for (int i = 0; i < m_complexityIndex; i++) { m_weights[i] = trainingSets[i].sumOfWeights(); } Utils.normalize(m_weights); */ Random r = new Random(1); int minNumCount = 0; for (int i = 0; i < m_complexityIndex; i++) { if (trainingSets[i].numInstances() > 5) { minNumCount++; // Discretize the sets Discretize disc = new Discretize(); disc.setInputFormat(trainingSets[i]); trainingSets[i] = Filter.useFilter(trainingSets[i], disc); trainingSets[i].randomize(r); trainingSets[i].stratify(5); NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable(); fullModel.buildClassifier(trainingSets[i]); // add the errors for this branch of the split m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r); } else { for (int j = 0; j < trainingSets[i].numInstances(); j++) { m_errors += trainingSets[i].instance(j).weight(); } } } // Check if minimum number of Instances in at least two // subsets. if (minNumCount > 1) { m_numSubsets = m_complexityIndex; } } /** * Returns index of subset instance is assigned to. * Returns -1 if instance is assigned to more than one subset. * * @exception Exception if something goes wrong */ public final int whichSubset(Instance instance) throws Exception { return m_c45S.whichSubset(instance); } /** * Returns weights if instance is assigned to more than one subset. * Returns null if instance is only assigned to one subset. */ public final double [] weights(Instance instance) { return m_c45S.weights(instance); // return m_weights; } /** * Returns a string containing java source code equivalent to the test * made at this node. The instance being tested is called "i". * * @param index index of the nominal value tested * @param data the data containing instance structure info * @return a value of type 'String' */ public final String sourceExpression(int index, Instances data) { return m_c45S.sourceExpression(index, data); } /** * Prints the condition satisfied by instances in a subset. * * @param index of subset * @param data training set. */ public final String rightSide(int index,Instances data) { return m_c45S.rightSide(index, data); } /** * Prints left side of condition.. * * @param data training set. */ public final String leftSide(Instances data) { return m_c45S.leftSide(data); } /** * Return the probability for a class value * * @param classIndex the index of the class value * @param instance the instance to generate a probability for * @param theSubset the subset to consider * @return a probability * @exception Exception if an error occurs */ public double classProb(int classIndex, Instance instance, int theSubset) throws Exception { // use the global naive bayes model if (theSubset > -1) { return m_globalNB.classProb(classIndex, instance, theSubset); } else { throw new Exception("This shouldn't happen!!!"); } } /** * Return the global naive bayes model for this node * * @return a <code>NBTreeNoSplit</code> value */ public NBTreeNoSplit getGlobalModel() { return m_globalNB; } /** * Set the global naive bayes model for this node * * @param global a <code>NBTreeNoSplit</code> value */ public void setGlobalModel(NBTreeNoSplit global) { m_globalNB = global; } /** * Return the errors made by the naive bayes models arising * from this split. * * @return a <code>double</code> value */ public double getErrors() { return m_errors; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 8034 $"); } }