/*********************************************************************** This file is part of KEEL-software, the Data Mining tool for regression, classification, clustering, pattern mining and so on. Copyright (C) 2004-2010 F. Herrera (herrera@decsai.ugr.es) L. S�nchez (luciano@uniovi.es) J. Alcal�-Fdez (jalcala@decsai.ugr.es) S. Garc�a (sglopez@ujaen.es) A. Fern�ndez (alberto.fernandez@ujaen.es) J. Luengo (julianlm@decsai.ugr.es) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/ **********************************************************************/ package keel.Algorithms.Decision_Trees.C45; /** * Class to handle the classifier tree * * <p> * @author Written by Cristobal Romero (Universidad de C�rdoba) 10/10/2007 * @author Modified by Alberto Fernandez (UGR) * @version 1.2 (29-04-10) * @since JDK 1.5 *</p> */ public class Tree { /** Total number of Nodes in the tree */ public static int NumberOfNodes; /** Number of Leafs in the tree */ public static int NumberOfLeafs; /** The selected model. */ protected SelectCut model; /** The model of the node. */ protected Cut nodeModel; /** Sons of the node. */ protected Tree[] sons; /** Is this node leaf or not. */ protected boolean isLeaf; /** Is this node empty or not. */ protected boolean isEmpty; /** The dataset. */ protected Dataset train; /** Is pruned the tree or not. */ protected boolean prune = false; /** The confidence factor for pruning. */ protected float confidence = 0.25f; /** To compute the average attributes per rule */ public static int global = 0; /** Constructor. * * @param selectNodeModel The cut model. * @param pruneTree Prune the tree or not. * @param cf Minimum confidence. */ public Tree(SelectCut selectNodeModel, boolean pruneTree, float cf) { model = selectNodeModel; prune = pruneTree; confidence = cf; NumberOfNodes = 0; NumberOfLeafs = 0; } /** Adds one new node. * * @param data The dataset. * * @throws Exception If the node cannot be built. */ public void buildNode(Dataset data) throws Exception { Dataset[] localItemsets; train = data; isLeaf = false; isEmpty = false; sons = null; nodeModel = model.selectModel(data); if (nodeModel.numSubsets() > 1) { localItemsets = nodeModel.cutDataset(data); data = null; sons = new Tree[nodeModel.numSubsets()]; for (int i = 0; i < sons.length; i++) { sons[i] = getNewTree(localItemsets[i]); localItemsets[i] = null; } } else { isLeaf = true; if (data.sumOfWeights() == 0) { isEmpty = true; } data = null; } } /** Function to build the classifier tree. * * @param data The dataset. * * @throws Exception If the tree cannot be built. */ public void buildTree(Dataset data) throws Exception { data = new Dataset(data); data.deleteWithMissing(data.getClassIndex()); buildNode(data); collapse(); if (prune) { prune(); } } /** Function to collapse a tree to a node if training error doesn't increase. * */ public final void collapse() { double errorsOfSubtree, errorsOfTree; int i; if (!isLeaf) { errorsOfSubtree = getErrors(); errorsOfTree = nodeModel.classification().numIncorrect(); if (errorsOfSubtree >= errorsOfTree - 1E-3) { // Free adjacent trees sons = null; isLeaf = true; // Get NoCut Model for tree. nodeModel = new Cut(nodeModel.classification()); } else { for (i = 0; i < sons.length; i++) { son(i).collapse(); } } } } /** Function to prune a tree. * * @throws Exception If the prune cannot be made. */ public void prune() throws Exception { double errorsLargestBranch, errorsLeaf, errorsTree; int indexOfLargestBranch, i; Tree largestBranch; if (!isLeaf) { // Prune all subtrees. for (i = 0; i < sons.length; i++) { son(i).prune(); } // Compute error for largest branch indexOfLargestBranch = nodeModel.classification().maxValue(); errorsLargestBranch = son(indexOfLargestBranch). getEstimatedErrorsForBranch((Dataset) train); // Compute error if this Tree would be leaf errorsLeaf = getEstimatedErrorsForLeaf(nodeModel.classification()); // Compute error for the whole subtree errorsTree = getEstimatedErrors(); // Decide if leaf is best choice. if (errorsLeaf <= errorsTree + 0.1 && errorsLeaf <= errorsLargestBranch + 0.1) { // Free son Trees sons = null; isLeaf = true; // Get NoCut Model for node. nodeModel = new Cut(nodeModel.classification()); return; } // Decide if largest branch is better choice // than whole subtree. if (errorsLargestBranch <= errorsTree + 0.1) { largestBranch = son(indexOfLargestBranch); sons = largestBranch.sons; nodeModel = largestBranch.nodeModel; isLeaf = largestBranch.isLeaf; newClassification(train); prune(); } } } /** Function to get the classification of classes. * * @param itemset The itemset to classify. * * @return The classification of class values for the itemset. * * @throws Exception If the probabilities cannot be computed. */ public final double[] classificationForItemset(Itemset itemset) throws Exception { double[] doubles = new double[itemset.numClasses()]; for (int i = 0; i < doubles.length; i++) { doubles[i] = getProbabilities(i, itemset, 1); } return doubles; } /** Function to compute the class probabilities of a given itemset. * * @param classIndex The index of the class attribute. * @param itemset The itemset. * @param weight The weight. * * @return The probability of the class. * * @throws Exception If the probabilities cannot be computed. */ private double getProbabilities(int classIndex, Itemset itemset, double weight) throws Exception { double[] weights; double prob = 0; int treeIndex, i, j; if (isLeaf) { return weight * nodeModel.classProbability(classIndex, itemset, -1); } else { treeIndex = nodeModel.whichSubset(itemset); if (treeIndex == -1) { weights = nodeModel.weights(itemset); for (i = 0; i < sons.length; i++) { if (!son(i).isEmpty) { prob += son(i).getProbabilities(classIndex, itemset, weights[i] * weight); } } return prob; } else { if (son(treeIndex).isEmpty) { return weight * nodeModel.classProbability(classIndex, itemset, treeIndex); } else { return son(treeIndex).getProbabilities(classIndex, itemset, weight); } } } } /** Function to print the tree. * */ public String toString() { try { StringBuffer text = new StringBuffer(); if (!isLeaf) { NumberOfNodes++; printTree(0, text); } return text.toString(); } catch (Exception e) { return "Can not print the tree."; } } /** * Function to print the tree (OVO code). */ public String toStringOVO() { try { StringBuffer text = new StringBuffer(); if (!isLeaf) { NumberOfNodes++; printTreeOVO(0, text); } return text.toString(); } catch (Exception e) { return "Can not print the tree."; } } /** Function to print the tree. * * @param depth Depth of the node in the tree. * @param text The tree. * * @throws Exception If the tree cannot be printed. */ private void printTree(int depth, StringBuffer text) throws Exception { int i, j; String aux = ""; for (int k = 0; k < depth; k++) { aux += "\t"; } for (i = 0; i < sons.length; i++) { text.append(aux); if (i == 0) { text.append("if ( " + nodeModel.leftSide(train) + nodeModel.rightSide(i, train) + " ) then\n" + aux + "{\n"); } else { text.append("elseif ( " + nodeModel.leftSide(train) + nodeModel.rightSide(i, train) + " ) then\n" + aux + "{\n"); } if (sons[i].isLeaf) { NumberOfLeafs++; text.append(aux + "\t" + train.getClassAttribute().name() + " = \"" + nodeModel.label(i, train) + "\"\n"); } else { NumberOfNodes++; sons[i].printTree(depth + 1, text); } text.append(aux + "}\n"); } } /** Function to print the tree. * * @param depth Depth of the node in the tree. * @param text The tree. * * @throws Exception If the tree cannot be printed. */ private void printTreeOVO(int depth, StringBuffer text) throws Exception { int i, j; String aux = ""; for (i = 0; i < sons.length; i++) { text.append(aux); if (i == 0) { text.append("if ( " + nodeModel.leftSideOVO(train) + nodeModel.rightSide(i, train) + " ) then\n" + aux + ""); } else { text.append("elseif "+depth+" ( " + nodeModel.leftSideOVO(train) + nodeModel.rightSide(i, train) + " ) then\n" + aux + ""); } if (sons[i].isLeaf) { NumberOfLeafs++; text.append(aux + "\t" + train.getClassAttribute().name() + " = " + nodeModel.label(i, train) + " \n"); } else { NumberOfNodes++; sons[i].printTreeOVO(depth + 1, text); } } } /** * Function to compute the number of attributes of the tree. * @param depth Depth of the node in the tree. * */ private void attributesPerRule(int depth){ depth++; for (int i = 0; i < sons.length; i++) { if (sons[i].isLeaf) { global += depth; } else { sons[i].attributesPerRule(depth); } } } /** * Function to compute the number of attributes of the tree * @return int number of attributes of the tree */ public int getAttributesPerRule(){ global = 0; if (!isLeaf) { this.attributesPerRule(0); } return global; } /** Returns the son with the given index. * * @param index The index of the son. */ private Tree son(int index) { return (Tree) sons[index]; } /** Function to create a new tree. * * @param data The dataset. * * @return The new tree. * * @throws Exception If the new tree cannot be created. */ protected Tree getNewTree(Dataset data) throws Exception { Tree newNode = new Tree(model, prune, confidence); newNode.buildNode((Dataset) data); return newNode; } /** Function to compute the estimated errors. * * @return The estimated errors. */ private double getEstimatedErrors() { double errors = 0; int i; if (isLeaf) { return getEstimatedErrorsForLeaf(nodeModel.classification()); } else { for (i = 0; i < sons.length; i++) { errors = errors + son(i).getEstimatedErrors(); } return errors; } } /** Function to compute the estimated errors for one branch. * * @param data The dataset over the errors has to be computed. * * @return The error computed. * * @throws Exception If the errors cannot be computed. */ private double getEstimatedErrorsForBranch(Dataset data) throws Exception { Dataset[] localItemsets; double errors = 0; int i; if (isLeaf) { return getEstimatedErrorsForLeaf(new Classification(data)); } else { Classification savedDist = nodeModel.classification; nodeModel.resetClassification(data); localItemsets = (Dataset[]) nodeModel.cutDataset(data); nodeModel.classification = savedDist; for (i = 0; i < sons.length; i++) { errors += son(i).getEstimatedErrorsForBranch(localItemsets[i]); } return errors; } } /** Function to compute the estimated errors for leaf. * * @param theClassification The classification of the classes. * * @return The estimated errors for the leaf. */ private double getEstimatedErrorsForLeaf(Classification theClassification) { if (theClassification.getTotal() == 0) { return 0; } else { return theClassification.numIncorrect() + errors(theClassification.getTotal(), theClassification.numIncorrect(), confidence); } } /** Function to compute the errors on training data. * * @return The errors. */ private double getErrors() { double errors = 0; int i; if (isLeaf) { return nodeModel.classification().numIncorrect(); } else { for (i = 0; i < sons.length; i++) { errors += son(i).getErrors(); } return errors; } } /** Function to create a new classification. * * @param data The dataset. * * @throws Exception If the classification cannot be built. */ private void newClassification(Dataset data) throws Exception { Dataset[] localItemsets; nodeModel.resetClassification(data); train = data; if (!isLeaf) { localItemsets = (Dataset[]) nodeModel.cutDataset(data); for (int i = 0; i < sons.length; i++) { son(i).newClassification(localItemsets[i]); } } } /** Function to compute estimated extra error for given total number of itemsets and errors. * * @param N The weight of all the itemsets. * @param e The weight of the itemsets incorrectly classified. * @param CF Minimum confidence. * * @return The errors. */ private static double errors(double N, double e, float CF) { // Some constants for the interpolation. double Val[] = {0, 0.000000001, 0.00000001, 0.0000001, 0.000001, 0.00001, 0.00005, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.10, 0.20, 0.40, 1.00}; double Dev[] = {100, 6.0, 5.61, 5.2, 4.75, 4.26, 3.89, 3.72, 3.29, 3.09, 2.58, 2.33, 1.65, 1.28, 0.84, 0.25, 0.00}; double Val0, Pr, Coeff = 0; int i = 0; while (CF > Val[i]) { i++; } Coeff = Dev[i - 1] + (Dev[i] - Dev[i - 1]) * (CF - Val[i - 1]) / (Val[i] - Val[i - 1]); Coeff = Coeff * Coeff; if (e == 0) { return N * (1 - Math.exp(Math.log(CF) / N)); } else { if (e < 0.9999) { Val0 = N * (1 - Math.exp(Math.log(CF) / N)); return Val0 + e * (errors(N, 1.0, CF) - Val0); } else { if (e + 0.5 >= N) { return 0.67 * (N - e); } else { Pr = (e + 0.5 + Coeff / 2 + Math.sqrt(Coeff * ((e + 0.5) * (1 - (e + 0.5) / N) + Coeff / 4))) / (N + Coeff); return (N * Pr - e); } } } } }