package edu.stanford.nlp.classify;
import java.util.ArrayList;
import java.util.List;
import java.io.File;
import edu.stanford.nlp.util.BinaryHeapPriorityQueue;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.PriorityQueue;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
/** A class to create recall-precision curves given scores
* used to fit the best monotonic function for logistic regression and SVMs.
*
* @author Kristina Toutanova
* @version May 23, 2005
*/
public class PRCurve {
double[] scores; //sorted scores
int[] classes; // the class of example i
int[] guesses; // the guess of example i according to the argmax
int[] numpositive; // number positive in the i-th highest scores
int[] numnegative; // number negative in the i-th lowest scores
final static Redwood.RedwoodChannels logger = Redwood.channels(PRCurve.class);
/**
* reads scores with classes from a file, sorts by score and creates the arrays
*
*/
public PRCurve(String filename) {
try {
ArrayList<Pair<Double, Integer>> dataScores = new ArrayList<>();
for(String line : ObjectBank.getLineIterator(new File(filename))) {
List<String> elems = StringUtils.split(line);
Pair<Double, Integer> p = new Pair<>(Double.valueOf(elems.get(0)), Integer.valueOf(elems.get(1)));
dataScores.add(p);
}
init(dataScores);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* reads scores with classes from a file, sorts by score and creates the arrays
*
*/
public PRCurve(String filename, boolean svm) {
try {
ArrayList<Pair<Double, Integer>> dataScores = new ArrayList<>();
for(String line : ObjectBank.getLineIterator(new File(filename))) {
List<String> elems = StringUtils.split(line);
int cls = Double.valueOf(elems.get(0)).intValue();
if (cls == -1) {
cls = 0;
}
double score = Double.valueOf(elems.get(1)) + 0.5;
Pair<Double, Integer> p = new Pair<>(new Double(score), Integer.valueOf(cls));
dataScores.add(p);
}
init(dataScores);
} catch (Exception e) {
e.printStackTrace();
}
}
public double optimalAccuracy() {
return precision(numSamples()) / (double) numSamples();
}
public double accuracy() {
return logPrecision(numSamples()) / (double) numSamples();
}
public PRCurve(List<Pair<Double, Integer>> dataScores) {
init(dataScores);
}
public void init(List<Pair<Double, Integer>> dataScores) {
PriorityQueue<Pair<Integer, Pair<Double, Integer>>> q = new BinaryHeapPriorityQueue<>();
for (int i = 0; i < dataScores.size(); i++) {
q.add(new Pair<>(Integer.valueOf(i), dataScores.get(i)), -dataScores.get(i).first().doubleValue());
}
List<Pair<Integer, Pair<Double, Integer>>> sorted = q.toSortedList();
scores = new double[sorted.size()];
classes = new int[sorted.size()];
logger.info("incoming size " + dataScores.size() + " resulting " + sorted.size());
for (int i = 0; i < sorted.size(); i++) {
Pair<Double, Integer> next = sorted.get(i).second();
scores[i] = next.first().doubleValue();
classes[i] = next.second().intValue();
}
init();
}
public void initMC(ArrayList<Triple<Double, Integer, Integer>> dataScores) {
PriorityQueue<Pair<Integer, Triple<Double, Integer, Integer>>> q = new BinaryHeapPriorityQueue<>();
for (int i = 0; i < dataScores.size(); i++) {
q.add(new Pair<>(Integer.valueOf(i), dataScores.get(i)), -dataScores.get(i).first().doubleValue());
}
List<Pair<Integer, Triple<Double, Integer, Integer>>> sorted = q.toSortedList();
scores = new double[sorted.size()];
classes = new int[sorted.size()];
guesses = new int[sorted.size()];
logger.info("incoming size " + dataScores.size() + " resulting " + sorted.size());
for (int i = 0; i < sorted.size(); i++) {
Triple<Double, Integer, Integer> next = sorted.get(i).second();
scores[i] = next.first().doubleValue();
classes[i] = next.second().intValue();
guesses[i] = next.third().intValue();
}
init();
}
/**
* initialize the numpositive and the numnegative arrays
*/
void init() {
numnegative = new int[numSamples() + 1];
numpositive = new int[numSamples() + 1];
numnegative[0] = 0;
numpositive[0] = 0;
int num = numSamples();
for (int i = 1; i <= num; i++) {
numnegative[i] = numnegative[i - 1] + (classes[i - 1] == 0 ? 1 : 0);
}
for (int i = 1; i <= num; i++) {
numpositive[i] = numpositive[i - 1] + (classes[num - i] == 0 ? 0 : 1);
}
logger.info("total positive " + numpositive[num] + " total negative " + numnegative[num] + " total " + num);
for (int i = 1; i < numpositive.length; i++) {
//System.out.println(i + " positive " + numpositive[i] + " negative " + numnegative[i] + " classes " + classes[i - 1] + " " + classes[num - i]);
}
}
int numSamples() {
return scores.length;
}
/**
* what is the best precision at the given recall
*
*/
public int precision(int recall) {
int optimum = 0;
for (int right = 0; right <= recall; right++) {
int candidate = numpositive[right] + numnegative[recall - right];
if (candidate > optimum) {
optimum = candidate;
}
}
return optimum;
}
public static double f1(int tp, int fp, int fn) {
double prec = 1;
double recall = 1;
if (tp + fp > 0) {
prec = tp / (double) (tp + fp);
}
if (tp + fn > 0) {
recall = tp / (double) (tp + fn);
}
return 2 * prec * recall / (prec + recall);
}
/**
* the f-measure if we just guess as negative the first numleft and guess as positive the last numright
*
*/
public double fmeasure(int numleft, int numright) {
int tp = numpositive[numright];
int fp = numright - tp;
int fn = numleft - numnegative[numleft];
return f1(tp, fp, fn);
}
/**
* what is the precision at this recall if we look at the score as the probability of class 1 given x
* as if coming from logistic regression
*
*/
public int logPrecision(int recall) {
int totaltaken = 0;
int rightIndex = numSamples() - 1; //next right candidate
int leftIndex = 0; //next left candidate
int totalcorrect = 0;
while (totaltaken < recall) {
double confr = Math.abs(scores[rightIndex] - .5);
double confl = Math.abs(scores[leftIndex] - .5);
int chosen = leftIndex;
if (confr > confl) {
chosen = rightIndex;
rightIndex--;
} else {
leftIndex++;
}
//logger.info("chose "+chosen+" score "+scores[chosen]+" class "+classes[chosen]+" correct "+correct(scores[chosen],classes[chosen]));
if ((scores[chosen] >= .5) && (classes[chosen] == 1)) {
totalcorrect++;
}
if ((scores[chosen] < .5) && (classes[chosen] == 0)) {
totalcorrect++;
}
totaltaken++;
}
return totalcorrect;
}
/**
* what is the optimal f-measure we can achieve given recall guesses
* using the optimal monotonic function
*
*/
public double optFmeasure(int recall) {
double max = 0;
for (int i = 0; i < (recall + 1); i++) {
double f = fmeasure(i, recall - i);
if (f > max) {
max = f;
}
}
return max;
}
public double opFmeasure() {
return optFmeasure(numSamples());
}
/**
* what is the f-measure at this recall if we look at the score as the probability of class 1 given x
* as if coming from logistic regression same as logPrecision but calculating f-measure
*
* @param recall make this many guesses for which we are most confident
*/
public double fmeasure(int recall) {
int totaltaken = 0;
int rightIndex = numSamples() - 1; //next right candidate
int leftIndex = 0; //next left candidate
int tp = 0, fp = 0, fn = 0;
while (totaltaken < recall) {
double confr = Math.abs(scores[rightIndex] - .5);
double confl = Math.abs(scores[leftIndex] - .5);
int chosen = leftIndex;
if (confr > confl) {
chosen = rightIndex;
rightIndex--;
} else {
leftIndex++;
}
//logger.info("chose "+chosen+" score "+scores[chosen]+" class "+classes[chosen]+" correct "+correct(scores[chosen],classes[chosen]));
if ((scores[chosen] >= .5)) {
if (classes[chosen] == 1) {
tp++;
} else {
fp++;
}
}
if ((scores[chosen] < .5)) {
if (classes[chosen] == 1) {
fn++;
}
}
totaltaken++;
}
return f1(tp, fp, fn);
}
/**
* assuming the scores are probability of 1 given x
*
*/
public double logLikelihood() {
double loglik = 0;
for (int i = 0; i < scores.length; i++) {
loglik += Math.log(classes[i] == 0 ? 1 - scores[i] : scores[i]);
}
return loglik;
}
/**
* confidence weighted accuracy assuming the scores are probabilities and using .5 as treshold
*
*/
public double cwa() {
double acc = 0;
for (int recall = 1; recall <= numSamples(); recall++) {
acc += logPrecision(recall) / (double) recall;
}
return acc / numSamples();
}
/**
* confidence weighted accuracy assuming the scores are probabilities and using .5 as treshold
*
*/
public int[] cwaArray() {
int[] arr = new int[numSamples()];
for (int recall = 1; recall <= numSamples(); recall++) {
arr[recall - 1] = logPrecision(recall);
}
return arr;
}
/**
* confidence weighted accuracy assuming the scores are probabilities and using .5 as threshold
*
*/
public int[] optimalCwaArray() {
int[] arr = new int[numSamples()];
for (int recall = 1; recall <= numSamples(); recall++) {
arr[recall - 1] = precision(recall);
}
return arr;
}
/**
* optimal confidence weighted accuracy assuming for each recall we can fit an optimal monotonic function
*
*/
public double optimalCwa() {
double acc = 0;
for (int recall = 1; recall <= numSamples(); recall++) {
acc += precision(recall) / (double) recall;
}
return acc / numSamples();
}
public static boolean correct(double score, int cls) {
return ((score >= .5) && (cls == 1)) || ((score < .5) && (cls == 0));
}
public static void main(String[] args) {
PriorityQueue<String> q = new BinaryHeapPriorityQueue<>();
q.add("bla", 2);
q.add("bla3", 2);
logger.info("size of q " + q.size());
PRCurve pr = new PRCurve("c:/data0204/precsvm", true);
logger.info("acc " + pr.accuracy() + " opt " + pr.optimalAccuracy() + " cwa " + pr.cwa() + " optcwa " + pr.optimalCwa());
for (int r = 1; r <= pr.numSamples(); r++) {
logger.info("optimal precision at recall " + r + " " + pr.precision(r));
logger.info("model precision at recall " + r + " " + pr.logPrecision(r));
}
}
}