package tr.gov.ulakbim.jDenetX.classifiers; import tr.gov.ulakbim.jDenetX.classifiers.ensemble.QBC; import tr.gov.ulakbim.jDenetX.core.*; import tr.gov.ulakbim.jDenetX.options.ClassOption; import tr.gov.ulakbim.jDenetX.options.FlagOption; import tr.gov.ulakbim.jDenetX.options.IntOption; import weka.core.Instance; /** * Created by IntelliJ IDEA. * User: caglar * Date: Sep 28, 2010 * Time: 12:13:04 PM * To change this template use File | Settings | File Templates. */ public class SelfOzaBoost extends AbstractClassifier { private static final long serialVersionUID = 1L; public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "HoeffdingTree"); public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models to boost.", 10, 1, Integer.MAX_VALUE); public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', "Boost with weights only; no poisson."); private static VotedInstancePool instConfPool = new VotedInstancePool(); protected Classifier[] ensemble; protected double[] scms; protected double[] swms; public static int instConfCount = 0; private static final double confidenceThreshold = 0.98; protected static final double errorRatio = 0.23; @Override public int measureByteSize() { int size = (int) SizeOf.sizeOf(this); for (Classifier classifier : this.ensemble) { size += classifier.measureByteSize(); } return size; } @Override public void resetLearningImpl() { this.ensemble = new Classifier[this.ensembleSizeOption.getValue()]; Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption); baseLearner.resetLearning(); for (int i = 0; i < this.ensemble.length; i++) { this.ensemble[i] = baseLearner.copy(); } instConfCount = 0; this.scms = new double[this.ensemble.length]; this.swms = new double[this.ensemble.length]; } @Override public void trainOnInstanceImpl(Instance inst) { double lambda_d = 1.0; for (int i = 0; i < this.ensemble.length; i++) { double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils .poisson(lambda_d, this.classifierRandom); if (k > 0.0) { Instance weightedInst = (Instance) inst.copy(); weightedInst.setWeight(inst.weight() * k); this.ensemble[i].trainOnInstance(weightedInst); } if (this.ensemble[i].correctlyClassifies(inst)) { this.scms[i] += lambda_d; lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]); } else { this.swms[i] += lambda_d; lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]); } } } protected double getEnsembleMemberWeight(int i) { double em = this.swms[i] / (this.scms[i] + this.swms[i]); if ((em == 0.0) || (em > 0.5)) { return 0.0; } double Bm = em / (1.0 - em); return Math.log(1.0 / Bm); } protected double getEnsembleMemberError(int i) { double em = this.swms[i] / (this.scms[i] + this.swms[i]); return em; } public double getActiveLearningRatio(double qbcEntropy, DoubleVector combinedVote) { int maxIndex = combinedVote.maxIndex(); int ensembleLength = this.ensemble.length; double maxVote = combinedVote.getValue(maxIndex); double activeLearningRatio = (qbcEntropy) * (maxVote / ensembleLength); return activeLearningRatio; } int count = 0; public double[] getVotesForInstance(Instance inst) { DoubleVector combinedVote = new DoubleVector(); DoubleVector confidenceVec = new DoubleVector(); int success = 0; double[] ensembleVotes = new double[inst.numClasses()]; double[] ensMemberWeights = new double[this.ensemble.length]; boolean[] ensMemberFlags = new boolean[this.ensemble.length]; double confidence = 0.0; for (int i = 0; i < this.ensemble.length; i++) { if (!ensMemberFlags[i]) { ensMemberWeights[i] = getEnsembleMemberWeight(i); } if (ensMemberWeights[i] > 0.0) { DoubleVector vote = new DoubleVector(this.ensemble[i] .getVotesForInstance(inst)); if (vote.sumOfValues() > 0.0) { vote.normalize(); vote.scaleValues(ensMemberWeights[i]); combinedVote.addValues(vote); if (getEnsembleMemberError(i) < errorRatio) { // // these are the votes of the ensembles for the classes // success++; //successFlag = true; confidenceVec.addValues(vote); ensembleVotes[confidenceVec.maxIndex()] += confidenceVec.getValue(confidenceVec.maxIndex()); } } } else { break; } } confidenceVec = (DoubleVector) combinedVote.copy(); confidenceVec.normalize(); confidence = confidenceVec.getValue(confidenceVec.maxIndex()); //Reconfigure the activeLearningRatio //For confidence measure add to the pool and in order to fit the confidence value between 0 and 1 divide by success val if (confidence > confidenceThreshold) { double qbcEntropy = QBC.queryByCommitee(ensembleVotes, inst.numClasses(), success, ensemble.length); Math.pow(qbcEntropy, 2); System.out.println("QBC Entropy: " + qbcEntropy); double activeLearningRatio = getActiveLearningRatio(qbcEntropy, combinedVote); inst.setClassValue(combinedVote.maxIndex()); //Set the class value of the instance instConfPool.addVotedInstance(inst, combinedVote .getValue(combinedVote.maxIndex()), activeLearningRatio); instConfCount++; } return combinedVote.getArrayRef(); } public static VotedInstancePool getVotedInstancePool() { return instConfPool; } public boolean isRandomizable() { return true; } @Override public void getModelDescription(StringBuilder out, int indent) { // TODO Auto-generated method stub } @Override protected Measurement[] getModelMeasurementsImpl() { return new Measurement[]{new Measurement("ensemble size", this.ensemble != null ? this.ensemble.length : 0)}; } @Override public Classifier[] getSubClassifiers() { return this.ensemble.clone(); } }