/* * Concept profile generation tool suite * Copyright (C) 2015 Biosemantics Group, Erasmus University Medical Center, * Rotterdam, The Netherlands * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published * by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/> */ package org.erasmusmc.classification; import java.awt.BasicStroke; import java.awt.Color; import java.awt.Font; import java.awt.FontMetrics; import java.awt.Graphics2D; import java.awt.RenderingHints; import java.awt.geom.Line2D; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import javax.imageio.ImageIO; import org.erasmusmc.utilities.StringUtilities; import org.erasmusmc.utilities.WriteTextFile; public class ClassifierEvaluator { public static int width = 500; public static int height = 500; public static int borderWidth = 50; public static int ticSize = 5; public static boolean includeAuCinROC = true; public static boolean outputROCcsv = true; public static void main(String[] args){ ClassifierOutput output = new ClassifierOutput(); /*output.scoreLabelPairs.add(new ScoreLabelPair(5, true)); output.scoreLabelPairs.add(new ScoreLabelPair(0, true)); output.scoreLabelPairs.add(new ScoreLabelPair(9, true)); output.scoreLabelPairs.add(new ScoreLabelPair(8, true)); output.scoreLabelPairs.add(new ScoreLabelPair(5, true)); output.scoreLabelPairs.add(new ScoreLabelPair(0, false)); output.scoreLabelPairs.add(new ScoreLabelPair(0, false)); output.scoreLabelPairs.add(new ScoreLabelPair(0, false)); output.scoreLabelPairs.add(new ScoreLabelPair(5, false)); */ output.scoreLabelPairs.add(new ScoreLabelPair(5, false)); output.scoreLabelPairs.add(new ScoreLabelPair(4, true)); output.scoreLabelPairs.add(new ScoreLabelPair(3, true)); output.scoreLabelPairs.add(new ScoreLabelPair(2, true)); output.scoreLabelPairs.add(new ScoreLabelPair(1, true)); output.scoreLabelPairs.add(new ScoreLabelPair(ScoreLabelPair.LOW_NUMBER, false)); output.scoreLabelPairs.add(new ScoreLabelPair(ScoreLabelPair.LOW_NUMBER, false)); output.scoreLabelPairs.add(new ScoreLabelPair(ScoreLabelPair.LOW_NUMBER, false)); output.scoreLabelPairs.add(new ScoreLabelPair(ScoreLabelPair.LOW_NUMBER, true)); output.scoreLabelPairs.add(new ScoreLabelPair(ScoreLabelPair.LOW_NUMBER, true)); System.out.println(ClassifierEvaluator.areaUnderCurve(output)); System.out.println(ClassifierEvaluator.areaUnderCurveConfidenceInterval(output).pointEstimate); System.out.println(ClassifierEvaluator.areaUnderCurveConfidenceInterval(output).lowerBound); System.out.println(ClassifierEvaluator.areaUnderCurveConfidenceInterval(output).upperBound); System.out.println(ClassifierEvaluator.bpref(output)); } public static void createROC(ClassifierOutput classifierOutput, String filename){ WriteTextFile out = null; if (outputROCcsv){ out = new WriteTextFile(filename.replace(".gif", ".csv")); out.writeln("FPR,TPR"); } BufferedImage img = new BufferedImage(width, height,BufferedImage.TYPE_INT_ARGB); Graphics2D g2d = (Graphics2D)img.getGraphics(); g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); g2d.setColor(Color.WHITE); g2d.fillRect(0, 0, width, height); renderAxes(g2d, 0.2f,1f); //renderDiagonal(g2d); renderROC(g2d,classifierOutput,out); if (includeAuCinROC){ double auc = ClassifierEvaluator.areaUnderCurveConfidenceInterval(classifierOutput).pointEstimate; g2d.setColor(Color.black); g2d.setFont(new Font("Arial",Font.PLAIN,50)); FontMetrics metrics = g2d.getFontMetrics(); String text = "AUC = " + StringUtilities.formatNumber("0.00", auc); int textWidth = metrics.stringWidth(text); int textHeight = metrics.getHeight(); g2d.drawString(text, Math.round(width-textWidth-borderWidth), Math.round(height-textHeight)+7); } try { ImageIO.write(img, "GIF", new File(filename)); } catch (IOException e) { e.printStackTrace(); } if (out != null) out.close(); } public static void createFPCurve(ClassifierOutput classifierOutput, String filename, int count){ BufferedImage img = new BufferedImage(width, height,BufferedImage.TYPE_INT_ARGB); Graphics2D g2d = (Graphics2D)img.getGraphics(); g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); g2d.setColor(Color.WHITE); g2d.fillRect(0, 0, width, height); renderAxes(g2d, 10,count); renderDiagonal(g2d); renderFPCurve(g2d,classifierOutput,count); try { ImageIO.write(img, "GIF", new File(filename)); } catch (IOException e) { e.printStackTrace(); } } private static void renderDiagonal(Graphics2D g2d) { g2d.setColor(Color.LIGHT_GRAY); g2d.draw(new Line2D.Float(borderWidth, height-borderWidth, width-borderWidth, borderWidth)); } public static double averagePrecision(ClassifierOutput classifierOutput){ Collections.sort(classifierOutput.scoreLabelPairs); int tpCount = 0; double score = -1; double sumP = 0; int startInterval = 1; int tpCountInterval = 0; for (int i = 1; i <= classifierOutput.scoreLabelPairs.size(); i++){ ScoreLabelPair pair = classifierOutput.scoreLabelPairs.get(i-1); if (pair.score != score) {// In case of a tie, just make a linear interpolation if (tpCountInterval != 0){ int intervalSize = (i-startInterval); double step = (intervalSize+1) / (double)(tpCountInterval + 1); for (int j = 1; j <= tpCountInterval; j++){ double index = startInterval-1 + j*step; double precision = (tpCount + j) / index; sumP += precision; } tpCount += tpCountInterval; } score = pair.score; startInterval = i; tpCountInterval = 0; } if (pair.label){ tpCountInterval++; } } int i = classifierOutput.scoreLabelPairs.size()+1; if (tpCountInterval != 0){ int intervalSize = (i-startInterval); double step = (intervalSize+1) / (double)(tpCountInterval + 1); for (int j = 1; j <= tpCountInterval; j++){ double index = startInterval-1 + j*step; double precision = (tpCount + j) / index; sumP += precision; } tpCount += tpCountInterval; } return sumP/(double)tpCount; } /** * Calculates the MAP according to TREC. Assumes a score of 0 means not ranked by the system! * @param classifierOutput * @return */ public static double trecMAP(ClassifierOutput classifierOutput){ Collections.sort(classifierOutput.scoreLabelPairs); int tpCount = 0; int pCount = 0; double sumP = 0; for (int i = 1; i <= classifierOutput.scoreLabelPairs.size(); i++){ ScoreLabelPair pair = classifierOutput.scoreLabelPairs.get(i-1); if (pair.label){ pCount++; if (pair.score != Integer.MIN_VALUE){ tpCount++; sumP += tpCount/(double)i; } } } return sumP/(double)pCount; } public static double omopMAP(ClassifierOutput classifierOutput){ Collections.sort(classifierOutput.scoreLabelPairs); /*Collections.sort(classifierOutput.scoreLabelPairs, new Comparator<ScoreLabelPair>(){ @Override public int compare(ScoreLabelPair arg0, ScoreLabelPair arg1) { int result = Double.compare(arg1.score, arg0.score); if (result == 0) if (arg0.label == arg1.label) return 0; else if (arg0.label && !arg1.label) return 1; else return -1; return result; }}); */ int tpCount = 0; int tpTieCount = 0; double score = -1; double sumP = 0; for (int i = 1; i <= classifierOutput.scoreLabelPairs.size(); i++){ ScoreLabelPair pair = classifierOutput.scoreLabelPairs.get(i-1); if (pair.score != score){ if (tpTieCount != 0){ tpCount += tpTieCount; double precision = tpCount / (double)(i-1); sumP += tpTieCount * precision; tpTieCount = 0; } } if (pair.label) tpTieCount++; score = pair.score; } if (tpTieCount != 0){ tpCount += tpTieCount; double precision = tpCount / (double)classifierOutput.scoreLabelPairs.size(); sumP += tpTieCount * precision; tpTieCount = 0; } return sumP/(double)tpCount; } public static double p10(ClassifierOutput classifierOutput){ if (classifierOutput.scoreLabelPairs.size() < 10) return Double.NaN; Collections.sort(classifierOutput.scoreLabelPairs); int pos = 0; for (int i = 0; i < 10; i++) if (classifierOutput.scoreLabelPairs.get(i).label) pos++; return pos / 10d; } public static double bpref(ClassifierOutput classifierOutput){ Collections.sort(classifierOutput.scoreLabelPairs); int R = 0; int N = 0; for (ScoreLabelPair pair : classifierOutput.scoreLabelPairs) if (pair.label) R++; else N++; int minNR = Math.min(R,N); double bpref = 0; int neg = 0; for (int i = 0; i < classifierOutput.scoreLabelPairs.size(); i++){ ScoreLabelPair pair = classifierOutput.scoreLabelPairs.get(i); if (pair.label){ if (pair.score != ScoreLabelPair.LOW_NUMBER){ bpref += (N-neg)/(double)minNR; } } else neg++; } return bpref/R; } public static double areaUnderCurve(ClassifierOutput classifierOutput){ if (classifierOutput.scoreLabelPairs.size() == 0) return Double.NaN; Collections.sort(classifierOutput.scoreLabelPairs); int postiveCount = countPositives(classifierOutput); int tpCount = 0; int fpCount = 0; float fpr = 0; float tpr = 0; double score = classifierOutput.scoreLabelPairs.get(0).score; double auc = 0; for (int i = 0; i < classifierOutput.scoreLabelPairs.size(); i++){ ScoreLabelPair pair = classifierOutput.scoreLabelPairs.get(i); if (pair.label) tpCount++; else fpCount++; float newFpr = (float)fpCount / (float)(classifierOutput.scoreLabelPairs.size()-postiveCount); float newTpr = (float)tpCount / (float)postiveCount; if (pair.score != score || i == classifierOutput.scoreLabelPairs.size()-1) {// In case of a tie, just make a linear interpolation auc += (newFpr-fpr)*tpr + 0.5*((newFpr-fpr)*(newTpr-tpr)); tpr = newTpr; fpr = newFpr; score = pair.score; } } return auc; } public static double omopAUC(ClassifierOutput classifierOutput){ if (classifierOutput.scoreLabelPairs.size() == 0) return Double.NaN; Collections.sort(classifierOutput.scoreLabelPairs); int postiveCount = countPositives(classifierOutput); int tpCount = 0; int fpCount = 0; float fpr = 0; float tpr = 0; double score = classifierOutput.scoreLabelPairs.get(0).score; double auc = 0; for (int i = 0; i < classifierOutput.scoreLabelPairs.size(); i++){ ScoreLabelPair pair = classifierOutput.scoreLabelPairs.get(i); if (pair.label) tpCount++; else fpCount++; float newFpr = (float)fpCount / (float)(classifierOutput.scoreLabelPairs.size()-postiveCount); float newTpr = (float)tpCount / (float)postiveCount; if (pair.score != score || i == classifierOutput.scoreLabelPairs.size()-1) {// In case of a tie, take worst case scenario: auc += (newFpr-fpr)*tpr ; tpr = newTpr; fpr = newFpr; score = pair.score; } } return auc; } /** * Calculates AuC with CI using DeLong method (DeLong et al., 1988). Based on R pROC package. * @param classifierOutput * @return */ public static ConfidenceInterval areaUnderCurveConfidenceInterval(ClassifierOutput classifierOutput){ List<Double> cases = new ArrayList<Double>(); List<Double> controls = new ArrayList<Double>(); for (ScoreLabelPair pair : classifierOutput.scoreLabelPairs) if (pair.label) cases.add(pair.score); else controls.add(pair.score); int m = cases.size(); int n = controls.size(); int mn = m*n; double[][] mwMatrix = new double[m][n]; double mean = 0; for (int i = 0; i < m; i++) for (int j = 0; j < n; j++){ double mw = mannWhitneyKernel(cases.get(i), controls.get(j)); mwMatrix[i][j] = mw; mean += mw; } mean /= (double)mn; double vr10[] = new double[m]; for (int i = 0; i < m; i++){ double sum = 0; for (int j = 0; j < n; j++) sum += mwMatrix[i][j]; vr10[i] = sum/(double)n; } double vr01[] = new double[n]; for (int i = 0; i < n; i++){ double sum = 0; for (int j = 0; j < m; j++) sum += mwMatrix[j][i]; vr01[i] = sum/(double)m; } double s10 = 0; for (double vr : vr10) s10 += (vr-mean)*(vr-mean); s10 /= (double)(m-1); double s01 = 0; for (double vr : vr01) s01 += (vr-mean)*(vr-mean); s01 /= (double)(n-1); double s = s10/(double)m + s01/(double)n; double sd = Math.sqrt(s); ConfidenceInterval ci = new ConfidenceInterval(); ci.pointEstimate = mean; ci.lowerBound = mean - (1.96*sd); ci.upperBound = mean + (1.96*sd); return ci; } private static double mannWhitneyKernel(double x, double y){ if (y < x) return 1; if (y == x) return 0.5; return -0; } private static void renderROC(Graphics2D g2d, ClassifierOutput classifierOutput,WriteTextFile out) { Collections.sort(classifierOutput.scoreLabelPairs); int postiveCount = countPositives(classifierOutput); g2d.setColor(new Color(200,0,0)); g2d.setStroke(new BasicStroke(4)); float areaWidth = width - 2*borderWidth; float areaHeight = height - 2*borderWidth; float x = borderWidth; float y = height-borderWidth; float previousX = x; float previousY = y; double score = classifierOutput.scoreLabelPairs.get(0).score; int fpCount = 0; int tpCount = 0; for (ScoreLabelPair pair : classifierOutput.scoreLabelPairs){ if (pair.score != score) {// In case of a tie, just make a linear interpolation g2d.draw(new Line2D.Float(previousX,previousY,x,y)); previousX = x; previousY = y; score = pair.score; } if (pair.label) tpCount++; else fpCount++; float fpr = (float)fpCount / (float)(classifierOutput.scoreLabelPairs.size()-postiveCount); float tpr = (float)tpCount / (float)postiveCount; if (out != null) out.writeln(Float.toString(fpr)+"," +Float.toString(tpr)); x = borderWidth + fpr * areaWidth; y = height-borderWidth - tpr * areaHeight; } g2d.draw(new Line2D.Float(previousX,previousY,x,y)); if (out != null) out.writeln("1,1"); } private static void renderFPCurve(Graphics2D g2d, ClassifierOutput classifierOutput, int count) { Collections.sort(classifierOutput.scoreLabelPairs); g2d.setColor(Color.RED); float areaWidth = width - 2*borderWidth; float areaHeight = height - 2*borderWidth; float x = borderWidth; float y = height-borderWidth; int fpCount = 0; for (int i = 0; i < count; i++){ if (!classifierOutput.scoreLabelPairs.get(i).label) fpCount++; float fracX = (i+1)/(float)count; float fracY = fpCount/(float)count; float newX = borderWidth + fracX * areaWidth; float newY = height-borderWidth - fracY * areaHeight; g2d.draw(new Line2D.Float(x,y,newX,newY)); x = newX; y = newY; } } private static int countPositives(ClassifierOutput classifierOutput) { int postiveCount = 0; for (ScoreLabelPair pair : classifierOutput.scoreLabelPairs) if (pair.label) postiveCount++; return postiveCount; } private static void renderAxes(Graphics2D g2d, float tic, float max) { /* g2d.setColor(Color.BLACK); g2d.setStroke(new BasicStroke(3)); g2d.setFont(new Font("Arial",Font.PLAIN,20)); int areaWidth = width - 2*borderWidth; int areaHeight = height - 2*borderWidth; g2d.drawLine(borderWidth, borderWidth, borderWidth, height-borderWidth); g2d.drawLine(borderWidth, height-borderWidth, width-borderWidth, height-borderWidth); int fontHeight = g2d.getFontMetrics().getHeight(); int fontWidth = g2d.getFontMetrics().stringWidth("0.0"); for (float f = 0; f <= max; f+= tic){ int ticY = Math.round((height-borderWidth) - (f/max) * (areaHeight)); g2d.drawString(StringUtilities.formatNumber("0.0", f), borderWidth-ticSize-fontWidth, fontHeight / 2 + ticY); g2d.drawLine(borderWidth - ticSize, ticY, borderWidth, ticY); int ticX = Math.round(borderWidth + (f/max) * (areaWidth)); g2d.drawString(StringUtilities.formatNumber("0.0", f), ticX - fontWidth/2,height-borderWidth+ticSize+fontHeight); g2d.drawLine(ticX, height-borderWidth, ticX, height-borderWidth+ticSize); } */ g2d.setStroke(new BasicStroke(1)); g2d.setFont(new Font("Arial",Font.PLAIN,30)); int areaWidth = width - 2*borderWidth; int areaHeight = height - 2*borderWidth; g2d.setColor(new Color(238, 238, 238)); g2d.fillRect(borderWidth, borderWidth, areaWidth, areaHeight); int fontHeight = g2d.getFontMetrics().getHeight(); int fontWidth = g2d.getFontMetrics().stringWidth("0.0"); for (float f = 0; f <= max; f+= tic){ int ticY = Math.round((height-borderWidth) - (f/max) * (areaHeight)); g2d.setColor(Color.BLACK); g2d.drawString(StringUtilities.formatNumber("0.0", f), borderWidth-ticSize-fontWidth-3, ticY+(fontHeight/2)-5); g2d.setColor(new Color(170, 170, 170)); if (f != 0 && f != max) g2d.drawLine(borderWidth, ticY, width-borderWidth, ticY); int ticX = Math.round(borderWidth + (f/max) * (areaWidth)); g2d.setColor(Color.BLACK); g2d.drawString(StringUtilities.formatNumber("0.0", f), ticX - fontWidth/2,height-borderWidth+ticSize+fontHeight); g2d.setColor(new Color(170, 170, 170)); if (f != 0 && f != max) g2d.drawLine(ticX, borderWidth, ticX, height-borderWidth); } } public static class ConfidenceInterval { public double lowerBound; public double upperBound; public double pointEstimate; } public static String printStats(ClassifierOutput output) { int pos = 0; int neg = 0; for (ScoreLabelPair pair : output.scoreLabelPairs) if (pair.label) pos++; else neg++; return "Postives: " + pos + ", negatives: " + neg; } }