/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * C45PruneableClassifierTreeG.java * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand * Copyright (C) 2007 Geoff Webb & Janice Boughton * */ package weka.classifiers.trees.j48; import weka.core.Capabilities; import weka.core.Instances; import weka.core.Instance; import weka.core.RevisionUtils; import weka.core.Utils; import weka.core.Capabilities.Capability; import java.util.ArrayList; import java.util.Collections; /** * Class for handling a tree structure that can * be pruned using C4.5 procedures and have nodes grafted on. * * @author Janice Boughton (based on code by Eibe Frank) * @version $Revision: 5535 $ */ public class C45PruneableClassifierTreeG extends ClassifierTree{ /** for serialization */ static final long serialVersionUID = 66981207374331964L; /** True if the tree is to be pruned. */ boolean m_pruneTheTree = false; /** The confidence factor for pruning. */ float m_CF = 0.25f; /** Is subtree raising to be performed? */ boolean m_subtreeRaising = true; /** Cleanup after the tree has been built. */ boolean m_cleanup = true; /** flag for using relabelling when grafting */ boolean m_relabel = false; /** binomial probability critical value */ double m_BiProbCrit = 1.64; boolean m_Debug = false; /** * Constructor for pruneable tree structure. Stores reference * to associated training data at each node. * * @param toSelectLocModel selection method for local splitting model * @param pruneTree true if the tree is to be pruned * @param cf the confidence factor for pruning * @param raiseTree * @param cleanup * @throws Exception if something goes wrong */ public C45PruneableClassifierTreeG(ModelSelection toSelectLocModel, boolean pruneTree,float cf, boolean raiseTree, boolean relabel, boolean cleanup) throws Exception { super(toSelectLocModel); m_pruneTheTree = pruneTree; m_CF = cf; m_subtreeRaising = raiseTree; m_cleanup = cleanup; m_relabel = relabel; } /** * Returns default capabilities of the classifier tree. * * @return the capabilities of this classifier tree */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * Constructor for pruneable tree structure. Used to create new nodes * in the tree during grafting. * * @param toSelectLocModel selection method for local splitting model * @param data the dta used to produce split model * @param gs the split model * @param prune true if the tree is to be pruned * @param cf the confidence factor for pruning * @param raise * @param isLeaf if this node is a leaf or not * @param relabel whether relabeling occured * @param cleanup * @throws Exception if something goes wrong */ public C45PruneableClassifierTreeG(ModelSelection toSelectLocModel, Instances data, ClassifierSplitModel gs, boolean prune, float cf, boolean raise, boolean isLeaf, boolean relabel, boolean cleanup) { super(toSelectLocModel); m_relabel = relabel; m_cleanup = cleanup; m_localModel = gs; m_train = data; m_test = null; m_isLeaf = isLeaf; if(gs.distribution().total() > 0) m_isEmpty = false; else m_isEmpty = true; m_pruneTheTree = prune; m_CF = cf; m_subtreeRaising = raise; } /** * Method for building a pruneable classifier tree. * * @param datathe data for building the tree * @throws Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { // can classifier tree handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); buildTree(data, m_subtreeRaising); collapse(); if (m_pruneTheTree) { prune(); } doGrafting(data); if (m_cleanup) { cleanup(new Instances(data, 0)); } } /** * Collapses a tree to a node if training error doesn't increase. */ public final void collapse(){ double errorsOfSubtree; double errorsOfTree; int i; if (!m_isLeaf){ errorsOfSubtree = getTrainingErrors(); errorsOfTree = localModel().distribution().numIncorrect(); if (errorsOfSubtree >= errorsOfTree-1E-3){ // Free adjacent trees m_sons = null; m_isLeaf = true; // Get NoSplit Model for tree. m_localModel = new NoSplit(localModel().distribution()); }else for (i=0;i<m_sons.length;i++) son(i).collapse(); } } /** * Prunes a tree using C4.5's pruning procedure. * * @throws Exception if something goes wrong */ public void prune() throws Exception { double errorsLargestBranch; double errorsLeaf; double errorsTree; int indexOfLargestBranch; C45PruneableClassifierTreeG largestBranch; int i; if (!m_isLeaf){ // Prune all subtrees. for (i=0;i<m_sons.length;i++) son(i).prune(); // Compute error for largest branch indexOfLargestBranch = localModel().distribution().maxBag(); if (m_subtreeRaising) { errorsLargestBranch = son(indexOfLargestBranch). getEstimatedErrorsForBranch((Instances)m_train); } else { errorsLargestBranch = Double.MAX_VALUE; } // Compute error if this Tree would be leaf errorsLeaf = getEstimatedErrorsForDistribution(localModel().distribution()); // Compute error for the whole subtree errorsTree = getEstimatedErrors(); // Decide if leaf is best choice. if (Utils.smOrEq(errorsLeaf,errorsTree+0.1) && Utils.smOrEq(errorsLeaf,errorsLargestBranch+0.1)){ // Free son Trees m_sons = null; m_isLeaf = true; // Get NoSplit Model for node. m_localModel = new NoSplit(localModel().distribution()); return; } // Decide if largest branch is better choice // than whole subtree. if (Utils.smOrEq(errorsLargestBranch,errorsTree+0.1)){ largestBranch = son(indexOfLargestBranch); m_sons = largestBranch.m_sons; m_localModel = largestBranch.localModel(); m_isLeaf = largestBranch.m_isLeaf; newDistribution(m_train); prune(); } } } /** * Returns a newly created tree. * * @param data the data to work with * @return the new tree * @throws Exception if something goes wrong */ protected ClassifierTree getNewTree(Instances data) throws Exception { C45PruneableClassifierTreeG newTree = new C45PruneableClassifierTreeG(m_toSelectModel, m_pruneTheTree, m_CF, m_subtreeRaising, m_relabel, m_cleanup); // ATBOP Modification // m_subtreeRaising, m_cleanup); newTree.buildTree((Instances)data, m_subtreeRaising); return newTree; } /** * Computes estimated errors for tree. * * @return the estimated errors */ private double getEstimatedErrors(){ double errors = 0; int i; if (m_isLeaf) return getEstimatedErrorsForDistribution(localModel().distribution()); else{ for (i=0;i<m_sons.length;i++) errors = errors+son(i).getEstimatedErrors(); return errors; } } /** * Computes estimated errors for one branch. * * @param data the data to work with * @return the estimated errors * @throws Exception if something goes wrong */ private double getEstimatedErrorsForBranch(Instances data) throws Exception { Instances [] localInstances; double errors = 0; int i; if (m_isLeaf) return getEstimatedErrorsForDistribution(new Distribution(data)); else{ Distribution savedDist = localModel().m_distribution; localModel().resetDistribution(data); localInstances = (Instances[])localModel().split(data); localModel().m_distribution = savedDist; for (i=0;i<m_sons.length;i++) errors = errors+ son(i).getEstimatedErrorsForBranch(localInstances[i]); return errors; } } /** * Computes estimated errors for leaf. * * @param theDistribution the distribution to use * @return the estimated errors */ private double getEstimatedErrorsForDistribution(Distribution theDistribution){ if (Utils.eq(theDistribution.total(),0)) return 0; else return theDistribution.numIncorrect()+ Stats.addErrs(theDistribution.total(), theDistribution.numIncorrect(),m_CF); } /** * Computes errors of tree on training data. * * @return the training errors */ private double getTrainingErrors(){ double errors = 0; int i; if (m_isLeaf) return localModel().distribution().numIncorrect(); else{ for (i=0;i<m_sons.length;i++) errors = errors+son(i).getTrainingErrors(); return errors; } } /** * Method just exists to make program easier to read. * * @return the local split model */ private ClassifierSplitModel localModel(){ return (ClassifierSplitModel)m_localModel; } /** * Computes new distributions of instances for nodes * in tree. * * @param data the data to compute the distributions for * @throws Exception if something goes wrong */ private void newDistribution(Instances data) throws Exception { Instances [] localInstances; localModel().resetDistribution(data); m_train = data; if (!m_isLeaf){ localInstances = (Instances [])localModel().split(data); for (int i = 0; i < m_sons.length; i++) son(i).newDistribution(localInstances[i]); } else { // Check whether there are some instances at the leaf now! if (!Utils.eq(data.sumOfWeights(), 0)) { m_isEmpty = false; } } } /** * Method just exists to make program easier to read. */ private C45PruneableClassifierTreeG son(int index){ return (C45PruneableClassifierTreeG)m_sons[index]; } /** * Initializes variables for grafting. * sets up limits array (for numeric attributes) and calls * the recursive function traverseTree. * * @param data the data for the tree * @throws Exception if anything goes wrong */ public void doGrafting(Instances data) throws Exception { // 2d array for the limits double [][] limits = new double[data.numAttributes()][2]; // 2nd dimension: index 0 == lower limit, index 1 == upper limit // initialise to no limit for(int i = 0; i < data.numAttributes(); i++) { limits[i][0] = Double.NEGATIVE_INFINITY; limits[i][1] = Double.POSITIVE_INFINITY; } // use an index instead of creating new Insances objects all the time // instanceIndex[0] == array for weights at leaf // instanceIndex[1] == array for weights in atbop double [][] instanceIndex = new double[2][data.numInstances()]; // initialize the weight for each instance for(int x = 0; x < data.numInstances(); x++) { instanceIndex[0][x] = 1; instanceIndex[1][x] = 1; // leaf instances are in atbop } // first call to graft traverseTree(data, instanceIndex, limits, this, 0, -1); } /** * recursive function. * if this node is a leaf then calls findGraft, otherwise sorts * the two sets of instances (tracked in iindex array) and calls * sortInstances for each of the child nodes (which then calls * this method). * * @param fulldata all instances * @param iindex array the tracks the weight of each instance in * the atbop and at the leaf (0.0 if not present) * @param limits array specifying current upper/lower limits for numeric atts * @param parent the node immediately before the current one * @param pL laplace for node, as calculated by parent (in case leaf is empty) * @param nodeClass class of node, determined by parent (in case leaf empty) */ private void traverseTree(Instances fulldata, double [][] iindex, double[][] limits, C45PruneableClassifierTreeG parent, double pL, int nodeClass) throws Exception { if(m_isLeaf) { findGraft(fulldata, iindex, limits, (ClassifierTree)parent, pL, nodeClass); } else { // traverse each branch for(int i = 0; i < localModel().numSubsets(); i++) { double [][] newiindex = new double[2][fulldata.numInstances()]; for(int x = 0; x < 2; x++) System.arraycopy(iindex[x], 0, newiindex[x], 0, iindex[x].length); sortInstances(fulldata, newiindex, limits, i); } } } /** * sorts/deletes instances into/from node and atbop according to * the test for subset, then calls traverseTree for subset's node. * * @param fulldata all instances * @param iindex array the tracks the weight of each instance in * the atbop and at the leaf (0.0 if not present) * @param limits array specifying current upper/lower limits for numeric atts * @param subset the subset for which to sort instances into inode & iatbop */ private void sortInstances(Instances fulldata, double [][] iindex, double [][] limits, int subset) throws Exception { C45Split test = (C45Split)localModel(); // update the instances index for subset double knownCases = 0; double thisSubsetCount = 0; for(int x = 0; x < iindex[0].length; x++) { if(iindex[0][x] == 0 && iindex[1][x] == 0) // skip "discarded" instances continue; if(!fulldata.instance(x).isMissing(test.attIndex())) { knownCases += iindex[0][x]; if(test.whichSubset(fulldata.instance(x)) != subset) { if(iindex[0][x] > 0) { // move to atbop, delete from leaf iindex[1][x] = iindex[0][x]; iindex[0][x] = 0; } else { if(iindex[1][x] > 0) { // instance is now "discarded" iindex[1][x] = 0; } } } else { thisSubsetCount += iindex[0][x]; } } } // work out proportions of weight for missing values for leaf and atbop double lprop = (knownCases == 0) ? (1.0 / (double)test.numSubsets()) : (thisSubsetCount / (double)knownCases); // add in the instances that have missing value for attIndex for(int x = 0; x < iindex[0].length; x++) { if(iindex[0][x] == 0 && iindex[1][x] == 0) continue; // skip "discarded" instances if(fulldata.instance(x).isMissing(test.attIndex())) { iindex[1][x] -= (iindex[1][x] - iindex[0][x]) * (1-lprop); iindex[0][x] *= lprop; } } int nodeClass = localModel().distribution().maxClass(subset); double pL = (localModel().distribution().perClass(nodeClass) + 1.0) / (localModel().distribution().total() + 2.0); // call traerseTree method for the child node son(subset).traverseTree(fulldata, iindex, test.minsAndMaxs(fulldata, limits, subset), this, pL, nodeClass); } /** * finds new nodes that improve accuracy and grafts them onto the tree * * @param fulldata the instances in whole trainset * @param iindex records num tests each instance has failed up to this node * @param limits the upper/lower limits for numeric attributes * @param parent the node immediately before the current one * @param pLaplace laplace for leaf, calculated by parent (in case leaf empty) * @param pLeafClass class of leaf, determined by parent (in case leaf empty) */ private void findGraft(Instances fulldata, double [][] iindex, double [][] limits, ClassifierTree parent, double pLaplace, int pLeafClass) throws Exception { // get the class for this leaf int leafClass = (m_isEmpty) ? pLeafClass : localModel().distribution().maxClass(); // get the laplace value for this leaf double leafLaplace = (m_isEmpty) ? pLaplace : laplaceLeaf(leafClass); // sort the instances into those at the leaf, those in atbop, and discarded Instances l = new Instances(fulldata, fulldata.numInstances()); Instances n = new Instances(fulldata, fulldata.numInstances()); int lcount = 0; int acount = 0; for(int x = 0; x < fulldata.numInstances(); x++) { if(iindex[0][x] <= 0 && iindex[1][x] <= 0) continue; if(iindex[0][x] != 0) { l.add(fulldata.instance(x)); l.instance(lcount).setWeight(iindex[0][x]); // move instance's weight in iindex to same index as in l iindex[0][lcount++] = iindex[0][x]; } if(iindex[1][x] > 0) { n.add(fulldata.instance(x)); n.instance(acount).setWeight(iindex[1][x]); // move instance's weight in iindex to same index as in n iindex[1][acount++] = iindex[1][x]; } } boolean graftPossible = false; double [] classDist = new double[n.numClasses()]; for(int x = 0; x < n.numInstances(); x++) { if(iindex[1][x] > 0 && !n.instance(x).classIsMissing()) classDist[(int)n.instance(x).classValue()] += iindex[1][x]; } for(int cVal = 0; cVal < n.numClasses(); cVal++) { double theLaplace = (classDist[cVal] + 1.0) / (classDist[cVal] + 2.0); if(cVal != leafClass && (theLaplace > leafLaplace) && (biprob(classDist[cVal], classDist[cVal], leafLaplace) > m_BiProbCrit)) { graftPossible = true; break; } } if(!graftPossible) { return; } // 1. Initialize to {} a set of tuples t containing potential tests ArrayList t = new ArrayList(); // go through each attribute for(int a = 0; a < n.numAttributes(); a++) { if(a == n.classIndex()) continue; // skip the class // sort instances in atbop by $a int [] sorted = sortByAttribute(n, a); // 2. For each continuous attribute $a: if(n.attribute(a).isNumeric()) { // find min and max values for this attribute at the leaf boolean prohibited = false; double minLeaf = Double.POSITIVE_INFINITY; double maxLeaf = Double.NEGATIVE_INFINITY; for(int i = 0; i < l.numInstances(); i++) { if(l.instance(i).isMissing(a)) { if(l.instance(i).classValue() == leafClass) { prohibited = true; break; } } double value = l.instance(i).value(a); if(!m_relabel || l.instance(i).classValue() == leafClass) { if(value < minLeaf) minLeaf = value; if(value > maxLeaf) maxLeaf = value; } } if(prohibited) { continue; } // (a) find values of // $n: instances in atbop (already have that, actually) // $v: a value for $a that exists for a case in the atbop, where // $v is < the min value for $a for a case at the leaf which // has the class $c, and $v is > the lowerlimit of $a at // the leaf. // (note: error in original paper stated that $v must be // smaller OR EQUAL TO the min value). // $k: $k is a class // that maximize L' = Laplace({$x: $x contained in cases($n) // & value($a,$x) <= $v & value($a,$x) > lowerlim($l,$a)}, $k). double minBestClass = Double.NaN; double minBestLaplace = leafLaplace; double minBestVal = Double.NaN; double minBestPos = Double.NaN; double minBestTotal = Double.NaN; double [][] minBestCounts = null; double [][] counts = new double[2][n.numClasses()]; for(int x = 0; x < n.numInstances(); x++) { if(n.instance(sorted[x]).isMissing(a)) break; // missing are sorted to end: no more valid vals double theval = n.instance(sorted[x]).value(a); if(m_Debug) System.out.println("\t " + theval); if(theval <= limits[a][0]) { if(m_Debug) System.out.println("\t <= lowerlim: continuing..."); continue; } // note: error in paper would have this read "theVal > minLeaf) if(theval >= minLeaf) { if(m_Debug) System.out.println("\t >= minLeaf; breaking..."); break; } counts[0][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; if(x != n.numInstances() - 1) { int z = x + 1; while(z < n.numInstances() && n.instance(sorted[z]).value(a) == theval) { z++; x++; counts[0][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; } } // work out the best laplace/class (for <= theval) double total = Utils.sum(counts[0]); for(int c = 0; c < n.numClasses(); c++) { double temp = (counts[0][c]+1.0)/(total+2.0); if(temp > minBestLaplace) { minBestPos = counts[0][c]; minBestTotal = total; minBestLaplace = temp; minBestClass = c; minBestCounts = copyCounts(counts); minBestVal = (x == n.numInstances()-1) ? theval : ((theval + n.instance(sorted[x+1]).value(a)) / 2.0); } } } // (b) add to t tuple <n,a,v,k,L',"<="> if(!Double.isNaN(minBestVal) && biprob(minBestPos, minBestTotal, leafLaplace) > m_BiProbCrit) { GraftSplit gsplit = null; try { gsplit = new GraftSplit(a, minBestVal, 0, leafClass, minBestCounts); } catch (Exception e) { System.err.println("graftsplit error: "+e.getMessage()); System.exit(1); } t.add(gsplit); } // free space minBestCounts = null; // (c) find values of // n: instances in atbop (already have that, actually) // $v: a value for $a that exists for a case in the atbop, where // $v is > the max value for $a for a case at the leaf which // has the class $c, and $v is <= the upperlimit of $a at // the leaf. // k: k is a class // that maximize L' = Laplace({x: x contained in cases(n) // & value(a,x) > v & value(a,x) <= upperlim(l,a)}, k). double maxBestClass = -1; double maxBestLaplace = leafLaplace; double maxBestVal = Double.NaN; double maxBestPos = Double.NaN; double maxBestTotal = Double.NaN; double [][] maxBestCounts = null; for(int c = 0; c < n.numClasses(); c++) { // zero the counts counts[0][c] = 0; counts[1][c] = 0; // shouldn't need to do this ... } // check smallest val for a in atbop is < upper limit if(n.numInstances() >= 1 && n.instance(sorted[0]).value(a) < limits[a][1]) { for(int x = n.numInstances() - 1; x >= 0; x--) { if(n.instance(sorted[x]).isMissing(a)) continue; double theval = n.instance(sorted[x]).value(a); if(m_Debug) System.out.println("\t " + theval); if(theval > limits[a][1]) { if(m_Debug) System.out.println("\t >= upperlim; continuing..."); continue; } if(theval <= maxLeaf) { if(m_Debug) System.out.println("\t < maxLeaf; breaking..."); break; } // increment counts counts[1][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; if(x != 0 && !n.instance(sorted[x-1]).isMissing(a)) { int z = x - 1; while(z >= 0 && n.instance(sorted[z]).value(a) == theval) { z--; x--; counts[1][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; } } // work out best laplace for > theval double total = Utils.sum(counts[1]); for(int c = 0; c < n.numClasses(); c++) { double temp = (counts[1][c]+1.0)/(total+2.0); if(temp > maxBestLaplace ) { maxBestPos = counts[1][c]; maxBestTotal = total; maxBestLaplace = temp; maxBestClass = c; maxBestCounts = copyCounts(counts); maxBestVal = (x == 0) ? theval : ((theval + n.instance(sorted[x-1]).value(a)) / 2.0); } } } // (d) add to t tuple <n,a,v,k,L',">"> if(!Double.isNaN(maxBestVal) && biprob(maxBestPos,maxBestTotal,leafLaplace) > m_BiProbCrit) { GraftSplit gsplit = null; try { gsplit = new GraftSplit(a, maxBestVal, 1, leafClass, maxBestCounts); } catch (Exception e) { System.err.println("graftsplit error:" + e.getMessage()); System.exit(1); } t.add(gsplit); } } } else { // must be a nominal attribute // 3. for each discrete attribute a for which there is no // test at an ancestor of l // skip if this attribute has already been used if(limits[a][1] == 1) { continue; } boolean [] prohibit = new boolean[l.attribute(a).numValues()]; for(int aval = 0; aval < n.attribute(a).numValues(); aval++) { for(int x = 0; x < l.numInstances(); x++) { if((l.instance(x).isMissing(a) || l.instance(x).value(a) == aval) && (!m_relabel || (l.instance(x).classValue() == leafClass))) { prohibit[aval] = true; break; } } } // (a) find values of // $n: instances in atbop (already have that, actually) // $v: $v is a value for $a // $k: $k is a class // that maximize L' = Laplace({$x: $x contained in cases($n) // & value($a,$x) = $v}, $k). double bestVal = Double.NaN; double bestClass = Double.NaN; double bestLaplace = leafLaplace; double [][] bestCounts = null; double [][] counts = new double[2][n.numClasses()]; for(int x = 0; x < n.numInstances(); x++) { if(n.instance(sorted[x]).isMissing(a)) continue; // zero the counts for(int c = 0; c < n.numClasses(); c++) counts[0][c] = 0; double theval = n.instance(sorted[x]).value(a); counts[0][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; if(x != n.numInstances() - 1) { int z = x + 1; while(z < n.numInstances() && n.instance(sorted[z]).value(a) == theval) { z++; x++; counts[0][(int)n.instance(sorted[x]).classValue()] += iindex[1][sorted[x]]; } } if(!prohibit[(int)theval]) { // work out best laplace for > theval double total = Utils.sum(counts[0]); bestLaplace = leafLaplace; bestClass = Double.NaN; for(int c = 0; c < n.numClasses(); c++) { double temp = (counts[0][c]+1.0)/(total+2.0); if(temp > bestLaplace && biprob(counts[0][c],total,leafLaplace) > m_BiProbCrit) { bestLaplace = temp; bestClass = c; bestVal = theval; bestCounts = copyCounts(counts); } } // add to graft list if(!Double.isNaN(bestClass)) { GraftSplit gsplit = null; try { gsplit = new GraftSplit(a, bestVal, 2, leafClass, bestCounts); } catch (Exception e) { System.err.println("graftsplit error: "+e.getMessage()); System.exit(1); } t.add(gsplit); } } } // (b) add to t tuple <n,a,v,k,L',"="> // done this already } } // 4. remove from t all tuples <n,a,v,c,L,x> such that L <= // Laplace(cases(l),c) or prob(x,n,Laplace(cases(l),c) <= 0.05 // -- checked this constraint prior to adding a tuple -- // *** step six done before step five for efficiency *** // 6. for each <n,a,v,k,L,x> in t ordered on L from highest to lowest // order the tuples from highest to lowest laplace // (this actually orders lowest to highest) Collections.sort(t); // 5. remove from t all tuples <n,a,v,c,L,x> such that there is // no tuple <n',a',v',k',L',x'> such that k' != c & L' < L. for(int x = 0; x < t.size(); x++) { GraftSplit gs = (GraftSplit)t.get(x); if(gs.maxClassForSubsetOfInterest() != leafClass) { break; // reached a graft with class != leafClass, so stop deleting } else { t.remove(x); x--; } } // if no potential grafts were found, do nothing and return if(t.size() < 1) { return; } // create the distributions for each graft for(int x = t.size()-1; x >= 0; x--) { GraftSplit gs = (GraftSplit)t.get(x); try { gs.buildClassifier(l); gs.deleteGraftedCases(l); // so they don't go down the other branch } catch (Exception e) { System.err.println("graftsplit build error: " + e.getMessage()); } } // add this stuff to the tree ((C45PruneableClassifierTreeG)parent).setDescendents(t, this); } /** * sorts the int array in ascending order by attribute indexed * by a in dataset data. * @param the data the indices represent * @param the index of the attribute to sort by * @return array of sorted indicies */ private int [] sortByAttribute(Instances data, int a) { double [] attList = data.attributeToDoubleArray(a); int [] temp = Utils.sort(attList); return temp; } /** * deep copy the 2d array of counts * * @param src the array to copy * @return a copy of src */ private double [][] copyCounts(double [][] src) { double [][] newArr = new double[src.length][0]; for(int x = 0; x < src.length; x++) { newArr[x] = new double[src[x].length]; for(int y = 0; y < src[x].length; y++) { newArr[x][y] = src[x][y]; } } return newArr; } /** * Help method for computing class probabilities of * a given instance. * * @throws Exception if something goes wrong */ private double getProbsLaplace(int classIndex, Instance instance, double weight) throws Exception { double [] weights; double prob = 0; int treeIndex; int i,j; if (m_isLeaf) { return weight * localModel().classProbLaplace(classIndex, instance, -1); } else { treeIndex = localModel().whichSubset(instance); if (treeIndex == -1) { weights = localModel().weights(instance); for (i = 0; i < m_sons.length; i++) { if (!son(i).m_isEmpty) { if (!son(i).m_isLeaf) { prob += son(i).getProbsLaplace(classIndex, instance, weights[i] * weight); } else { prob += weight * weights[i] * localModel().classProbLaplace(classIndex, instance, i); } } } return prob; } else { if (son(treeIndex).m_isLeaf) { return weight * localModel().classProbLaplace(classIndex, instance, treeIndex); } else { return son(treeIndex).getProbsLaplace(classIndex,instance,weight); } } } } /** * Help method for computing class probabilities of * a given instance. * * @throws Exception if something goes wrong */ private double getProbs(int classIndex, Instance instance, double weight) throws Exception { double [] weights; double prob = 0; int treeIndex; int i,j; if (m_isLeaf) { return weight * localModel().classProb(classIndex, instance, -1); } else { treeIndex = localModel().whichSubset(instance); if (treeIndex == -1) { weights = localModel().weights(instance); for (i = 0; i < m_sons.length; i++) { if (!son(i).m_isEmpty) { prob += son(i).getProbs(classIndex, instance, weights[i] * weight); } } return prob; } else { if (son(treeIndex).m_isEmpty) { return weight * localModel().classProb(classIndex, instance, treeIndex); } else { return son(treeIndex).getProbs(classIndex, instance, weight); } } } } /** * add the grafted nodes at originalLeaf's position in tree. * a recursive function that terminates when t is empty. * * @param t the list of nodes to graft * @param originalLeaf the leaf that the grafts are replacing */ public void setDescendents(ArrayList t, C45PruneableClassifierTreeG originalLeaf) { Instances headerInfo = new Instances(m_train, 0); boolean end = false; ClassifierSplitModel splitmod = null; C45PruneableClassifierTreeG newNode; if(t.size() > 0) { splitmod = (ClassifierSplitModel)t.remove(t.size() - 1); newNode = new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo, splitmod, m_pruneTheTree, m_CF, m_subtreeRaising, false, m_relabel, m_cleanup); } else { // get the leaf for one of newNode's children NoSplit kLeaf = ((GraftSplit)localModel()).getOtherLeaf(); newNode = new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo, kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising, true, m_relabel, m_cleanup); end = true; } // behave differently for parent of original leaf, since we don't // want to destroy any of its other branches if(m_sons != null) { for(int x = 0; x < m_sons.length; x++) { if(son(x).equals(originalLeaf)) { m_sons[x] = newNode; // replace originalLeaf with newNode } } } else { // allocate space for the children m_sons = new C45PruneableClassifierTreeG[localModel().numSubsets()]; // get the leaf for one of newNode's children NoSplit kLeaf = ((GraftSplit)localModel()).getLeaf(); C45PruneableClassifierTreeG kNode = new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo, kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising, true, m_relabel, m_cleanup); // figure where to put the new node if(((GraftSplit)localModel()).subsetOfInterest() == 0) { m_sons[0] = kNode; m_sons[1] = newNode; } else { m_sons[0] = newNode; m_sons[1] = kNode; } } if(!end) ((C45PruneableClassifierTreeG)newNode).setDescendents (t, (C45PruneableClassifierTreeG)originalLeaf); } /** * class prob with laplace correction (assumes binary class) */ private double laplaceLeaf(double classIndex) { double l = (localModel().distribution().perClass((int)classIndex) + 1.0) / (localModel().distribution().total() + 2.0); return l; } /** * Significance test * @param double x, double n, double r. * @return returns the probability of obtaining x or MORE out of n * if r proportion of n are positive. * * z for normal estimation of binomial probability of obtaining x * or more out of n, if r proportion of n are positive */ public double biprob(double x, double n, double r) throws Exception { return ((((x) - 0.5) - (n) * (r)) / Math.sqrt((n) * (r) * (1.0 - (r)))); } /** * Prints tree structure. */ public String toString() { try { StringBuffer text = new StringBuffer(); if(m_isLeaf) { text.append(": "); if(m_localModel instanceof GraftSplit) text.append(((GraftSplit)m_localModel).dumpLabelG(0,m_train)); else text.append(m_localModel.dumpLabel(0,m_train)); } else dumpTree(0,text); text.append("\n\nNumber of Leaves : \t"+numLeaves()+"\n"); text.append("\nSize of the tree : \t"+numNodes()+"\n"); return text.toString(); } catch (Exception e) { return "Can't print classification tree."; } } /** * Help method for printing tree structure. * * @throws Exception if something goes wrong */ protected void dumpTree(int depth,StringBuffer text) throws Exception { int i,j; for(i=0;i<m_sons.length;i++) { text.append("\n");; for(j=0;j<depth;j++) text.append("| "); text.append(m_localModel.leftSide(m_train)); text.append(m_localModel.rightSide(i, m_train)); if(m_sons[i].m_isLeaf) { text.append(": "); if(m_localModel instanceof GraftSplit) text.append(((GraftSplit)m_localModel).dumpLabelG(i,m_train)); else text.append(m_localModel.dumpLabel(i,m_train)); } else ((C45PruneableClassifierTreeG)m_sons[i]).dumpTree(depth+1,text); } } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5535 $"); } }