package tr.gov.ulakbim.jDenetX.classifiers;
import tr.gov.ulakbim.jDenetX.core.*;
import tr.gov.ulakbim.jDenetX.options.ClassOption;
import tr.gov.ulakbim.jDenetX.options.IntOption;
import weka.core.Instance;
import weka.core.Instances;
import java.util.Collections;
public class ActiveClusterBagging extends AbstractClassifier {
private static final long serialVersionUID = 1L;
public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
"Classifier to train.", Classifier.class, "ASHoeffdingTreeNB");
public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
"The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
public IntOption windowSizeOption = new IntOption(
"windowSize", 'w',
"The size of window.", 1000, 10, Integer.MAX_VALUE);
protected InstanceClassesPool ClassPool = new InstanceClassesPool();
protected Classifier[] ensemble;
@Override
public int measureByteSize() {
int size = (int) SizeOf.sizeOf(this);
for (Classifier classifier : this.ensemble) {
size += classifier.measureByteSize();
}
return size;
}
@Override
public void resetLearningImpl() {
this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
ClassPool = new InstanceClassesPool();
Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
baseLearner.resetLearning();
for (int i = 0; i < this.ensemble.length; i++) {
this.ensemble[i] = baseLearner.copy();
}
}
@Override
public void trainOnInstanceImpl(Instance inst) {
int trueClass = (int) inst.classValue();
int windowSize = windowSizeOption.getValue();
if (!ClassPool.isInitialized()) {
ClassPool.initialize(inst.numClasses(), Collections.list(inst.enumerateAttributes()), windowSize);
}
ClassPool.addInstance(inst);
if (ClassPool.checkPoolSize(windowSize)) {
ClusterTrainingDataHarvester ctdh = new ClusterTrainingDataHarvester(this.ensemble.length);
for (int c = 0; c < inst.numClasses(); c++) {
Instances[] insts = ctdh.getEnsembleTrainingData(ClassPool.getInstancesInClass(c), ClassPool.getNoOfClasses());
for (int i = 0; i < this.ensemble.length; i++) {
for (int j = 0; j < insts[i].size(); j++) {
this.ensemble[i].trainOnInstance(insts[i].get(j));
}
}
}
ClassPool.clear();
}
}
public double[] getVotesForInstance(Instance inst) {
DoubleVector combinedVote = new DoubleVector();
for (int i = 0; i < this.ensemble.length; i++) {
DoubleVector vote = new DoubleVector(this.ensemble[i]
.getVotesForInstance(inst));
if (vote.sumOfValues() > 0.0) {
vote.normalize();
combinedVote.addValues(vote);
}
}
return combinedVote.getArrayRef();
}
public boolean isRandomizable() {
return true;
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
// TODO Auto-generated method stub
}
@Override
protected Measurement[] getModelMeasurementsImpl() {
return new Measurement[]{new Measurement("ensemble size",
this.ensemble != null ? this.ensemble.length : 0)};
}
@Override
public Classifier[] getSubClassifiers() {
return this.ensemble.clone();
}
}