/*
* WeightedMajorityAlgorithm.java
* Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
package tr.gov.ulakbim.jDenetX.classifiers;
import tr.gov.ulakbim.jDenetX.core.DoubleVector;
import tr.gov.ulakbim.jDenetX.core.Measurement;
import tr.gov.ulakbim.jDenetX.core.ObjectRepository;
import tr.gov.ulakbim.jDenetX.options.*;
import tr.gov.ulakbim.jDenetX.tasks.TaskMonitor;
import weka.core.Instance;
import weka.core.Utils;
public class WeightedMajorityAlgorithm extends AbstractClassifier {
private static final long serialVersionUID = 1L;
public ListOption learnerListOption = new ListOption(
"learners",
'l',
"The learners to combine.",
new ClassOption("learner", ' ', "", Classifier.class,
"HoeffdingTree"),
new Option[]{
new ClassOption("", ' ', "", Classifier.class,
"HoeffdingTree"),
new ClassOption("", ' ', "", Classifier.class,
"HoeffdingTreeNB"),
new ClassOption("", ' ', "", Classifier.class,
"HoeffdingTreeNBAdaptive"),
new ClassOption("", ' ', "", Classifier.class, "NaiveBayes")},
',');
public FloatOption betaOption = new FloatOption("beta", 'b',
"Factor to punish mistakes by.", 0.9, 0.0, 1.0);
public FloatOption gammaOption = new FloatOption("gamma", 'g',
"Minimum fraction of weight per model.", 0.01, 0.0, 0.5);
public FlagOption pruneOption = new FlagOption("prune", 'p',
"Prune poorly performing models from ensemble.");
protected Classifier[] ensemble;
protected double[] ensembleWeights;
@Override
public void prepareForUseImpl(TaskMonitor monitor,
ObjectRepository repository) {
Option[] learnerOptions = this.learnerListOption.getList();
this.ensemble = new Classifier[learnerOptions.length];
for (int i = 0; i < learnerOptions.length; i++) {
monitor.setCurrentActivity("Materializing learner " + (i + 1)
+ "...", -1.0);
this.ensemble[i] = (Classifier) ((ClassOption) learnerOptions[i])
.materializeObject(monitor, repository);
if (monitor.taskShouldAbort()) {
return;
}
monitor.setCurrentActivity("Preparing learner " + (i + 1) + "...",
-1.0);
this.ensemble[i].prepareForUse(monitor, repository);
if (monitor.taskShouldAbort()) {
return;
}
}
super.prepareForUseImpl(monitor, repository);
}
@Override
public void resetLearningImpl() {
this.ensembleWeights = new double[this.ensemble.length];
for (int i = 0; i < this.ensemble.length; i++) {
this.ensemble[i].resetLearning();
this.ensembleWeights[i] = 1.0;
}
}
@Override
public void trainOnInstanceImpl(Instance inst) {
double totalWeight = 0.0;
for (int i = 0; i < this.ensemble.length; i++) {
boolean prune = false;
if (!this.ensemble[i].correctlyClassifies(inst)) {
if (this.ensembleWeights[i] > this.gammaOption.getValue()
/ this.ensembleWeights.length) {
this.ensembleWeights[i] *= this.betaOption.getValue()
* inst.weight();
} else if (this.pruneOption.isSet()) {
prune = true;
discardModel(i);
i--;
}
}
if (!prune) {
totalWeight += this.ensembleWeights[i];
this.ensemble[i].trainOnInstance(inst);
}
}
// normalize weights
for (int i = 0; i < this.ensembleWeights.length; i++) {
this.ensembleWeights[i] /= totalWeight;
}
}
public double[] getVotesForInstance(Instance inst) {
DoubleVector combinedVote = new DoubleVector();
if (this.trainingWeightSeenByModel > 0.0) {
for (int i = 0; i < this.ensemble.length; i++) {
if (this.ensembleWeights[i] > 0.0) {
DoubleVector vote = new DoubleVector(this.ensemble[i]
.getVotesForInstance(inst));
if (vote.sumOfValues() > 0.0) {
vote.normalize();
vote.scaleValues(this.ensembleWeights[i]);
combinedVote.addValues(vote);
}
}
}
}
return combinedVote.getArrayRef();
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
// TODO Auto-generated method stub
}
@Override
protected Measurement[] getModelMeasurementsImpl() {
Measurement[] measurements = null;
if (this.ensembleWeights != null) {
measurements = new Measurement[this.ensembleWeights.length];
for (int i = 0; i < this.ensembleWeights.length; i++) {
measurements[i] = new Measurement("member weight " + (i + 1),
this.ensembleWeights[i]);
}
}
return measurements;
}
public boolean isRandomizable() {
return false;
}
@Override
public Classifier[] getSubClassifiers() {
return this.ensemble.clone();
}
public void discardModel(int index) {
Classifier[] newEnsemble = new Classifier[this.ensemble.length - 1];
double[] newEnsembleWeights = new double[newEnsemble.length];
int oldPos = 0;
for (int i = 0; i < newEnsemble.length; i++) {
if (oldPos == index) {
oldPos++;
}
newEnsemble[i] = this.ensemble[oldPos];
newEnsembleWeights[i] = this.ensembleWeights[oldPos];
oldPos++;
}
this.ensemble = newEnsemble;
this.ensembleWeights = newEnsembleWeights;
}
protected int removePoorestModelBytes() {
int poorestIndex = Utils.minIndex(this.ensembleWeights);
int byteSize = this.ensemble[poorestIndex].measureByteSize();
discardModel(poorestIndex);
return byteSize;
}
}