package tr.gov.ulakbim.jDenetX.classifiers; /** * Created by IntelliJ IDEA. * User: caglar * Date: 10/19/11 * Time: 1:58 PM * To change this template use File | Settings | File Templates. */ import tr.gov.ulakbim.jDenetX.core.DoubleVector; import tr.gov.ulakbim.jDenetX.core.Measurement; import tr.gov.ulakbim.jDenetX.core.ObjectRepository; import tr.gov.ulakbim.jDenetX.options.ClassOption; import tr.gov.ulakbim.jDenetX.options.FloatOption; import tr.gov.ulakbim.jDenetX.options.IntOption; import tr.gov.ulakbim.jDenetX.tasks.TaskMonitor; import weka.core.Instance; import weka.core.Instances; import weka.core.Utils; import java.util.Random; /** * The Accuracy Weighted Ensemble classifier as proposed * by Wang et al. in "Mining concept-drifting data streams using ensemble classifiers", * KDD 2003. */ public class AccuracyWeightedEnsemble extends AbstractClassifier { /** * Simple weight comparator. * Needed for sorting component classifiers. */ private static final class ClassifierWeightComparator implements java.util.Comparator<double[]> { @Override public int compare(double[] o1, double[] o2) { if (o1[0] > o2[0]) return 1; else if (o1[0] < o2[0]) return -1; else return 0; } } private static final long serialVersionUID = 1L; /** * Simple weight comparator. */ protected static java.util.Comparator<double[]> weightComparator = new ClassifierWeightComparator(); /** * Type of classifier to use as a component classifier. */ public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Classifier.class, "HoeffdingTreeNB -e 1000 -g 100 -c 0.01"); /** * Number of component classifiers. */ public FloatOption memberCountOption = new FloatOption("memberCount", 'n', "The maximum number of classifier in an ensemble.", 15, 1, Integer.MAX_VALUE); /** * Number of classifiers remembered and available for ensemble construction. */ public FloatOption storedCountOption = new FloatOption("storedCount", 'r', "The maximum number of classifiers to store and choose from when creating an ensemble.", 30, 1, Integer.MAX_VALUE); /** * Chunk size. */ public IntOption chunkSizeOption = new IntOption("chunkSize", 'c', "The chunk size used for classifier creation and evaluation.", 500, 1, Integer.MAX_VALUE); /** * Number of folds in candidate classifier cross-validation. */ public IntOption numFoldsOption = new IntOption("numFolds", 'f', "Number of cross-validation folds for candidate classifier testing.", 10, 1, Integer.MAX_VALUE); protected long[] classDistributions; protected Classifier[] ensemble; protected Classifier[] storedLearners; protected double[] ensembleWeights; /** * The weights of stored classifiers. * storedWeights[x][0] = weight * storedWeights[x][1] = classifier */ protected double[][] storedWeights; protected int processedInstances; protected int chunkSize; protected int numFolds; protected int maxMemberCount; protected int maxStoredCount; protected Classifier candidateClassifier; protected Instances currentChunk; @Override public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { this.maxMemberCount = (int)memberCountOption.getValue(); this.maxStoredCount = (int)storedCountOption.getValue(); if(this.maxMemberCount > this.maxStoredCount) { this.maxStoredCount = this.maxMemberCount; } this.chunkSize = this.chunkSizeOption.getValue(); this.numFolds = this.numFoldsOption.getValue(); this.candidateClassifier = (Classifier) getPreparedClassOption(this.learnerOption); this.candidateClassifier.resetLearning(); super.prepareForUseImpl(monitor, repository); } @Override public void resetLearningImpl() { this.currentChunk = null; this.classDistributions = null; this.processedInstances = 0; this.ensemble = new Classifier[0]; this.storedLearners = new Classifier[0]; this.candidateClassifier = (Classifier) getPreparedClassOption(this.learnerOption); this.candidateClassifier.resetLearning(); } @Override public void trainOnInstanceImpl(Instance inst) { this.initVariables(); this.classDistributions[(int) inst.classValue()]++; this.currentChunk.add(inst); this.processedInstances++; if (this.processedInstances % this.chunkSize == 0) { this.processChunk(); } } /** * Initiates the current chunk and class distribution variables. */ private void initVariables() { if (this.currentChunk == null) { this.currentChunk = new Instances(this.getModelContext()); } if (this.classDistributions == null) { this.classDistributions = new long[this.getModelContext().classAttribute().numValues()]; for (int i = 0; i < this.classDistributions.length; i++) { this.classDistributions[i] = 0; } } } /** * Processes a chunk. */ protected void processChunk() { // Compute weights double candidateClassifierWeight = this.computeCandidateWeight(this.candidateClassifier, this.currentChunk, this.numFolds); for (int i = 0; i < this.storedLearners.length; i++) { this.storedWeights[i][0] = this.computeWeight(this.storedLearners[(int)this.storedWeights[i][1]], this.currentChunk); } if (this.storedLearners.length < this.maxStoredCount) { // Train and add classifier for (int num = 0; num < this.chunkSize; num++) { this.candidateClassifier.trainOnInstance(this.currentChunk.instance(num)); } this.addToStored(this.candidateClassifier, candidateClassifierWeight); } else { // Substitute poorest classifier java.util.Arrays.sort(this.storedWeights, weightComparator); if (this.storedWeights[0][0] < candidateClassifierWeight) { for (int num = 0; num < this.chunkSize; num++) { this.candidateClassifier.trainOnInstance(this.currentChunk.instance(num)); } this.storedWeights[0][0] = candidateClassifierWeight; this.storedLearners[(int) this.storedWeights[0][1]] = this.candidateClassifier.copy(); } } int ensembleSize = java.lang.Math.min(this.storedLearners.length, this.maxMemberCount); this.ensemble = new Classifier[ensembleSize]; this.ensembleWeights = new double[ensembleSize]; // Sort learners according to their weights java.util.Arrays.sort(this.storedWeights, weightComparator); // Select top k classifiers to construct the ensemble int storeSize = this.storedLearners.length; for (int i = 0; i < ensembleSize; i++) { this.ensembleWeights[i] = this.storedWeights[storeSize - i - 1][0]; this.ensemble[i] = this.storedLearners[(int) this.storedWeights[storeSize - i - 1][1]]; } this.classDistributions = null; this.currentChunk = null; this.candidateClassifier = (Classifier) getPreparedClassOption(this.learnerOption); this.candidateClassifier.resetLearning(); } /** * Computes the weight of a candidate classifier. * @param candidate Candidate classifier. * @param chunk Data chunk of examples. * @param numFolds Number of folds in candidate classifier cross-validation. * @return Candidate classifier weight. */ protected double computeCandidateWeight(Classifier candidate, Instances chunk, int numFolds) { double candidateWeight = 0.0; Random random = new Random(1); Instances randData = new Instances(chunk); randData.randomize(random); if (randData.classAttribute().isNominal()) { randData.stratify(numFolds); } for (int n = 0; n < numFolds; n++) { Instances train = randData.trainCV(numFolds, n, random); Instances test = randData.testCV(numFolds, n); Classifier learner = candidate.copy(); for (int num = 0; num < train.numInstances(); num++) { learner.trainOnInstance(train.instance(num)); } candidateWeight += computeWeight(learner, test); } double resultWeight = candidateWeight / numFolds; if(Double.isInfinite(resultWeight)) { return Double.MAX_VALUE; } else { return resultWeight; } } /** * Computes the weight of a given classifie. * @param learner Classifier to calculate weight for. * @param chunk Data chunk of examples. * @return The given classifier's weight. */ protected double computeWeight(Classifier learner, Instances chunk) { double mse_i = 0; double mse_r = 0; double f_ci; double voteSum; for (int i = 0; i < chunk.numInstances(); i++) { try { voteSum = 0; for (double element : learner.getVotesForInstance(chunk.instance(i))) { voteSum += element; } if (voteSum > 0) { f_ci = learner.getVotesForInstance(chunk.instance(i))[(int) chunk.instance(i).classValue()] / voteSum; mse_i += (1 - f_ci) * (1 - f_ci); } else { mse_i += 1; } } catch (Exception e) { mse_i += 1; } } mse_i /= this.chunkSize; mse_r = this.computeMseR(); return java.lang.Math.max(mse_r - mse_i, 0); } /** * Computes the MSEr threshold. * @return The MSEr threshold. */ protected double computeMseR() { double p_c; double mse_r = 0; for (int i = 0; i < this.classDistributions.length; i++) { p_c = (double) this.classDistributions[i] / (double) this.chunkSize; mse_r += p_c * ((1 - p_c) * (1 - p_c)); } return mse_r; } /** * Predicts a class for an example. */ public double[] getVotesForInstance(Instance inst) { DoubleVector combinedVote = new DoubleVector(); if (this.trainingWeightSeenByModel > 0.0) { for (int i = 0; i < this.ensemble.length; i++) { if (this.ensembleWeights[i] > 0.0) { DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(inst)); if (vote.sumOfValues() > 0.0) { vote.normalize(); //scale weight and prevent overflow vote.scaleValues(this.ensembleWeights[i]/(1.0 * this.ensemble.length + 1)); combinedVote.addValues(vote); } } } } combinedVote.normalize(); return combinedVote.getArrayRef(); } @Override public void getModelDescription(StringBuilder out, int indent) { } /** * Adds ensemble weights to the measurements. */ @Override protected Measurement[] getModelMeasurementsImpl() { Measurement[] measurements = new Measurement[this.maxStoredCount]; for (int m = 0; m < this.maxMemberCount; m++) { measurements[m] = new Measurement("Member weight " + (m + 1), -1); } for (int s = this.maxMemberCount; s < this.maxStoredCount; s++) { measurements[s] = new Measurement("Stored member weight " + (s + 1), -1); } if (this.storedWeights != null) { int storeSize = this.storedWeights.length; for (int i = 0; i < storeSize; i++) { if(i < this.ensemble.length) { measurements[i] = new Measurement("Member weight " + (i + 1), this.storedWeights[storeSize - i - 1][0]); } else { measurements[i] = new Measurement("Stored member weight " + (i + 1), this.storedWeights[storeSize - i - 1][0]); } } } return measurements; } /** * Determines whether the classifier is randomizable. */ public boolean isRandomizable() { return false; } @Override public Classifier[] getSubClassifiers() { return this.ensemble.clone(); } /** * Adds a classifier to the storage. * * @param newClassifier * The classifier to add. * @param newClassifiersWeight * The new classifiers weight. */ protected Classifier addToStored(Classifier newClassifier, double newClassifiersWeight) { Classifier addedClassifier = null; Classifier[] newStored = new Classifier[this.storedLearners.length + 1]; double[][] newStoredWeights = new double[newStored.length][2]; for (int i = 0; i < newStored.length; i++) { if (i < this.storedLearners.length) { newStored[i] = this.storedLearners[i]; newStoredWeights[i][0] = this.storedWeights[i][0]; newStoredWeights[i][1] = this.storedWeights[i][1]; } else { newStored[i] = addedClassifier = newClassifier.copy(); newStoredWeights[i][0] = newClassifiersWeight; newStoredWeights[i][1] = i; } } this.storedLearners = newStored; this.storedWeights = newStoredWeights; return addedClassifier; } /** * Removes the poorest classifier from the model, thus decreasing the models * size. * * @return the size of the removed classifier. */ protected int removePoorestModelBytes() { int poorestIndex = Utils.minIndex(this.ensembleWeights); int byteSize = this.ensemble[poorestIndex].measureByteSize(); discardModel(poorestIndex); return byteSize; } /** * Removes the classifier at a given index from the model, thus decreasing * the models size. * * @param index */ protected void discardModel(int index) { Classifier[] newEnsemble = new Classifier[this.ensemble.length - 1]; double[] newEnsembleWeights = new double[newEnsemble.length]; int oldPos = 0; for (int i = 0; i < newEnsemble.length; i++) { if (oldPos == index) { oldPos++; } newEnsemble[i] = this.ensemble[oldPos]; newEnsembleWeights[i] = this.ensembleWeights[oldPos]; oldPos++; } this.ensemble = newEnsemble; this.ensembleWeights = newEnsembleWeights; } }