/* * OzaBagASHT.java * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand * @author Albert Bifet * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ package tr.gov.ulakbim.jDenetX.classifiers; import tr.gov.ulakbim.jDenetX.core.DoubleVector; import tr.gov.ulakbim.jDenetX.core.MiscUtils; import tr.gov.ulakbim.jDenetX.options.FlagOption; import tr.gov.ulakbim.jDenetX.options.IntOption; import weka.core.Instance; import weka.core.Utils; public class OzaBagASHOT extends OzaBag { private static final long serialVersionUID = 1L; public IntOption firstClassifierSizeOption = new IntOption("firstClassifierSize", 'f', "The size of first classifier in the bag.", 1, 1, Integer.MAX_VALUE); public FlagOption useWeightOption = new FlagOption("useWeight", 'u', "Enable weight classifiers."); public FlagOption resetTreesOption = new FlagOption("resetTrees", 'r', "Reset trees when size is higher than the max."); protected double[] error; protected double alpha = 0.01; /** * Initializes the variables and components of the algorithm */ @Override public void resetLearningImpl() { System.out.println("Learning is resetted"); StackTraceElement[] stackTraceElements = Thread.currentThread().getStackTrace(); System.out.println("Caller method name " + stackTraceElements[2].getMethodName()); this.ensemble = new Classifier[this.ensembleSizeOption.getValue()]; this.error = new double[this.ensembleSizeOption.getValue()]; Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); int pow = this.firstClassifierSizeOption.getValue(); //EXTENSION TO ASHT for (int i = 0; i < this.ensemble.length; i++) { this.ensemble[i] = baseLearner.copy(); this.error[i] = 0.0; ((ASHoeffdingOptionTree) this.ensemble[i]).setMaxSize(pow); //EXTENSION TO ASHT if ((this.resetTreesOption != null) && this.resetTreesOption.isSet()) { ((ASHoeffdingOptionTree) this.ensemble[i]).setResetTree(); } pow *= 2; //EXTENSION TO ASHT } } @Override public void trainOnInstanceImpl(Instance inst) { final int trueClass = (int) inst.classValue(); //System.out.println("Ensemble Length " + this.ensemble.length); for (int i = 0; i < this.ensemble.length; i++) { final int k = MiscUtils.poisson(1.0, this.classifierRandom); if (k > 0) { final Instance weightedInst = (Instance) inst.copy(); weightedInst.setWeight(inst.weight() * k); if (Utils.maxIndex(this.ensemble[i].getVotesForInstance(inst)) == trueClass) { this.error[i] += alpha * (0.0 - this.error[i]); //EWMA } else { this.error[i] += alpha * (1.0 - this.error[i]); //EWMA } this.ensemble[i].trainOnInstance(weightedInst); } // System.out.println("ClassifierRandom: " + k); //System.out.println("EWMA Error Ensemble "+i+" "+ this.error[i]); if (this.error[i] > 0.6) { System.out.println("Error is " + this.error[i]); System.out.println("Ensemble " + i); System.err.println("Warning!!!!!"); } } } public double[] getVotesForInstance(Instance inst) { DoubleVector combinedVote = new DoubleVector(); for (int i = 0; i < this.ensemble.length; i++) { DoubleVector vote = new DoubleVector(this.ensemble[i] .getVotesForInstance(inst)); if (vote.sumOfValues() > 0.0) { vote.normalize(); if ((this.useWeightOption != null) && this.useWeightOption.isSet()) { vote.scaleValues(1.0 / (this.error[i] * this.error[i])); } combinedVote.addValues(vote); } } return combinedVote.getArrayRef(); } @Override public void getModelDescription(StringBuilder out, int indent) { // TODO Auto-generated method stub } }