/* * OzaBagASHT.java * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand * @author Caglar * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ package tr.gov.ulakbim.jDenetX.classifiers; import tr.gov.ulakbim.jDenetX.core.DoubleVector; import tr.gov.ulakbim.jDenetX.core.MiscUtils; import tr.gov.ulakbim.jDenetX.core.VotedInstancePool; import tr.gov.ulakbim.jDenetX.options.FlagOption; import tr.gov.ulakbim.jDenetX.options.IntOption; import weka.core.Instance; import weka.core.Utils; import java.util.ArrayList; public class CoOzaBagASHT extends OzaBag { private static final long serialVersionUID = 1L; public IntOption firstClassifierSizeOption = new IntOption( "firstClassifierSize", 'f', "The size of first classifier in the bag.", 1, 1, Integer.MAX_VALUE); public FlagOption useWeightOption = new FlagOption("useWeight", 'u', "Enable weight classifiers."); public FlagOption resetTreesOption = new FlagOption("resetTrees", 'r', "Reset trees when size is higher than the max."); protected double[] error; protected ArrayList<Instance> centroids; protected double alpha = 0.01; private static VotedInstancePool instConfPool = new VotedInstancePool(); public static int instConfCount = 0; private final static double confidenceThreshold = 9.7; @Override public void resetLearningImpl() { this.ensemble = new Classifier[this.ensembleSizeOption.getValue()]; this.error = new double[this.ensembleSizeOption.getValue()]; instConfPool = new VotedInstancePool(); instConfCount = 0; Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); int pow = this.firstClassifierSizeOption.getValue(); // EXTENSION TO ASHT for (int i = 0; i < this.ensemble.length; i++) { this.ensemble[i] = baseLearner.copy(); this.error[i] = 0.0; ((ASHoeffdingOptionTree) this.ensemble[i]).setMaxSize(pow); // EXTENSION TO ASHT if ((this.resetTreesOption != null) && this.resetTreesOption.isSet()) { ((ASHoeffdingOptionTree) this.ensemble[i]).setResetTree(); } pow *= 2; // EXTENSION TO ASHT } } public double getEntropyForArray(double votes[]) { double entropy = 0.0; for (int i = 0; i < votes.length; i++) { votes[i] -= votes[i] * (Math.log(votes[i]) / Math.log(2)); // By Default Java computes Math.log for base e, to compute base 2 we should divide by log(2) } return entropy; } public double getQBCEntropy(double vote, int success) { double entropy = 0.0; entropy -= (vote / success) * Utils.log2(vote / success); //(Math.log(vote) / Math.log(2)); // Default Java log function computes // Math.log for base e, to compute base 2 we // should divide by log(2) return entropy; } /** * Query By Comittee algorithm * This measures the vote entropy. * xve = argmax-Sigma (V(yi)/C)*log(V(yi)/C) * xve is the vote entropy. * C is the comittee size * V(yi) is the number of the votes that a label recieves among the comittee members' votes. */ public double queryByCommitee(double[] ensembleVotes, int noOfClasses, int success) { double entropyQBC = 0.0; double qbc = 0.0; if (noOfClasses != 0) { for (int j = 0; j < noOfClasses; j++) { if (ensembleVotes[j] != 0) { qbc = (double) ensembleVotes[j] / ((double) ensemble.length); //System.out.println("qbc is : " + qbc); entropyQBC -= getQBCEntropy(qbc, success); } } } return entropyQBC; } @Override public void trainOnInstanceImpl(Instance inst) { int trueClass = (int) inst.classValue(); for (int i = 0; i < this.ensemble.length; i++) { int k = MiscUtils.poisson(1.0, this.classifierRandom); if (k > 0) { Instance weightedInst = (Instance) inst.copy(); weightedInst.setWeight(inst.weight() * k); if (Utils.maxIndex(this.ensemble[i].getVotesForInstance(inst)) == trueClass) { // Here we used the getVotesForInstanceFunction of HoeffdingTree this.error[i] += alpha * (0.0 - this.error[i]); // EWMA } else { this.error[i] += alpha * (1.0 - this.error[i]); // EWMA } this.ensemble[i].trainOnInstance(weightedInst); } } } /** * This is the main classification function that is used by the GUI */ public double[] getVotesForInstance(Instance inst) { DoubleVector combinedVote = new DoubleVector(); DoubleVector confidenceVec = new DoubleVector(); double[] ensembleVotes = new double[inst.numClasses()]; double qbcEntropy = 0.0; int success = 0; int alpha1 = 1; int alpha2 = 1; for (int i = 0; i < this.ensemble.length; i++) { DoubleVector vote = new DoubleVector(this.ensemble[i] .getVotesForInstance(inst)); if (vote.sumOfValues() > 0.0) { vote.normalize(); confidenceVec.addValues(vote); if ((this.useWeightOption != null) && this.useWeightOption.isSet()) { vote.scaleValues(1.0 / (this.error[i] * this.error[i])); //System.out.println("Ensemble : " + i + " Error: " + this.error[i]); } combinedVote.addValues(vote); } // //Ignore the classifiers which have high error ratio // if (this.error[i] < 0.23) { // // this is the votes of the ensembles for the classes // success++; ensembleVotes[combinedVote.maxIndex()] += combinedVote.getValue(combinedVote.maxIndex()); } } //For confidence measure add to the pool and in order to fit the confidence value between 0 and 1 divide by success val //System.out.println("Confidence " + combinedVote.getValue(combinedVote.maxIndex())); if ((confidenceVec.getValue(combinedVote.maxIndex())) >= confidenceThreshold) { qbcEntropy = queryByCommitee(ensembleVotes, inst.numClasses(), success); double activeLearningRatio = (qbcEntropy) * (combinedVote.getValue(combinedVote.maxIndex()) / this.ensemble.length); inst.setClassValue(combinedVote.maxIndex()); //Set the class value of the instance instConfPool.addVotedInstance(inst, combinedVote .getValue(combinedVote.maxIndex()), activeLearningRatio); instConfCount++; } return combinedVote.getArrayRef(); } /** * This is the main classification function that is used by the GUI */ public double[] getVotesForInstanceOrig(Instance inst) { DoubleVector combinedVote = new DoubleVector(); double[] ensembleVotes = new double[inst.numClasses()]; double qbcEntropy = 0.0; int success = 0; for (int i = 0; i < this.ensemble.length; i++) { DoubleVector vote = new DoubleVector(this.ensemble[i] .getVotesForInstance(inst)); // This will call the HoeffdingTree's getVotesForInstance Function if (vote.sumOfValues() > 0.0) { vote.normalize(); if ((this.useWeightOption != null) && this.useWeightOption.isSet()) { vote.scaleValues(1.0 / (this.error[i] * this.error[i])); System.out.println("Ensemble : " + i + " Error: " + this.error[i]); } // //Ignore the ensembles which have high error ratio // if (this.error[i] < 0.3) { combinedVote.addValues(vote); } } // // this is the votes of the ensembles for the classes // if (this.error[i] < 0.3) { success++; ensembleVotes[combinedVote.maxIndex()] += combinedVote.getValue(combinedVote.maxIndex()); } } // For confidence measure add to the pool and in order to fit the confidence value between 0 and 1 divide by success val if ((combinedVote.getValue(combinedVote.maxIndex()) / success) >= confidenceThreshold) { qbcEntropy = queryByCommitee(ensembleVotes, inst.numClasses(), 0); System.out.println("QBC Entropy: " + qbcEntropy); double activeLearningRatio = (qbcEntropy) + (combinedVote.getValue(combinedVote.maxIndex()) / this.ensemble.length); inst.setClassValue(combinedVote.maxIndex()); instConfPool.addVotedInstance(inst, combinedVote .getValue(combinedVote.maxIndex()), activeLearningRatio); } return combinedVote.getArrayRef(); } public static VotedInstancePool getVotedInstancePool() { return instConfPool; } @Override public void getModelDescription(StringBuilder out, int indent) { // TODO Auto-generated method stub super.getModelDescription(out, indent); } }