package tr.gov.ulakbim.jDenetX.classifiers;
import tr.gov.ulakbim.jDenetX.classifiers.ensemble.QBC;
import tr.gov.ulakbim.jDenetX.core.*;
import tr.gov.ulakbim.jDenetX.options.ClassOption;
import tr.gov.ulakbim.jDenetX.options.FlagOption;
import tr.gov.ulakbim.jDenetX.options.IntOption;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import java.util.ArrayList;
import java.util.Collections;
/**
* Created by IntelliJ IDEA.
* User: caglar
* Date: Oct 18, 2010
* Time: 9:23:26 AM
* To change this template use File | Settings | File Templates.
*/
public class SelfOzaBoostID extends AbstractClassifier {
private static final long serialVersionUID = 1L;
public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
"Classifier to train.", Classifier.class, "ASHoeffdingOptionTree");
public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
"The number of models to boost.", 10, 1, Integer.MAX_VALUE);
public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p',
"Boost with weights only; no poisson.");
public FlagOption resetTreesOption = new FlagOption("resetTrees", 'r',
"Reset trees when size is higher than the max. Requires ASHoeffdingOptionTree");
public IntOption firstClassifierSizeOption = new IntOption(
"firstClassifierSize", 'f',
"The size of first classifier in the bag. This option will have effect with only ASHoeffdingOptionTree", 80, 1, Integer.MAX_VALUE);
private static VotedInstancePool instConfPool = new VotedInstancePool();
protected Classifier[] ensemble;
private ArrayList<Attribute> AttList = null;
private Instances ClassificationInstsPool = null;
EuclideanSimilarityDiscoverer TrainingSimilarity = null;
EuclideanSimilarityDiscoverer ClassificationSimilarity = null;
protected double[] scms;
protected double[] swms;
private static final double ConfidenceThreshold = 0.94;
private static final int ReservoirSize = 10000;
protected static final double ErrorRatio = 0.28;
private DoubleVector QBCs = new DoubleVector();
private static DoubleVector Confidences = new DoubleVector();
private int ClassifiedInstances = 0;
@Override
public int measureByteSize() {
int size = (int) SizeOf.sizeOf(this);
for (Classifier classifier : this.ensemble) {
size += classifier.measureByteSize();
}
return size;
}
@Override
public void resetLearningImpl() {
this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
QBCs = new DoubleVector();
ClassifiedInstances = 0;
ClassificationSimilarity = null;
baseLearner.resetLearning();
int pow = this.firstClassifierSizeOption.getValue(); // EXTENSION TO ASHT
for (int i = 0; i < this.ensemble.length; i++) {
this.ensemble[i] = baseLearner.copy();
if ((this.resetTreesOption != null)
&& this.resetTreesOption.isSet()) {
((ASHoeffdingOptionTree) this.ensemble[i]).setMaxSize(pow); // EXTENSION TO ASHT
((ASHoeffdingOptionTree) this.ensemble[i]).setResetTree();
pow *= 2; //EXTENSION TO ASHT
}
}
this.scms = new double[this.ensemble.length];
this.swms = new double[this.ensemble.length];
}
@Override
public void trainOnInstanceImpl(Instance inst) {
double lambda_d = 1.0;
if (TrainingSimilarity == null) {
AttList = (ArrayList<Attribute>) Collections.list(inst.enumerateAttributes()); //new ArrayList<Attribute>();
AttList.add(inst.attribute(inst.classIndex()));
TrainingSimilarity = new EuclideanSimilarityDiscoverer(AttList);
}
TrainingSimilarity.addInstance(inst);
for (int i = 0; i < this.ensemble.length; i++) {
double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils
.poisson(lambda_d, this.classifierRandom);
if (k > 0.0) {
Instance weightedInst = (Instance) inst.copy();
weightedInst.setWeight(inst.weight() * k);
this.ensemble[i].trainOnInstance(weightedInst);
}
if (this.ensemble[i].correctlyClassifies(inst)) {
this.scms[i] += lambda_d;
lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]);
} else {
this.swms[i] += lambda_d;
lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]);
}
}
}
public double getEnsembleMemberError(int i) {
double em = this.swms[i] / (this.scms[i] + this.swms[i]);
return em;
}
protected double getEnsembleMemberWeight(int i) {
double em = getEnsembleMemberError(i);
if ((em == 0.0) || (em > 0.5)) {
return 0.0;
}
double Bm = em / (1.0 - em);
return Math.log(1.0 / Bm);
}
public double unlabeledSimilarity(Instance inst, double beta) {
double unlabeledSimilarity = Math.pow(1 / (ClassificationSimilarity.findDistanceToCenteroid(inst)), beta);
return unlabeledSimilarity;
}
public double labeledSimilarity(Instance inst) {
double labeledSimilarity = 10 * Math.exp(-(1 / TrainingSimilarity.findDistanceToCenteroid(inst)));
return labeledSimilarity;
}
public double getActiveLearningRatio(Instance inst, double qbc) {
double activeLearnRatio;
double beta = 1;
activeLearnRatio = qbc * unlabeledSimilarity(inst, beta) * labeledSimilarity(inst);
return activeLearnRatio;
}
private void addToVotedInstances(Instance inst, double qbc, double confidence) {
if (confidence > ConfidenceThreshold) {
double activeLearningRatio = getActiveLearningRatio(inst, qbc);
instConfPool.addVotedInstance(inst, confidence, activeLearningRatio);
}
}
public DoubleVector calculateVotesForInstance(Instance inst) {
DoubleVector combinedVote = new DoubleVector();
DoubleVector confidenceVote = new DoubleVector();
double[] ensMemberWeights = new double[this.ensemble.length];
boolean[] ensMemberFlags = new boolean[this.ensemble.length];
double[][] comitteeVotes = new double[this.ensemble.length][inst.numClasses()];
int success = 0;
double qbc;
double confidence = 0.0;
for (int i = 0; i < this.ensemble.length; i++) {
if (!ensMemberFlags[i]) {
ensMemberWeights[i] = getEnsembleMemberWeight(i);
ensMemberFlags[i] = true;
}
if (ensMemberWeights[i] > 0.0) {
DoubleVector vote = new DoubleVector(this.ensemble[i]
.getVotesForInstance(inst));
if (vote.sumOfValues() > 0.0) {
vote.scaleValues(ensMemberWeights[i]);
vote.normalize();
combinedVote.addValues(vote);
if (getEnsembleMemberError(i) < ErrorRatio && (((ASHoeffdingOptionTree) this.ensemble[i]).measureTreeDepth() > 1)) {
success++;
confidenceVote.addValues(vote);
comitteeVotes[i] = vote.getArrayRef();
}
}
} else {
break;
}
}
if (confidenceVote.numValues() > 0) {
confidenceVote.normalize();
confidence = confidenceVote.maxValue();
}
qbc = QBC.getKullbackLeiblerDiv(comitteeVotes, confidenceVote, success);
combinedVote = (DoubleVector) confidenceVote.copy();
//Similar to the reservoir sampling
if (ClassifiedInstances == ReservoirSize) {
QBCs.setValue(ClassifiedInstances, qbc);
ClassificationInstsPool.add(inst);
for (int j = 0; j < ClassificationInstsPool.size(); j++) {
ClassificationInstsPool.get(j).setClassValue(combinedVote.maxIndex()); //Set the class value of the instance
addToVotedInstances(ClassificationInstsPool.get(j), QBCs.getValue(j), Confidences.getValue(j));
}
addToVotedInstances(inst, qbc, confidence);
} else if (ClassifiedInstances > ReservoirSize) {
inst.setClassValue(combinedVote.maxIndex());
addToVotedInstances(inst, qbc, confidence);
} else {
QBCs.setValue(ClassifiedInstances, qbc);
ClassificationInstsPool.add(inst);
Confidences.setValue(ClassifiedInstances, confidence);
}
ClassificationSimilarity.addInstance(inst);
return combinedVote;
}
public double[] getVotesForInstance(Instance inst) {
DoubleVector combinedVote;
if (ClassificationSimilarity == null) {
if (AttList != null) {
ClassificationSimilarity = new EuclideanSimilarityDiscoverer(AttList);
} else {
AttList = Collections.list(inst.enumerateAttributes()); //new ArrayList<Attribute>();
AttList.add(inst.attribute(inst.classIndex()));
ClassificationSimilarity = new EuclideanSimilarityDiscoverer(AttList);
}
}
if (ClassificationInstsPool == null) {
ClassificationInstsPool = new Instances("instancePool", AttList, ReservoirSize);
ClassificationInstsPool.setClassIndex(AttList.size() - 1);
}
combinedVote = calculateVotesForInstance(inst);
ClassifiedInstances++;
return combinedVote.getArrayRef();
}
public boolean isRandomizable() {
return true;
}
public static VotedInstancePool getVotedInstancePool() {
return instConfPool;
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
// TODO Auto-generated method stub
}
@Override
protected Measurement[] getModelMeasurementsImpl() {
return new Measurement[]{new Measurement("ensemble size",
this.ensemble != null ? this.ensemble.length : 0)};
}
@Override
public Classifier[] getSubClassifiers() {
return this.ensemble.clone();
}
}