package tr.gov.ulakbim.jDenetX.classifiers; /* * ActiveClassifier.java * Copyright (C) 2011 University of Waikato, Hamilton, New Zealand * @author Indre Zliobaite (zliobaite at gmail dot com) * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ import tr.gov.ulakbim.jDenetX.core.DoubleVector; import tr.gov.ulakbim.jDenetX.core.Measurement; import tr.gov.ulakbim.jDenetX.options.ClassOption; import tr.gov.ulakbim.jDenetX.options.FloatOption; import tr.gov.ulakbim.jDenetX.options.MultiChoiceOption; import weka.core.Instance; import weka.core.Utils; import java.util.Collections; import java.util.LinkedList; import java.util.List; /** * Active learning setting for evolving data streams. * * <p>Active learning focuses on learning an accurate model with as few labels * as possible. Streaming data poses additional challenges for active learning, * since the data distribution may change over time (concept drift) and classifiers * need to adapt. Conventional active learning strategies concentrate on querying * the most uncertain instances, which are typically concentrated around the * decision boundary. If changes do not occur close to the boundary, they will * be missed and classifiers will fail to adapt. This class contains four active * learning strategies for streaming data that explicitly handle concept drift. * They are based on randomization, fixed uncertainty, dynamic allocation of * labeling efforts over time and randomization of the search space [ZBPH]. * It also contains the Selective Sampling strategy, which is adapted from [CGZ] * it uses a variable labeling threshold. * * </p> * * <p>[ZBPH] Indre Zliobaite, Albert Bifet, Bernhard Pfahringer, Geoff Holmes: * Active Learning with Evolving Streaming Data. ECML/PKDD (3) 2011: 597-612</p> * <p>[CGZ] N. Cesa-Bianchi, C. Gentile, and L. Zaniboni. Worst-case analysis of selective * sampling for linear classification. J. Mach. Learn. Res. (7) 2006: 1205-1230</p>. * * <p>Parameters:</p> <ul> * <li>-l : Classifier to train</li> * <li>-d : Strategy to use: Random, FixedUncertainty, VarUncertainty, RandVarUncertainty, SelSampling</li> </ul> * <li>-b : Budget to use</li> * <li>-u : Fixed threshold</li> * <li>-s : Floating budget step</li> * <li>-n : Number of instances at beginning without active learning</li> * * * @author Indre Zliobaite (zliobaite at gmail dot com) * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * @version $Revision: 7 $ */ public class ActiveClassifier extends AbstractClassifier { private static final long serialVersionUID = 1L; public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "SingleClassifierDrift"); public MultiChoiceOption activeLearningStrategyOption = new MultiChoiceOption( "activeLearningStrategy", 'd', "Active Learning Strategy to use.", new String[]{ "Random", "FixedUncertainty", "VarUncertainty", "RandVarUncertainty", "SelSampling"}, new String[]{ "Random strategy", "Fixed uncertainty strategy", "Uncertainty strategy with variable threshold", "Uncertainty strategy with randomized variable threshold", "Selective Sampling"}, 0); public FloatOption budgetOption = new FloatOption("budget", 'b', "Budget to use.", 0.1, 0.0, 1.0); public FloatOption fixedThresholdOption = new FloatOption("fixedThreshold", 'u', "Fixed threshold.", 0.9, 0.00, 1.00); public FloatOption stepOption = new FloatOption("step", 's', "Floating budget step.", 0.01, 0.00, 1.00); public FloatOption numInstancesInitOption = new FloatOption("numInstancesInit", 'n', "Number of instances at beginning without active learning.", 0.0, 0.00, Integer.MAX_VALUE); public Classifier classifier; public int costLabeling; public int costLabelingRandom; public int iterationControl; public double newThreshold; public double maxPosterior; public double accuracyBaseLearner; private double getMaxPosterior(double[] incomingPrediction) { double outPosterior; if (incomingPrediction.length > 1) { DoubleVector vote = new DoubleVector(incomingPrediction); if (vote.sumOfValues() > 0.0) { vote.normalize(); } incomingPrediction = vote.getArrayRef(); outPosterior = (incomingPrediction[Utils.maxIndex(incomingPrediction)]); } else { outPosterior = 0; } return outPosterior; } private void labelRandom(Instance inst) { if (this.classifierRandom.nextDouble() < this.budgetOption.getValue()) { this.classifier.trainOnInstance(inst); this.costLabeling++; this.costLabelingRandom++; } } private void labelFixed(double incomingPosterior, Instance inst) { if (incomingPosterior < this.fixedThresholdOption.getValue()) { this.classifier.trainOnInstance(inst); this.costLabeling++; } } private void labelVar(double incomingPosterior, Instance inst) { if (incomingPosterior < this.newThreshold) { this.classifier.trainOnInstance(inst); this.costLabeling++; this.newThreshold *= (1 - this.stepOption.getValue()); } else { this.newThreshold *= (1 + this.stepOption.getValue()); } } private void labelSelSampling(double incomingPosterior, Instance inst) { double p = Math.abs(incomingPosterior - 1.0 / (inst.numClasses())); double budget = this.budgetOption.getValue() / (this.budgetOption.getValue() + p); if (this.classifierRandom.nextDouble() < budget) { this.classifier.trainOnInstance(inst); this.costLabeling++; } } @Override public void resetLearningImpl() { this.classifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy(); this.classifier.resetLearning(); this.costLabeling = 0; this.costLabelingRandom = 0; this.iterationControl = 0; this.newThreshold = 1.0; this.accuracyBaseLearner = 0; } @Override public void trainOnInstanceImpl(Instance inst) { this.iterationControl++; double costNow; if (this.iterationControl <= this.numInstancesInitOption.getValue()) { costNow = 0; } else { costNow = (this.costLabeling - this.numInstancesInitOption.getValue()) / ((double) this.iterationControl - this.numInstancesInitOption.getValue()); } if (costNow < this.budgetOption.getValue()) { //allow to label switch (this.activeLearningStrategyOption.getChosenIndex()) { case 0: //Random labelRandom(inst); break; case 1: //fixed maxPosterior = getMaxPosterior(this.classifier.getVotesForInstance(inst)); labelFixed(maxPosterior, inst); break; case 2: //variable maxPosterior = getMaxPosterior(this.classifier.getVotesForInstance(inst)); labelVar(maxPosterior, inst); break; case 3: //randomized maxPosterior = getMaxPosterior(this.classifier.getVotesForInstance(inst)); maxPosterior = maxPosterior / (this.classifierRandom.nextGaussian() + 1.0); labelVar(maxPosterior, inst); break; case 4: //selective-sampling maxPosterior = getMaxPosterior(this.classifier.getVotesForInstance(inst)); labelSelSampling(maxPosterior, inst); break; } } } @Override public double[] getVotesForInstance(Instance inst) { return this.classifier.getVotesForInstance(inst); } @Override public boolean isRandomizable() { return true; } @Override public void getModelDescription(StringBuilder out, int indent) { ((AbstractClassifier) this.classifier).getModelDescription(out, indent); } @Override protected Measurement[] getModelMeasurementsImpl() { List<Measurement> measurementList = new LinkedList<Measurement>(); measurementList.add(new Measurement("labeling cost", this.costLabeling)); measurementList.add(new Measurement("newThreshold", this.newThreshold)); measurementList.add(new Measurement("maxPosterior", this.maxPosterior)); measurementList.add(new Measurement("accuracyBaseLearner (percent)", 100 * this.accuracyBaseLearner / this.costLabeling)); Measurement[] modelMeasurements = ((AbstractClassifier) this.classifier).getModelMeasurementsImpl(); if (modelMeasurements != null) { Collections.addAll(measurementList, modelMeasurements); } return measurementList.toArray(new Measurement[measurementList.size()]); } }