package tr.gov.ulakbim.jDenetX.classifiers; /* * HoeffdingAdaptiveTree.java * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand * @author Albert Bifet * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ import tr.gov.ulakbim.jDenetX.classifiers.conditionals.InstanceConditionalTest; import tr.gov.ulakbim.jDenetX.core.DoubleVector; import tr.gov.ulakbim.jDenetX.core.MiscUtils; import tr.gov.ulakbim.jDenetX.options.MultiChoiceOption; import weka.core.Instance; import weka.core.Utils; import java.util.LinkedList; import java.util.List; import java.util.Random; public class HoeffdingAdaptiveTree extends HoeffdingTreeNBAdaptive { private static final long serialVersionUID = 1L; public MultiChoiceOption leafpredictionOption = new MultiChoiceOption( "leafprediction", 'l', "Leaf prediction to use.", new String[]{ "MC", "NB", "NBAdaptive"}, new String[]{ "Majority class", "Naive Bayes", "Naive Bayes Adaptive"}, 2); public interface NewNode { // Change for adwin //public boolean getErrorChange(); public int numberLeaves(); public double getErrorEstimation(); public double getErrorWidth(); public boolean isNullError(); public void killTreeChilds(HoeffdingAdaptiveTree ht); public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch); public void filterInstanceToLeaves(Instance inst, SplitNode myparent, int parentBranch, List<FoundNode> foundNodes, boolean updateSplitterCounts); } public static class AdaSplitNode extends SplitNode implements NewNode { private static final long serialVersionUID = 1L; protected Node alternateTree; protected ADWIN estimationErrorWeight; //public boolean isAlternateTree = false; public boolean ErrorChange = false; protected int randomSeed = 1; protected Random classifierRandom; //public boolean getErrorChange() { // return ErrorChange; //} @Override public int calcByteSizeIncludingSubtree() { int byteSize = calcByteSize(); if (alternateTree != null) { byteSize += alternateTree.calcByteSizeIncludingSubtree(); } if (estimationErrorWeight != null) { byteSize += estimationErrorWeight.measureByteSize(); } for (Node child : this.children) { if (child != null) { byteSize += child.calcByteSizeIncludingSubtree(); } } return byteSize; } public AdaSplitNode(InstanceConditionalTest splitTest, double[] classObservations) { super(splitTest, classObservations); this.classifierRandom = new Random(this.randomSeed); } public int numberLeaves() { int numLeaves = 0; for (Node child : this.children) { if (child != null) { numLeaves += ((NewNode) child).numberLeaves(); } } return numLeaves + 1; } public double getErrorEstimation() { return this.estimationErrorWeight.getEstimation(); } public double getErrorWidth() { double w = 0.0; if (!isNullError()) { w = this.estimationErrorWeight.getWidth(); } return w; } public boolean isNullError() { return (this.estimationErrorWeight == null); } // SplitNodes can have alternative trees, but LearningNodes can't // LearningNodes can split, but SplitNodes can't // Parent nodes are allways SplitNodes public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) { int trueClass = (int) inst.classValue(); //New option vore int k = MiscUtils.poisson(1.0, this.classifierRandom); Instance weightedInst = (Instance) inst.copy(); if (k > 0) { //weightedInst.setWeight(inst.weight() * k); } //Compute ClassPrediction using filterInstanceToLeaf //int ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, null, -1).node.getClassVotes(inst, ht)); int ClassPrediction = 0; if (filterInstanceToLeaf(inst, parent, parentBranch).node != null) { ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, parent, parentBranch).node.getClassVotes(inst, ht)); } boolean blCorrect = (trueClass == ClassPrediction); if (this.estimationErrorWeight == null) { this.estimationErrorWeight = new ADWIN(); } double oldError = this.getErrorEstimation(); this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect ? 0.0 : 1.0); if (this.ErrorChange && oldError > this.getErrorEstimation()) { //if error is decreasing, don't do anything this.ErrorChange = false; } // Check condition to build a new alternate tree //if (this.isAlternateTree == false) { if (this.ErrorChange) {//&& this.alternateTree == null) { //Start a new alternative tree : learning node this.alternateTree = ht.newLearningNode(); //this.alternateTree.isAlternateTree = true; ht.alternateTrees++; } // Check condition to replace tree else if (!((NewNode) this.alternateTree).isNullError()) { if (this.getErrorWidth() > 300 && ((NewNode) this.alternateTree).getErrorWidth() > 300) { double oldErrorRate = this.getErrorEstimation(); double altErrorRate = ((NewNode) this.alternateTree).getErrorEstimation(); double fDelta = .05; //if (gNumAlts>0) fDelta=fDelta/gNumAlts; double fN = 1.0 / ((double) ((NewNode) this.alternateTree).getErrorWidth()) + 1.0 / ((double) this.getErrorWidth()); double Bound = (double) Math.sqrt((double) 2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN); if (Bound < oldErrorRate - altErrorRate) { // Switch alternate tree ht.activeLeafNodeCount -= this.numberLeaves(); ht.activeLeafNodeCount += ((NewNode) this.alternateTree).numberLeaves(); killTreeChilds(ht); if (parent != null) { parent.setChild(parentBranch, this.alternateTree); //((AdaSplitNode) parent.getChild(parentBranch)).alternateTree = null; } else { // Switch root tree ht.treeRoot = ((AdaSplitNode) ht.treeRoot).alternateTree; } ht.switchedAlternateTrees++; } else if (Bound < altErrorRate - oldErrorRate) { // Erase alternate tree if (this.alternateTree instanceof ActiveLearningNode) { this.alternateTree = null; ht.activeLeafNodeCount--; } else if (this.alternateTree instanceof InactiveLearningNode) { this.alternateTree = null; ht.inactiveLeafNodeCount--; } else { ((AdaSplitNode) this.alternateTree).killTreeChilds(ht); } ht.prunedAlternateTrees++; } } } //} //learnFromInstance alternate Tree and Child nodes if (this.alternateTree != null) { ((NewNode) this.alternateTree).learnFromInstance(weightedInst, ht, parent, parentBranch); } int childBranch = this.instanceChildIndex(inst); Node child = this.getChild(childBranch); if (child != null) { ((NewNode) child).learnFromInstance(weightedInst, ht, this, childBranch); } } public void killTreeChilds(HoeffdingAdaptiveTree ht) { for (Node child : this.children) { if (child != null) { //Delete alternate tree if it exists if (child instanceof AdaSplitNode && ((AdaSplitNode) child).alternateTree != null) { ((NewNode) ((AdaSplitNode) child).alternateTree).killTreeChilds(ht); ht.prunedAlternateTrees++; } //Recursive delete of SplitNodes if (child instanceof AdaSplitNode) { ((NewNode) child).killTreeChilds(ht); } if (child instanceof ActiveLearningNode) { child = null; ht.activeLeafNodeCount--; } else if (child instanceof InactiveLearningNode) { child = null; ht.inactiveLeafNodeCount--; } } } } //New for option votes //@Override public void filterInstanceToLeaves(Instance inst, SplitNode myparent, int parentBranch, List<FoundNode> foundNodes, boolean updateSplitterCounts) { if (updateSplitterCounts) { this.observedClassDistribution.addToValue((int) inst .classValue(), inst.weight()); } int childIndex = instanceChildIndex(inst); if (childIndex >= 0) { Node child = getChild(childIndex); if (child != null) { ((NewNode) child).filterInstanceToLeaves(inst, this, childIndex, foundNodes, updateSplitterCounts); } else { foundNodes.add(new FoundNode(null, this, childIndex)); } } if (this.alternateTree != null) { ((NewNode) this.alternateTree).filterInstanceToLeaves(inst, this, -999, foundNodes, updateSplitterCounts); } } } public static class AdaLearningNode extends LearningNodeNBAdaptive implements NewNode { private static final long serialVersionUID = 1L; protected ADWIN estimationErrorWeight; public boolean ErrorChange = false; protected int randomSeed = 1; protected Random classifierRandom; @Override public int calcByteSize() { int byteSize = super.calcByteSize(); if (estimationErrorWeight != null) { byteSize += estimationErrorWeight.measureByteSize(); } return byteSize; } public AdaLearningNode(double[] initialClassObservations) { super(initialClassObservations); this.classifierRandom = new Random(this.randomSeed); } public int numberLeaves() { return 1; } public double getErrorEstimation() { if (this.estimationErrorWeight != null) return this.estimationErrorWeight.getEstimation(); else return 0; } public double getErrorWidth() { return this.estimationErrorWeight.getWidth(); } public boolean isNullError() { return (this.estimationErrorWeight == null); } public void killTreeChilds(HoeffdingAdaptiveTree ht) { } //@Override public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) { int trueClass = (int) inst.classValue(); //New option vore int k = MiscUtils.poisson(1.0, this.classifierRandom); Instance weightedInst = (Instance) inst.copy(); if (k > 0) { weightedInst.setWeight(inst.weight() * k); } //Compute ClassPrediction using filterInstanceToLeaf int ClassPrediction = Utils.maxIndex(this.getClassVotes(inst, ht)); boolean blCorrect = (trueClass == ClassPrediction); if (this.estimationErrorWeight == null) { this.estimationErrorWeight = new ADWIN(); } double oldError = this.getErrorEstimation(); this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect ? 0.0 : 1.0); if (this.ErrorChange && oldError > this.getErrorEstimation()) { this.ErrorChange = false; } //Update statistics learnFromInstance(weightedInst, ht); //inst //Check for Split condition double weightSeen = this.getWeightSeen(); if (weightSeen - this.getWeightSeenAtLastSplitEvaluation() >= ht.gracePeriodOption .getValue()) { ht.attemptToSplit(this, parent, parentBranch); this.setWeightSeenAtLastSplitEvaluation(weightSeen); } //learnFromInstance alternate Tree and Child nodes /*if (this.alternateTree != null) { this.alternateTree.learnFromInstance(inst,ht); } for (Node child : this.children) { if (child != null) { child.learnFromInstance(inst,ht); } }*/ } @Override public double[] getClassVotes(Instance inst, HoeffdingTree ht) { double[] dist; int predictionOption = ((HoeffdingAdaptiveTree) ht).leafpredictionOption.getChosenIndex(); if (predictionOption == 0) { //MC dist = this.observedClassDistribution.getArrayCopy(); } else if (predictionOption == 1) { //NB dist = NaiveBayes.doNaiveBayesPrediction(inst, this.observedClassDistribution, this.attributeObservers); } else { //NBAdaptive if (this.mcCorrectWeight > this.nbCorrectWeight) { dist = this.observedClassDistribution.getArrayCopy(); } else dist = NaiveBayes.doNaiveBayesPrediction(inst, this.observedClassDistribution, this.attributeObservers); } //New for option votes double distSum = Utils.sum(dist); if (distSum * this.getErrorEstimation() * this.getErrorEstimation() > 0.0) { Utils.normalize(dist, distSum * this.getErrorEstimation() * this.getErrorEstimation()); //Adding weight } return dist; } //New for option votes public void filterInstanceToLeaves(Instance inst, SplitNode splitparent, int parentBranch, List<FoundNode> foundNodes, boolean updateSplitterCounts) { foundNodes.add(new FoundNode(this, splitparent, parentBranch)); } } protected int activeLeafNodeCount; protected int inactiveLeafNodeCount; protected int alternateTrees; protected int prunedAlternateTrees; protected int switchedAlternateTrees; @Override protected LearningNode newLearningNode(double[] initialClassObservations) { // IDEA: to choose different learning nodes depending on predictionOption return new AdaLearningNode(initialClassObservations); } //@Override protected SplitNode newSplitNode(InstanceConditionalTest splitTest, double[] classObservations) { return new AdaSplitNode(splitTest, classObservations); } @Override public void trainOnInstanceImpl(Instance inst) { if (this.treeRoot == null) { this.treeRoot = newLearningNode(); this.activeLeafNodeCount = 1; } ((NewNode) this.treeRoot).learnFromInstance(inst, this, null, -1); } //New for options vote public FoundNode[] filterInstanceToLeaves(Instance inst, SplitNode parent, int parentBranch, boolean updateSplitterCounts) { List<FoundNode> nodes = new LinkedList<FoundNode>(); ((NewNode) this.treeRoot).filterInstanceToLeaves(inst, parent, parentBranch, nodes, updateSplitterCounts); return nodes.toArray(new FoundNode[nodes.size()]); } @Override public double[] getVotesForInstance(Instance inst) { if (this.treeRoot != null) { FoundNode[] foundNodes = filterInstanceToLeaves(inst, null, -1, false); DoubleVector result = new DoubleVector(); int predictionPaths = 0; for (FoundNode foundNode : foundNodes) { if (foundNode.parentBranch != -999) { Node leafNode = foundNode.node; if (leafNode == null) { leafNode = foundNode.parent; } double[] dist = leafNode.getClassVotes(inst, this); //Albert: changed for weights //double distSum = Utils.sum(dist); //if (distSum > 0.0) { // Utils.normalize(dist, distSum); //} result.addValues(dist); //predictionPaths++; } } //if (predictionPaths > this.maxPredictionPaths) { // this.maxPredictionPaths++; //} return result.getArrayRef(); } return new double[0]; } }