package tr.gov.ulakbim.jDenetX.classifiers;
/*
* HoeffdingAdaptiveTree.java
* Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
* @author Albert Bifet
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
import tr.gov.ulakbim.jDenetX.classifiers.conditionals.InstanceConditionalTest;
import tr.gov.ulakbim.jDenetX.core.DoubleVector;
import tr.gov.ulakbim.jDenetX.core.MiscUtils;
import tr.gov.ulakbim.jDenetX.options.MultiChoiceOption;
import weka.core.Instance;
import weka.core.Utils;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
public class HoeffdingAdaptiveTree extends HoeffdingTreeNBAdaptive {
private static final long serialVersionUID = 1L;
public MultiChoiceOption leafpredictionOption = new MultiChoiceOption(
"leafprediction", 'l', "Leaf prediction to use.", new String[]{
"MC", "NB", "NBAdaptive"}, new String[]{
"Majority class",
"Naive Bayes",
"Naive Bayes Adaptive"}, 2);
public interface NewNode {
// Change for adwin
//public boolean getErrorChange();
public int numberLeaves();
public double getErrorEstimation();
public double getErrorWidth();
public boolean isNullError();
public void killTreeChilds(HoeffdingAdaptiveTree ht);
public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch);
public void filterInstanceToLeaves(Instance inst, SplitNode myparent, int parentBranch, List<FoundNode> foundNodes,
boolean updateSplitterCounts);
}
public static class AdaSplitNode extends SplitNode implements NewNode {
private static final long serialVersionUID = 1L;
protected Node alternateTree;
protected ADWIN estimationErrorWeight;
//public boolean isAlternateTree = false;
public boolean ErrorChange = false;
protected int randomSeed = 1;
protected Random classifierRandom;
//public boolean getErrorChange() {
// return ErrorChange;
//}
@Override
public int calcByteSizeIncludingSubtree() {
int byteSize = calcByteSize();
if (alternateTree != null) {
byteSize += alternateTree.calcByteSizeIncludingSubtree();
}
if (estimationErrorWeight != null) {
byteSize += estimationErrorWeight.measureByteSize();
}
for (Node child : this.children) {
if (child != null) {
byteSize += child.calcByteSizeIncludingSubtree();
}
}
return byteSize;
}
public AdaSplitNode(InstanceConditionalTest splitTest,
double[] classObservations) {
super(splitTest, classObservations);
this.classifierRandom = new Random(this.randomSeed);
}
public int numberLeaves() {
int numLeaves = 0;
for (Node child : this.children) {
if (child != null) {
numLeaves += ((NewNode) child).numberLeaves();
}
}
return numLeaves + 1;
}
public double getErrorEstimation() {
return this.estimationErrorWeight.getEstimation();
}
public double getErrorWidth() {
double w = 0.0;
if (!isNullError()) {
w = this.estimationErrorWeight.getWidth();
}
return w;
}
public boolean isNullError() {
return (this.estimationErrorWeight == null);
}
// SplitNodes can have alternative trees, but LearningNodes can't
// LearningNodes can split, but SplitNodes can't
// Parent nodes are allways SplitNodes
public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) {
int trueClass = (int) inst.classValue();
//New option vore
int k = MiscUtils.poisson(1.0, this.classifierRandom);
Instance weightedInst = (Instance) inst.copy();
if (k > 0) {
//weightedInst.setWeight(inst.weight() * k);
}
//Compute ClassPrediction using filterInstanceToLeaf
//int ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, null, -1).node.getClassVotes(inst, ht));
int ClassPrediction = 0;
if (filterInstanceToLeaf(inst, parent, parentBranch).node != null) {
ClassPrediction = Utils.maxIndex(filterInstanceToLeaf(inst, parent, parentBranch).node.getClassVotes(inst, ht));
}
boolean blCorrect = (trueClass == ClassPrediction);
if (this.estimationErrorWeight == null) {
this.estimationErrorWeight = new ADWIN();
}
double oldError = this.getErrorEstimation();
this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect ? 0.0 : 1.0);
if (this.ErrorChange && oldError > this.getErrorEstimation()) {
//if error is decreasing, don't do anything
this.ErrorChange = false;
}
// Check condition to build a new alternate tree
//if (this.isAlternateTree == false) {
if (this.ErrorChange) {//&& this.alternateTree == null) {
//Start a new alternative tree : learning node
this.alternateTree = ht.newLearningNode();
//this.alternateTree.isAlternateTree = true;
ht.alternateTrees++;
}
// Check condition to replace tree
else if (!((NewNode) this.alternateTree).isNullError()) {
if (this.getErrorWidth() > 300 && ((NewNode) this.alternateTree).getErrorWidth() > 300) {
double oldErrorRate = this.getErrorEstimation();
double altErrorRate = ((NewNode) this.alternateTree).getErrorEstimation();
double fDelta = .05;
//if (gNumAlts>0) fDelta=fDelta/gNumAlts;
double fN = 1.0 / ((double) ((NewNode) this.alternateTree).getErrorWidth()) + 1.0 / ((double) this.getErrorWidth());
double Bound = (double) Math.sqrt((double) 2.0 * oldErrorRate * (1.0 - oldErrorRate) * Math.log(2.0 / fDelta) * fN);
if (Bound < oldErrorRate - altErrorRate) {
// Switch alternate tree
ht.activeLeafNodeCount -= this.numberLeaves();
ht.activeLeafNodeCount += ((NewNode) this.alternateTree).numberLeaves();
killTreeChilds(ht);
if (parent != null) {
parent.setChild(parentBranch, this.alternateTree);
//((AdaSplitNode) parent.getChild(parentBranch)).alternateTree = null;
} else {
// Switch root tree
ht.treeRoot = ((AdaSplitNode) ht.treeRoot).alternateTree;
}
ht.switchedAlternateTrees++;
} else if (Bound < altErrorRate - oldErrorRate) {
// Erase alternate tree
if (this.alternateTree instanceof ActiveLearningNode) {
this.alternateTree = null;
ht.activeLeafNodeCount--;
} else if (this.alternateTree instanceof InactiveLearningNode) {
this.alternateTree = null;
ht.inactiveLeafNodeCount--;
} else {
((AdaSplitNode) this.alternateTree).killTreeChilds(ht);
}
ht.prunedAlternateTrees++;
}
}
}
//}
//learnFromInstance alternate Tree and Child nodes
if (this.alternateTree != null) {
((NewNode) this.alternateTree).learnFromInstance(weightedInst, ht, parent, parentBranch);
}
int childBranch = this.instanceChildIndex(inst);
Node child = this.getChild(childBranch);
if (child != null) {
((NewNode) child).learnFromInstance(weightedInst, ht, this, childBranch);
}
}
public void killTreeChilds(HoeffdingAdaptiveTree ht) {
for (Node child : this.children) {
if (child != null) {
//Delete alternate tree if it exists
if (child instanceof AdaSplitNode && ((AdaSplitNode) child).alternateTree != null) {
((NewNode) ((AdaSplitNode) child).alternateTree).killTreeChilds(ht);
ht.prunedAlternateTrees++;
}
//Recursive delete of SplitNodes
if (child instanceof AdaSplitNode) {
((NewNode) child).killTreeChilds(ht);
}
if (child instanceof ActiveLearningNode) {
child = null;
ht.activeLeafNodeCount--;
} else if (child instanceof InactiveLearningNode) {
child = null;
ht.inactiveLeafNodeCount--;
}
}
}
}
//New for option votes
//@Override
public void filterInstanceToLeaves(Instance inst, SplitNode myparent,
int parentBranch, List<FoundNode> foundNodes,
boolean updateSplitterCounts) {
if (updateSplitterCounts) {
this.observedClassDistribution.addToValue((int) inst
.classValue(), inst.weight());
}
int childIndex = instanceChildIndex(inst);
if (childIndex >= 0) {
Node child = getChild(childIndex);
if (child != null) {
((NewNode) child).filterInstanceToLeaves(inst, this, childIndex,
foundNodes, updateSplitterCounts);
} else {
foundNodes.add(new FoundNode(null, this, childIndex));
}
}
if (this.alternateTree != null) {
((NewNode) this.alternateTree).filterInstanceToLeaves(inst, this, -999,
foundNodes, updateSplitterCounts);
}
}
}
public static class AdaLearningNode extends LearningNodeNBAdaptive implements NewNode {
private static final long serialVersionUID = 1L;
protected ADWIN estimationErrorWeight;
public boolean ErrorChange = false;
protected int randomSeed = 1;
protected Random classifierRandom;
@Override
public int calcByteSize() {
int byteSize = super.calcByteSize();
if (estimationErrorWeight != null) {
byteSize += estimationErrorWeight.measureByteSize();
}
return byteSize;
}
public AdaLearningNode(double[] initialClassObservations) {
super(initialClassObservations);
this.classifierRandom = new Random(this.randomSeed);
}
public int numberLeaves() {
return 1;
}
public double getErrorEstimation() {
if (this.estimationErrorWeight != null)
return this.estimationErrorWeight.getEstimation();
else
return 0;
}
public double getErrorWidth() {
return this.estimationErrorWeight.getWidth();
}
public boolean isNullError() {
return (this.estimationErrorWeight == null);
}
public void killTreeChilds(HoeffdingAdaptiveTree ht) {
}
//@Override
public void learnFromInstance(Instance inst, HoeffdingAdaptiveTree ht, SplitNode parent, int parentBranch) {
int trueClass = (int) inst.classValue();
//New option vore
int k = MiscUtils.poisson(1.0, this.classifierRandom);
Instance weightedInst = (Instance) inst.copy();
if (k > 0) {
weightedInst.setWeight(inst.weight() * k);
}
//Compute ClassPrediction using filterInstanceToLeaf
int ClassPrediction = Utils.maxIndex(this.getClassVotes(inst, ht));
boolean blCorrect = (trueClass == ClassPrediction);
if (this.estimationErrorWeight == null) {
this.estimationErrorWeight = new ADWIN();
}
double oldError = this.getErrorEstimation();
this.ErrorChange = this.estimationErrorWeight.setInput(blCorrect ? 0.0 : 1.0);
if (this.ErrorChange && oldError > this.getErrorEstimation()) {
this.ErrorChange = false;
}
//Update statistics
learnFromInstance(weightedInst, ht); //inst
//Check for Split condition
double weightSeen = this.getWeightSeen();
if (weightSeen
- this.getWeightSeenAtLastSplitEvaluation() >= ht.gracePeriodOption
.getValue()) {
ht.attemptToSplit(this, parent,
parentBranch);
this.setWeightSeenAtLastSplitEvaluation(weightSeen);
}
//learnFromInstance alternate Tree and Child nodes
/*if (this.alternateTree != null) {
this.alternateTree.learnFromInstance(inst,ht);
}
for (Node child : this.children) {
if (child != null) {
child.learnFromInstance(inst,ht);
}
}*/
}
@Override
public double[] getClassVotes(Instance inst, HoeffdingTree ht) {
double[] dist;
int predictionOption = ((HoeffdingAdaptiveTree) ht).leafpredictionOption.getChosenIndex();
if (predictionOption == 0) { //MC
dist = this.observedClassDistribution.getArrayCopy();
} else if (predictionOption == 1) { //NB
dist = NaiveBayes.doNaiveBayesPrediction(inst,
this.observedClassDistribution, this.attributeObservers);
} else { //NBAdaptive
if (this.mcCorrectWeight > this.nbCorrectWeight) {
dist = this.observedClassDistribution.getArrayCopy();
} else
dist = NaiveBayes.doNaiveBayesPrediction(inst,
this.observedClassDistribution, this.attributeObservers);
}
//New for option votes
double distSum = Utils.sum(dist);
if (distSum * this.getErrorEstimation() * this.getErrorEstimation() > 0.0) {
Utils.normalize(dist, distSum * this.getErrorEstimation() * this.getErrorEstimation()); //Adding weight
}
return dist;
}
//New for option votes
public void filterInstanceToLeaves(Instance inst,
SplitNode splitparent, int parentBranch,
List<FoundNode> foundNodes, boolean updateSplitterCounts) {
foundNodes.add(new FoundNode(this, splitparent, parentBranch));
}
}
protected int activeLeafNodeCount;
protected int inactiveLeafNodeCount;
protected int alternateTrees;
protected int prunedAlternateTrees;
protected int switchedAlternateTrees;
@Override
protected LearningNode newLearningNode(double[] initialClassObservations) {
// IDEA: to choose different learning nodes depending on predictionOption
return new AdaLearningNode(initialClassObservations);
}
//@Override
protected SplitNode newSplitNode(InstanceConditionalTest splitTest,
double[] classObservations) {
return new AdaSplitNode(splitTest, classObservations);
}
@Override
public void trainOnInstanceImpl(Instance inst) {
if (this.treeRoot == null) {
this.treeRoot = newLearningNode();
this.activeLeafNodeCount = 1;
}
((NewNode) this.treeRoot).learnFromInstance(inst, this, null, -1);
}
//New for options vote
public FoundNode[] filterInstanceToLeaves(Instance inst,
SplitNode parent, int parentBranch, boolean updateSplitterCounts) {
List<FoundNode> nodes = new LinkedList<FoundNode>();
((NewNode) this.treeRoot).filterInstanceToLeaves(inst, parent, parentBranch, nodes,
updateSplitterCounts);
return nodes.toArray(new FoundNode[nodes.size()]);
}
@Override
public double[] getVotesForInstance(Instance inst) {
if (this.treeRoot != null) {
FoundNode[] foundNodes = filterInstanceToLeaves(inst,
null, -1, false);
DoubleVector result = new DoubleVector();
int predictionPaths = 0;
for (FoundNode foundNode : foundNodes) {
if (foundNode.parentBranch != -999) {
Node leafNode = foundNode.node;
if (leafNode == null) {
leafNode = foundNode.parent;
}
double[] dist = leafNode.getClassVotes(inst, this);
//Albert: changed for weights
//double distSum = Utils.sum(dist);
//if (distSum > 0.0) {
// Utils.normalize(dist, distSum);
//}
result.addValues(dist);
//predictionPaths++;
}
}
//if (predictionPaths > this.maxPredictionPaths) {
// this.maxPredictionPaths++;
//}
return result.getArrayRef();
}
return new double[0];
}
}