/*
* OzaBagASHT.java
* Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
* @author Albert Bifet
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
package tr.gov.ulakbim.jDenetX.classifiers;
import tr.gov.ulakbim.jDenetX.core.DoubleVector;
import tr.gov.ulakbim.jDenetX.core.MiscUtils;
import tr.gov.ulakbim.jDenetX.options.FlagOption;
import tr.gov.ulakbim.jDenetX.options.IntOption;
import weka.core.Instance;
import weka.core.Utils;
public class OzaBagASHOT extends OzaBag {
private static final long serialVersionUID = 1L;
public IntOption firstClassifierSizeOption = new IntOption("firstClassifierSize", 'f',
"The size of first classifier in the bag.", 1, 1, Integer.MAX_VALUE);
public FlagOption useWeightOption = new FlagOption("useWeight",
'u', "Enable weight classifiers.");
public FlagOption resetTreesOption = new FlagOption("resetTrees",
'r', "Reset trees when size is higher than the max.");
protected double[] error;
protected double alpha = 0.01;
/**
* Initializes the variables and components of the algorithm
*/
@Override
public void resetLearningImpl() {
System.out.println("Learning is resetted");
StackTraceElement[] stackTraceElements = Thread.currentThread().getStackTrace();
System.out.println("Caller method name " + stackTraceElements[2].getMethodName());
this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
this.error = new double[this.ensembleSizeOption.getValue()];
Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
baseLearner.resetLearning();
int pow = this.firstClassifierSizeOption.getValue(); //EXTENSION TO ASHT
for (int i = 0; i < this.ensemble.length; i++) {
this.ensemble[i] = baseLearner.copy();
this.error[i] = 0.0;
((ASHoeffdingOptionTree) this.ensemble[i]).setMaxSize(pow); //EXTENSION TO ASHT
if ((this.resetTreesOption != null)
&& this.resetTreesOption.isSet()) {
((ASHoeffdingOptionTree) this.ensemble[i]).setResetTree();
}
pow *= 2; //EXTENSION TO ASHT
}
}
@Override
public void trainOnInstanceImpl(Instance inst) {
final int trueClass = (int) inst.classValue();
//System.out.println("Ensemble Length " + this.ensemble.length);
for (int i = 0; i < this.ensemble.length; i++) {
final int k = MiscUtils.poisson(1.0, this.classifierRandom);
if (k > 0) {
final Instance weightedInst = (Instance) inst.copy();
weightedInst.setWeight(inst.weight() * k);
if (Utils.maxIndex(this.ensemble[i].getVotesForInstance(inst)) == trueClass) {
this.error[i] += alpha * (0.0 - this.error[i]); //EWMA
} else {
this.error[i] += alpha * (1.0 - this.error[i]); //EWMA
}
this.ensemble[i].trainOnInstance(weightedInst);
}
// System.out.println("ClassifierRandom: " + k);
//System.out.println("EWMA Error Ensemble "+i+" "+ this.error[i]);
if (this.error[i] > 0.6) {
System.out.println("Error is " + this.error[i]);
System.out.println("Ensemble " + i);
System.err.println("Warning!!!!!");
}
}
}
public double[] getVotesForInstance(Instance inst) {
DoubleVector combinedVote = new DoubleVector();
for (int i = 0; i < this.ensemble.length; i++) {
DoubleVector vote = new DoubleVector(this.ensemble[i]
.getVotesForInstance(inst));
if (vote.sumOfValues() > 0.0) {
vote.normalize();
if ((this.useWeightOption != null)
&& this.useWeightOption.isSet()) {
vote.scaleValues(1.0 / (this.error[i] * this.error[i]));
}
combinedVote.addValues(vote);
}
}
return combinedVote.getArrayRef();
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
// TODO Auto-generated method stub
}
}