/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* MultipleEvaluation.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.evaluation;
import java.util.ArrayList;
import java.util.HashMap;
import mulan.evaluation.measure.Measure;
/**
* Simple class that includes an array, whose elements are lists of evaluation
* evaluations. Used to compute means and standard deviations of multiple
* evaluations (e.g. cross-validation)
*
* @author Grigorios Tsoumakas
*/
public class MultipleEvaluation {
private ArrayList<Evaluation> evaluations;
private HashMap<String, Double> mean;
private HashMap<String, Double> standardDeviation;
/**
* Constructs a new object
*/
public MultipleEvaluation() {
evaluations = new ArrayList<Evaluation>();
}
/**
* Constructs a new object with given array of evaluations and calculates
* statistics
*
* @param someEvaluations
*/
public MultipleEvaluation(Evaluation[] someEvaluations) {
evaluations = new ArrayList<Evaluation>();
for (Evaluation e : someEvaluations) {
evaluations.add(e);
}
calculateStatistics();
}
/**
* Computes mean and standard deviation of all evaluation measures
*/
public void calculateStatistics() {
int size = evaluations.size();
HashMap<String, Double> sums = new HashMap<String, Double>();
// calculate sums of measures
for (int i = 0; i < evaluations.size(); i++) {
for (Measure m : evaluations.get(i).getMeasures()) {
double value = Double.NaN;
try {
value = m.getValue();
} catch (Exception ex) {
}
if (sums.containsKey(m.getName())) {
sums.put(m.getName(), sums.get(m.getName()) + value);
} else {
sums.put(m.getName(), value);
}
}
}
mean = new HashMap<String, Double>();
for (String measureName : sums.keySet()) {
mean.put(measureName, sums.get(measureName) / size);
}
// calculate sums of squared differences from mean
sums = new HashMap<String, Double>();
for (int i = 0; i < evaluations.size(); i++) {
for (Measure m : evaluations.get(i).getMeasures()) {
double value = Double.NaN;
try {
value = m.getValue();
} catch (Exception ex) {
}
if (sums.containsKey(m.getName())) {
sums.put(m.getName(), sums.get(m.getName()) + Math.pow(value - mean.get(m.getName()), 2));
} else {
sums.put(m.getName(), Math.pow(value - mean.get(m.getName()), 2));
}
}
}
standardDeviation = new HashMap<String, Double>();
for (String measureName : sums.keySet()) {
standardDeviation.put(measureName, Math.sqrt(sums.get(measureName) / size));
}
}
/**
* Adds an evaluation results to the list of evaluations
*
* @param evaluation an evaluation result
*/
public void addEvaluation(Evaluation evaluation) {
evaluations.add(evaluation);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (Measure m : evaluations.get(0).getMeasures()) {
String measureName = m.getName();
sb.append(measureName);
sb.append(": ");
sb.append(String.format("%.4f", mean.get(measureName)));
sb.append("\u00B1");
sb.append(String.format("%.4f", standardDeviation.get(measureName)));
sb.append("\n");
}
return sb.toString();
}
/**
* Returns the mean value of a measure
*
* @param measureName the name of the measure
* @return the mean value of the measure
*/
public double getMean(String measureName) {
return mean.get(measureName);
}
/**
* Returns the standard deviation value of a measure
*
* @param measureName the name of the measure
* @return the standard deviation value of the measure
*/
public double getStandardDeviation(String measureName) {
return standardDeviation.get(measureName);
}
/**
* Returns a CSV string representation of the results
*
* @return a CSV string representation of the results
*/
public String toCSV() {
StringBuilder sb = new StringBuilder();
for (Measure m : evaluations.get(0).getMeasures()) {
String measureName = m.getName();
sb.append(String.format("%.4f", mean.get(measureName)));
sb.append("\u00B1");
sb.append(String.format("%.4f", standardDeviation.get(measureName)));
sb.append(";");
}
return sb.toString();
}
}