package tr.gov.ulakbim.jDenetX.classifiers; /** * Created by IntelliJ IDEA. * User: caglar * Date: 10/19/11 * Time: 1:56 PM * To change this template use File | Settings | File Templates. */ import tr.gov.ulakbim.jDenetX.core.ObjectRepository; import tr.gov.ulakbim.jDenetX.tasks.TaskMonitor; import weka.core.Instances; /** * The Accuracy Updated Ensemble classifier as proposed by Brzezinski and * Stefanowski in "Accuracy Updated Ensemble for Data Streams with Concept * Drift", HAIS 2011. */ public class AccuracyUpdatedEnsemble extends AccuracyWeightedEnsemble { private static final long serialVersionUID = 1L; @Override public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { super.prepareForUseImpl(monitor, repository); } @Override protected void processChunk() { Classifier addedClassifier = null; // Compute weights double candidateClassifierWeight = this.computeCandidateWeight( this.candidateClassifier, this.currentChunk, this.numFolds); for (int i = 0; i < this.storedLearners.length; i++) { this.storedWeights[i][0] = this.computeWeight( this.storedLearners[(int) this.storedWeights[i][1]], this.currentChunk); } if (this.storedLearners.length < this.maxStoredCount) { // Train and add classifier this.trainOnChunk(this.candidateClassifier); addedClassifier = this.addToStored(this.candidateClassifier, candidateClassifierWeight); } else { // Substitute poorest classifier java.util.Arrays.sort(this.storedWeights, weightComparator); if (this.storedWeights[0][0] < candidateClassifierWeight) { this.trainOnChunk(this.candidateClassifier); this.storedWeights[0][0] = candidateClassifierWeight; addedClassifier = this.candidateClassifier.copy(); this.storedLearners[(int) this.storedWeights[0][1]] = addedClassifier; } } int ensembleSize = java.lang.Math.min(this.storedLearners.length, this.maxMemberCount); this.ensemble = new Classifier[ensembleSize]; this.ensembleWeights = new double[ensembleSize]; // Sort learners according to their weights java.util.Arrays.sort(this.storedWeights, weightComparator); double mse_r = this.computeMseR(); // Select top k classifiers to construct the ensemble int storeSize = this.storedLearners.length; for (int i = 0; i < ensembleSize; i++) { this.ensembleWeights[i] = this.storedWeights[storeSize - i - 1][0]; this.ensemble[i] = this.storedLearners[(int) this.storedWeights[storeSize - i - 1][1]]; if (this.ensemble[i] != addedClassifier) { if (mse_r > 0 && this.ensembleWeights[i] > 1 / mse_r) { this.trainOnChunk(this.ensemble[i]); } } } this.classDistributions = null; this.currentChunk = null; this.candidateClassifier = (Classifier) getPreparedClassOption(this.learnerOption); this.candidateClassifier.resetLearning(); } @Override protected double computeWeight(Classifier learner, Instances chunk) { double mse_i = 0; double f_ci; double voteSum; for (int i = 0; i < chunk.numInstances(); i++) { try { voteSum = 0; for (double element : learner.getVotesForInstance(chunk .instance(i))) { voteSum += element; } if (voteSum > 0) { f_ci = learner.getVotesForInstance(chunk.instance(i))[(int) chunk .instance(i).classValue()] / voteSum; mse_i += (1 - f_ci) * (1 - f_ci); } else { mse_i += 1; } } catch (Exception e) { mse_i += 1; } } mse_i /= this.chunkSize; if (mse_i > 0) { return 1.0 / mse_i; } else { return Double.MAX_VALUE; } } /** * Trains a component classifier on the most recent chunk of data. * * @param classifierToTrain Classifier being trained. */ private void trainOnChunk(Classifier classifierToTrain) { for (int num = 0; num < this.chunkSize; num++) { classifierToTrain.trainOnInstance(this.currentChunk.instance(num)); } } /** * Determines whether the classifier is randomizable. */ public boolean isRandomizable() { return false; } }