/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with ALOE. If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.data;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* A class for storing prediction data based on indexed instances.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class Predictions {
private List<Prediction> predictions = new ArrayList<Prediction>();
private int truePositiveCount = 0;
private int trueNegativeCount = 0;
private int falsePositiveCount = 0;
private int falseNegativeCount = 0;
public Predictions() {
}
/**
* Make a copy of the specified report.
*
* @param parent
* @param predictions
*/
private Predictions(Predictions parent) {
this.predictions = parent.predictions;
this.truePositiveCount = parent.truePositiveCount;
this.trueNegativeCount = parent.trueNegativeCount;
this.falsePositiveCount = parent.falsePositiveCount;
this.falseNegativeCount = parent.falseNegativeCount;
}
public Boolean getPredictedLabel(int index) {
return predictions.get(index).getPredictedLabel();
}
public Double getPredictionConfidence(int index) {
return predictions.get(index).getConfidence();
}
public Boolean getTrueLabel(int index) {
return predictions.get(index).getTrueLabel();
}
public int size() {
return predictions.size();
}
public int getTruePositiveCount() {
return truePositiveCount;
}
public int getTrueNegativeCount() {
return trueNegativeCount;
}
public int getFalsePositiveCount() {
return falsePositiveCount;
}
public int getFalseNegativeCount() {
return falseNegativeCount;
}
/**
* Add a predicted data point.
*
* @param predictedLabel
* @param confidence
*/
public void add(Boolean predictedLabel, Double confidence) {
this.add(predictedLabel, confidence, null);
}
/**
* Add a predicted data point with a known true value.
*
* @param predictedLabel
* @param confidence
* @param trueLabel
*/
public void add(Boolean predictedLabel, Double confidence, Boolean trueLabel) {
this.predictions.add(new Prediction(predictedLabel, trueLabel, confidence));
if (trueLabel != null) {
if (predictedLabel == true) {
if (predictedLabel == trueLabel) {
truePositiveCount++;
} else {
falsePositiveCount++;
}
} else {
if (predictedLabel == trueLabel) {
trueNegativeCount++;
} else {
falseNegativeCount++;
}
}
}
}
/**
* Creates a copy of these predictions, sorted by confidence (ascending).
*
* @return
*/
Predictions sortByConfidence() {
ArrayList<Prediction> sortedPredictions = new ArrayList<Prediction>(predictions);
Collections.sort(sortedPredictions, new Comparator<Prediction>() {
@Override
public int compare(Prediction o1, Prediction o2) {
return o1.getConfidence().compareTo(o2.getConfidence());
}
});
Predictions copy = new Predictions(this);
copy.predictions = sortedPredictions;
return copy;
}
/**
* Class for storing an individual prediction.
*/
private class Prediction {
private final Boolean predictedLabel;
private final Boolean trueLabel;
private final Double confidence;
public Prediction(Boolean predictedLabel, Boolean trueLabel, Double confidence) {
this.predictedLabel = predictedLabel;
this.trueLabel = trueLabel;
this.confidence = confidence;
}
public Boolean getPredictedLabel() {
return predictedLabel;
}
public Boolean getTrueLabel() {
return trueLabel;
}
public Double getConfidence() {
return confidence;
}
}
}