package tr.gov.ulakbim.jDenetX.classifiers;
import tr.gov.ulakbim.jDenetX.classifiers.ensemble.QBC;
import tr.gov.ulakbim.jDenetX.core.ClusterTrainingDataHarvester;
import tr.gov.ulakbim.jDenetX.core.DoubleVector;
import tr.gov.ulakbim.jDenetX.core.InstanceClassesPool;
import tr.gov.ulakbim.jDenetX.core.VotedInstancePool;
import tr.gov.ulakbim.jDenetX.options.FlagOption;
import tr.gov.ulakbim.jDenetX.options.IntOption;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import java.util.ArrayList;
import java.util.Collections;
/**
* Created by IntelliJ IDEA.
* User: caglar
* Date: Sep 1, 2010
* Time: 2:55:17 PM
* To change this template use File | Settings | File Templates.
*/
public class ActiveClusterBaggingASHT extends ActiveClusterBagging {
private static final long serialVersionUID = 1L;
public IntOption firstClassifierSizeOption = new IntOption(
"firstClassifierSize", 'f',
"The size of first classifier in the bag.", 1, 1, Integer.MAX_VALUE);
public FlagOption useWeightOption = new FlagOption("useWeight", 'u',
"Enable weight classifiers.");
private Instances Insts = null;
public FlagOption resetTreesOption = new FlagOption("resetTrees", 'r',
"Reset trees when size is higher than the max.");
private static VotedInstancePool instConfPool = new VotedInstancePool();
protected double[] error;
private final static boolean PoolFlag = false;
protected final static double errorRatio = 0.23;
protected double alpha = 0.01;
public static int instConfCount = 0;
private final static double confidenceThreshold = 0.97;
private static boolean checkSize = true;
@Override
public void resetLearningImpl() {
this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
this.error = new double[this.ensembleSizeOption.getValue()];
instConfPool = new VotedInstancePool();
instConfCount = 0;
Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
baseLearner.resetLearning();
ClassPool = new InstanceClassesPool();
int pow = this.firstClassifierSizeOption.getValue(); // EXTENSION TO ASHT
checkSize = true;
for (int i = 0; i < this.ensemble.length; i++) {
this.ensemble[i] = baseLearner.copy();
this.error[i] = 0.0;
((ASHoeffdingOptionTree) this.ensemble[i]).setMaxSize(pow); // EXTENSION TO ASHT
if ((this.resetTreesOption != null)
&& this.resetTreesOption.isSet()) {
((ASHoeffdingOptionTree) this.ensemble[i]).setResetTree();
}
pow *= 2; // EXTENSION TO ASHT
}
}
public void updateError(Instance inst, int ensembleNo) {
int trueClass = (int) inst.classValue();
if (Utils.maxIndex(this.ensemble[ensembleNo].getVotesForInstance(inst)) == trueClass) { // Here we used the getVotesForInstanceFunction of HoeffdingTree
this.error[ensembleNo] += alpha * (0.0 - this.error[ensembleNo]); // EWMA
} else {
this.error[ensembleNo] += alpha * (1.0 - this.error[ensembleNo]); // EWMA
}
}
public void trainByClassPool(Instance inst, int windowSize) {
int instsInitSize = 0;
int instsSecondarySize = 0;
if (!ClassPool.isInitialized()) {
ArrayList attList = Collections.list(inst.enumerateAttributes());
attList.add(inst.attribute(inst.classIndex()));
ClassPool.initialize(inst.numClasses(), attList, windowSize);
}
ClassPool.addInstance(inst);
if (ClassPool.checkPoolSize(windowSize) && checkSize) {
ClusterTrainingDataHarvester ctdh = new ClusterTrainingDataHarvester(this.ensemble.length);
Instances[] trainingInstances = new Instances[this.ensemble.length];
for (int c = 0; c < inst.numClasses(); c++) {
Instances[] insts = ctdh.getEnsembleTrainingData(ClassPool.getInstancesInClass(c), ClassPool.getNoOfClasses());
for (int i = 0; i < this.ensemble.length; i++) {
instsInitSize = insts[i].size();
if (trainingInstances[i] == null) {
trainingInstances[i] = insts[i];
} else {
for (int j = 0; j < instsInitSize; j++) {
Instance instance = insts[i].get(j);
updateError(instance, i);
trainingInstances[i].add(instance);
}
}
if (i != (insts.length - 1)) {
instsSecondarySize = insts[i + 1].size();
for (int j = 0; j < instsSecondarySize; j++) {
Instance instance = insts[i + 1].get(j);
updateError(instance, i);
trainingInstances[i].add(instance);
}
} else {
int currentEnsemble = (int) (System.nanoTime() % (insts.length - 1));
instsSecondarySize = insts[currentEnsemble].size();
for (int j = 0; j < instsSecondarySize; j++) {
Instance instance = insts[currentEnsemble].get(j);
updateError(instance, i);
trainingInstances[i].add(instance);
}
}
}
}
for (int i = 0; i < trainingInstances.length; i++) {
Instances ensTrainInsts = trainingInstances[i];
Collections.shuffle(ensTrainInsts);
for (Instance tInst : ensTrainInsts) {
this.ensemble[i].trainOnInstance(tInst);
}
}
ClassPool.clear();
}
}
public void trainByWholePool(Instance inst, int windowSize) {
int instsInitSize = 0;
int instsSecondarySize = 0;
if (Insts == null) {
ArrayList attList = Collections.list(inst.enumerateAttributes());
attList.add(inst.attribute(inst.classIndex()));
Insts = new Instances("instancePool", attList, windowSize);
Insts.setClassIndex(attList.size() - 1);
Insts.add(inst);
} else {
Insts.add(inst);
}
if (Insts.size() >= windowSize && checkSize) {
ClusterTrainingDataHarvester ctdh = new ClusterTrainingDataHarvester(this.ensemble.length);
Instances[] insts = ctdh.getEnsembleTrainingData(Insts, Insts.numClasses());
for (int i = 0; i < this.ensemble.length; i++) {
instsInitSize = insts[i].size();
//Train all of them with the current classifier's cluster
for (int j = 0; j < instsInitSize; j++) {
updateError(insts[i].get(j), i);
this.ensemble[i].trainOnInstance(insts[i].get(j));
}
//If it is not the last cluster
if (i != (insts.length - 1)) {
instsSecondarySize = insts[i + 1].size();
for (int j = 0; j < instsSecondarySize; j++) {
updateError(insts[i + 1].get(j), i);
this.ensemble[i].trainOnInstance(insts[i + 1].get(j));
}
} else {
//If it is the last ensemble
int currentEnsemble = (int) (System.nanoTime() % (insts.length - 1));
instsSecondarySize = insts[currentEnsemble].size();
for (int j = 0; j < instsSecondarySize; j++) {
updateError(insts[currentEnsemble].get(j), i);
this.ensemble[i].trainOnInstance(insts[currentEnsemble].get(j));
}
}
}
Insts.clear();
}
}
@Override
public void trainOnInstanceImpl(Instance inst) {
int windowSize = windowSizeOption.getValue();
if (PoolFlag) {
trainByClassPool(inst, windowSize);
} else {
trainByWholePool(inst, windowSize);
}
}
public double getActiveLearningRatio(double qbcEntropy, DoubleVector combinedVote) {
int maxIndex = combinedVote.maxIndex();
int ensembleLength = this.ensemble.length;
double maxVote = combinedVote.getValue(maxIndex);
double activeLearningRatio = Math.pow(qbcEntropy, 2) * (maxVote / ensembleLength);
return activeLearningRatio;
}
/**
* This is the main classification function that is used by the GUI
*/
public double[] getVotesForInstance(Instance inst) {
DoubleVector combinedVote = new DoubleVector();
DoubleVector confidenceVec = new DoubleVector();
// double[] ensembleVotes = new double[inst.numClasses()];
double[][] comitteeVotes = new double[this.ensemble.length][inst.numClasses()];
double kullLeib = 0.0;
int success = 0;
double confidence = 0.0;
for (int i = 0; i < this.ensemble.length; i++) {
DoubleVector vote = new DoubleVector(this.ensemble[i]
.getVotesForInstance(inst));
if (vote.sumOfValues() > 0.0) {
vote.normalize();
combinedVote.addValues(vote);
if ((this.useWeightOption != null)
&& this.useWeightOption.isSet()) {
vote.scaleValues(1.0 / (this.error[i] * this.error[i]));
}
//kullLeib += (1 / (this.ensemble.length)) * QBC.getKullbackLeiblerDiv(combinedVote.getArrayRef(), inst.numClasses(), this.ensemble.length);
if (this.error[i] < errorRatio) {
//
// these are the votes of the ensembles for the classes
//
success++;
confidenceVec.addValues(vote);
// ensembleVotes[confidenceVec.maxIndex()] += vote.getValue(confidenceVec.maxIndex());
comitteeVotes[i] = vote.getArrayRef();
}
}
}
if (confidenceVec.numValues() > 0) {
confidenceVec = (DoubleVector) combinedVote.copy();
//confidenceVec.normalize();
//System.out.println(confidenceVec);
confidence = (confidenceVec.maxValue());
Utils.logs2probs(confidenceVec.getArrayRef());
confidence = confidenceVec.maxValue();
}
//Reconfigure the activeLearningRatio
//For confidence measure add to the pool and in order to fit the confidence value between 0 and 1 divide by success val
if (confidence >= confidenceThreshold) {
//kullLeib = QBC.queryByCommitee(ensembleVotes, inst.numClasses(), success, ensemble.length);
kullLeib = QBC.getKullbackLeiblerDiv(comitteeVotes, confidenceVec, success);
double activeLearningRatio = getActiveLearningRatio(kullLeib, combinedVote);
inst.setClassValue(combinedVote.maxIndex()); //Set the class value of the instance
instConfPool.addVotedInstance(inst, combinedVote
.getValue(combinedVote.maxIndex()), activeLearningRatio);
instConfCount++;
}
return combinedVote.getArrayRef();
}
/**
* This is the main classification function that is used by the GUI
*/
public double[] getVotesForInstance_ori(Instance inst) {
DoubleVector combinedVote = new DoubleVector();
DoubleVector confidenceVec = new DoubleVector();
double[] ensembleVotes = new double[inst.numClasses()];
double qbcEntropy = 0.0;
int success = 0;
double confidence = 0.0;
for (int i = 0; i < this.ensemble.length; i++) {
DoubleVector vote = new DoubleVector(this.ensemble[i]
.getVotesForInstance(inst));
if (vote.sumOfValues() > 0.0) {
vote.normalize();
combinedVote.addValues(vote);
if ((this.useWeightOption != null)
&& this.useWeightOption.isSet()) {
vote.scaleValues(1.0 / (this.error[i] * this.error[i]));
}
if (this.error[i] < errorRatio && ((ASHoeffdingOptionTree) this.ensemble[i]).measureTreeDepth() > 2) {
//
// these are the votes of the ensembles for the classes
//
success++;
confidenceVec.addValues(vote);
ensembleVotes[confidenceVec.maxIndex()] += confidenceVec.getValue(confidenceVec.maxIndex());
}
}
}
if (confidenceVec.numValues() > 0) {
confidenceVec = (DoubleVector) combinedVote.copy();
confidenceVec.normalize();
confidence = (confidenceVec.maxValue());
}
//Reconfigure the activeLearningRatio
//For confidence measure add to the pool and in order to fit the confidence value between 0 and 1 divide by success val
if (confidence >= confidenceThreshold) {
qbcEntropy = QBC.queryByCommitee(ensembleVotes, inst.numClasses(), success, ensemble.length);
double activeLearningRatio = getActiveLearningRatio(qbcEntropy, combinedVote);
inst.setClassValue(combinedVote.maxIndex()); //Set the class value of the instance
instConfPool.addVotedInstance(inst, combinedVote
.getValue(combinedVote.maxIndex()), activeLearningRatio);
instConfCount++;
}
return confidenceVec.getArrayRef();//combinedVote.getArrayRef();
}
public static VotedInstancePool getVotedInstancePool() {
return instConfPool;
}
public void setCheckSize(boolean sizeControl) {
checkSize = sizeControl;
}
@Override
public void getModelDescription(StringBuilder out, int indent) {
// TODO Auto-generated method stub
super.getModelDescription(out, indent);
}
}