package tr.gov.ulakbim.jDenetX.evaluation;
import org.apache.commons.lang3.ArrayUtils;
import tr.gov.ulakbim.jDenetX.AbstractMOAObject;
import tr.gov.ulakbim.jDenetX.core.Measurement;
import weka.core.Utils;
import java.util.HashMap;
/**
* Created by IntelliJ IDEA.
* User: caglar
* Date: 10/19/11
* Time: 3:30 PM
* To change this template use File | Settings | File Templates.
*/
public class SelfOzaBoostClassificationPerformanceEvaluator extends AbstractMOAObject
implements ClassificationPerformanceEvaluator {
private static final long serialVersionUID = 1L;
protected double weightObserved;
protected double weightCorrect;
protected double[] columnKappa;
protected double[] rowKappa;
protected int[] instanceClassesMap;
protected HashMap<String, Integer> ClassesCountMap;
protected int numClasses;
private double SE = 0.0;
private int NoOfProcessedInstances = 0;
public void reset() {
reset(this.numClasses);
}
public void reset (int numClasses) {
this.numClasses = numClasses;
this.rowKappa = new double[numClasses];
this.columnKappa = new double[numClasses];
this.instanceClassesMap = new int[numClasses];
this.ClassesCountMap = new HashMap<String, Integer>();
for (int i = 0; i < this.numClasses; i++) {
this.rowKappa[i] = 0;
this.columnKappa[i] = 0;
this.instanceClassesMap[i] = 0;
}
this.SE = 0.0;
NoOfProcessedInstances = 0;
this.weightObserved = 0.0;
this.weightCorrect = 0.0;
}
public void addClassificationAttempt (int trueClass,
double[] classVotes,
double weight) {
if (weight > 0.0) {
NoOfProcessedInstances++;
if (this.weightObserved == 0) {
reset(classVotes.length > 1 ? classVotes.length : 2);
}
this.weightObserved += weight;
int predictedClass = Utils.maxIndex(classVotes);
if (predictedClass == trueClass) {
this.weightCorrect += weight;
}
this.SE += Evaluation.getSqError(trueClass, classVotes, weight);
this.rowKappa[predictedClass] += weight;
this.columnKappa[trueClass] += weight;
this.instanceClassesMap[trueClass]++;
}
}
public void addClassificationAttempt (int trueClass,
String className,
double[] classVotes,
double weight) {
if (weight > 0.0) {
NoOfProcessedInstances++;
if (this.weightObserved == 0) {
reset(classVotes.length > 1 ? classVotes.length : 2);
}
this.weightObserved += weight;
int predictedClass = Utils.maxIndex(classVotes);
if (predictedClass == trueClass) {
this.weightCorrect += weight;
}
this.SE += Evaluation.getSqError(trueClass, classVotes, weight);
this.rowKappa[predictedClass] += weight;
this.columnKappa[trueClass] += weight;
instanceClassesMap[trueClass]++;
ClassesCountMap.put(className, (Integer)(instanceClassesMap[trueClass]));
}
}
public String getClassesRatioMap () {
String message = "";
for (String key : ClassesCountMap.keySet()) {
double ratio = ((double)ClassesCountMap.get(key) / (double) NoOfProcessedInstances) * 100;
message += key + ": " + ratio + "% \n";
}
return message;
}
public Measurement[] getClassesRatioMeasurements () {
Measurement []measurements = new Measurement[ClassesCountMap.size()];
int i = 0;
for (String key : ClassesCountMap.keySet()) {
double ratio = ((double)ClassesCountMap.get(key) / (double) NoOfProcessedInstances) * 100;
measurements[i] = new Measurement(key, ratio);
i++;
}
return measurements;
}
public Measurement[] getPerformanceMeasurements () {
Measurement basicMeasurements[] = new Measurement[]{
new Measurement("classified instances",
getTotalWeightObserved()),
new Measurement("classifications correct (percent)",
getFractionCorrectlyClassified() * 100.0),
new Measurement("Kappa Statistic (percent)",
getKappaStatistic() * 100.0),
new Measurement("Mean Square Error ",
getMSE()),
new Measurement("Root Mean Square Error ",
getRMSE())
};
Measurement classRatios[] = getClassesRatioMeasurements();
Measurement aggregatedMeasurements[] = (Measurement []) ArrayUtils.addAll(basicMeasurements, classRatios);
return aggregatedMeasurements;
}
public double getTotalWeightObserved () {
return this.weightObserved;
}
public double getMSE () {
return (SE / (double) NoOfProcessedInstances);
}
public double getRMSE () {
return Math.sqrt(SE / (double) NoOfProcessedInstances);
}
public double getFractionCorrectlyClassified () {
return this.weightObserved > 0.0 ? this.weightCorrect
/ this.weightObserved : 0.0;
}
public int getNoOfProcessedInstances () {
return NoOfProcessedInstances;
}
public int[] getInstancesClassesCount (){
return this.instanceClassesMap;
}
public HashMap<String, Integer> getClassesCountMap () {
return ClassesCountMap;
}
public double getFractionIncorrectlyClassified () {
return 1.0 - getFractionCorrectlyClassified();
}
public double getKappaStatistic () {
if (this.weightObserved > 0.0) {
double p0 = getFractionCorrectlyClassified();
double pc = 0.0;
for (int i = 0; i < this.numClasses; i++) {
pc += (this.rowKappa[i] / this.weightObserved) *
(this.columnKappa[i] / this.weightObserved);
}
return (p0 - pc) / (1.0 - pc);
} else {
return 0;
}
}
public void getDescription (StringBuilder sb, int indent) {
Measurement.getMeasurementsDescription(getPerformanceMeasurements(),
sb, indent);
}
}