/* Copyright 2006, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.ranking; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.LineNumberReader; import java.io.PrintStream; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; import javax.swing.JComponent; import edu.cmu.minorthird.classify.BinaryClassifier; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.MutableInstance; import edu.cmu.minorthird.util.BasicCommandLineProcessor; import edu.cmu.minorthird.util.Saveable; import edu.cmu.minorthird.util.StringUtil; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.LineCharter; import edu.cmu.minorthird.util.gui.ParallelViewer; import edu.cmu.minorthird.util.gui.VanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerFrame; import edu.cmu.minorthird.util.gui.Visible; /** Evaluate a classifier as a ranker */ public class RankingEvaluation implements Visible, Saveable { private final static int GRAPHS_PER_PAGE = 10; private final static int NUM_TOP_TO_SHOW = 50; private TreeMap<String,List<Example>> rankedListMap = new TreeMap<String,List<Example>>(); private TreeMap<String,Set<String>> unrankedMap = new TreeMap<String,Set<String>>(); private TreeMap<String,List<Double>> scoreMap = new TreeMap<String,List<Double>>(); private TreeMap<String,Integer> numPosExamples = new TreeMap<String,Integer>(); private boolean guiFlag = false; private String loadedFile = null; public void extend(String rankingId, List<Example> ranking, BinaryClassifier classifier) { extend(rankingId,ranking,classifier,Collections.EMPTY_SET); } public void extend(String rankingId, List<Example> ranking, BinaryClassifier classifier, Set<String> unrankedPos) { BatchRankingLearner.sortByScore( classifier, ranking ); rankedListMap.put( rankingId, ranking ); List<Double> scores = new ArrayList<Double>(ranking.size()); int k=0; for (Iterator<Example> i=ranking.iterator(); i.hasNext(); ) { Example ex = i.next(); if (ex.getLabel().isPositive()) increment(numPosExamples,rankingId,1); scores.set(k++,classifier.score(ex)); } scoreMap.put(rankingId,scores); unrankedMap.put(rankingId,unrankedPos); increment(numPosExamples,rankingId,unrankedPos.size()); } private void increment(TreeMap<String,Integer> map,String key,int delta) { Integer i = map.get(key); if (i==null) map.put(key,new Integer(delta)); else map.put(key,new Integer(i.intValue()+delta)); } // // accessors // private List<Example> getRanking(String rankingId) { return rankedListMap.get(rankingId); } private double getScore(String rankingId,int rank) { return scoreMap.get(rankingId).get(rank-1); } private Iterator<String> getRankingIterator() { return rankedListMap.keySet().iterator(); } private int numPosExamples(String rankingId) {return (numPosExamples.get(rankingId)).intValue();} private boolean isPositive(String rankingId,Example ex) { return ex.getLabel().isPositive(); } private int numRankings() { return rankedListMap.keySet().size(); } private Set<Example> getUnrankedPositives(String rankingId) { return Collections.EMPTY_SET; } // // split the examples into groups of K // private String[][] exampleGroups(int groupSize) { int remainder = numRankings() % groupSize; int numRemainderGroups = remainder>0 ? 1 : 0; String[][] group = new String[(numRankings()/groupSize) + numRemainderGroups][]; for (int i=0; i<group.length-numRemainderGroups; i++) { group[i] = new String[groupSize]; } if (numRemainderGroups>0) { group[group.length-1] = new String[ remainder ]; } int j=0, k=0; for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { String name = i.next(); group[j][k++] = name; if (k>=group[j].length) { j++; k=0; } } return group; } // // useful subroutine - returns an array a such that a[0] is recall // at each rank, and a[1] is precision at each rank. // private double[][] recallAndPrecisionForEachK(String rankingId) { List<Example> ranking = getRanking(rankingId); int totalPos = numPosExamples(rankingId); double[] recall = new double[ranking.size()+1]; double[] precision = new double[ranking.size()+1]; int rank=0; double numPosAboveRank=0; for (Iterator<Example> j=ranking.iterator(); j.hasNext(); ) { Example ex = j.next(); rank++; if (isPositive(rankingId,ex)) numPosAboveRank++; if (totalPos>0) { recall[rank] = numPosAboveRank/totalPos; precision[rank] = numPosAboveRank/rank; } else { recall[rank] = precision[rank] = 1.0; } } double[][] result = new double[2][]; result[0] = recall; result[1] = precision; return result; } // // public functions // /** Non-interpolated average precision. */ public double averagePrecision(String rankingId) { if (numPosExamples(rankingId)==0) return 1.0; double rank = 0, numPosAboveRank = 0, totPrec = 0; List<Example> ranking = getRanking(rankingId); for (Iterator<Example> i=ranking.iterator(); i.hasNext(); ) { Example ex = i.next(); rank++; if (isPositive(rankingId,ex)) { numPosAboveRank++; totPrec += numPosAboveRank/rank; } } return totPrec/numPosExamples(rankingId); } /** Max value of F1 over all possible thresholds. */ public double maxF1(String rankingId) { if (numPosExamples(rankingId)==0) return 1.0; double rank=0, numPosAboveRank=0, maxF1=0; List<Example> ranking = getRanking(rankingId); for (Iterator<Example> i=ranking.iterator(); i.hasNext(); ) { Example id = i.next(); rank++; if (isPositive(rankingId,id)) { numPosAboveRank++; } double precision = numPosAboveRank/rank; double recall = numPosAboveRank/numPosExamples(rankingId); if (precision+recall>0) { double f1 = 2*precision*recall/(precision+recall); maxF1 = Math.max( maxF1, f1 ); } } return maxF1; } public double maxRecall(String rankingId) { if (numPosExamples(rankingId)==0) return 1.0; double numRanked = 0; List<Example> ranking = getRanking(rankingId); for (Iterator<Example> i=ranking.iterator(); i.hasNext(); ) { Example ex = i.next(); if (ex.getLabel().isPositive()) numRanked++; } return numRanked/numPosExamples(rankingId); } /** Interpolated precision at eleven recall levels, averaged over all examples. */ public double[] averageElevenPointPrecision() { double[] averagePrecision = new double[11]; for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { String name = i.next(); double[] precision = elevenPointPrecision(name); for (int j=0; j<=10; j++) { averagePrecision[j] += precision[j]; } } for (int j=0; j<=10; j++) { averagePrecision[j] /= numRankings(); } return averagePrecision; } /** Interpolated precision at eleven recall levels: 0.0, ... ,1.0 */ public double[] elevenPointPrecision(String rankingId) { double[][] a = recallAndPrecisionForEachK(rankingId); double[] recall = a[0]; double[] precision = a[1]; double[] interpolatedPrecision = new double[11]; for (int k=1; k<recall.length; k++) { double r = recall[k]; double p = precision[k]; for (int j=0; j<=10; j++) { if (r >= j/10.0) { interpolatedPrecision[j] = Math.max( interpolatedPrecision[j], p ); } } } return interpolatedPrecision; } /** A summary table. Columns are: avgpr, the * non-interpolated average precision of the ranking (the average * of this is thus mean average precision); maxF1, the maximum F1 * value for the ranking; maxRec, the maximum recall achieved * (i.e., the fraction of relevant nodes appearing in the * ranking); and #pos, the number of positive/relevant nodes. */ public String toTable() { if (rankedListMap.keySet().size()==0) { return "no examples?\n"; } StringBuffer buf = new StringBuffer(); DecimalFormat fmt = new DecimalFormat("0.000"); DecimalFormat fmt2 = new DecimalFormat("0.0"); buf.append("avgPr\tmaxF1\tmaxRec\t#pos\n"); double totMaxF1=0, totAvgPrec=0, totPos=0, totMaxRec=0; for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { String name = i.next(); double ap = averagePrecision(name); double maxf = maxF1(name); double maxr = maxRecall(name); int np = numPosExamples(name); buf.append(fmt.format(ap) + "\t"); buf.append(fmt.format(maxf) + "\t"); buf.append(fmt.format(maxr) + "\t"); buf.append(np + "\t"); buf.append(name + "\n"); totAvgPrec += ap; totMaxF1 += maxf; totMaxRec += maxr; totPos += np; } buf.append("\n"); buf.append(fmt.format(totAvgPrec/numRankings())+"\t"); buf.append(fmt.format(totMaxF1/numRankings())+"\t"); buf.append(fmt.format(totMaxRec/numRankings())+"\t"); buf.append(fmt2.format(totPos/numRankings())+"\t"); buf.append("average" + "\n"); return buf.toString(); } private double[] averageRecallAtEachK() { int longestRankedList = 0; for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { String name = i.next(); longestRankedList = Math.max( getRanking(name).size(), longestRankedList ); } // first have recall[k] be total recall over all examples at rank k double[] recall = new double[longestRankedList+1]; for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { String name = i.next(); List<Example> ranking = getRanking(name); int rank=0; double numPosAboveRank=0; for (Iterator<Example> j=ranking.iterator(); j.hasNext(); ) { Example id = j.next(); rank++; if (isPositive(name,id)) numPosAboveRank++; if (numPosExamples(name)>0) { recall[rank] += numPosAboveRank/numPosExamples(name); } else { recall[rank] = 1.0; } } // extend the last recorded recall level to the end of this list for (int k=rank+1; k<recall.length; k++) { recall[k] = recall[rank]; } } // now scale recall to average recall at K for (int k=1; k<recall.length; k++) { recall[k] /= numRankings(); } recall[0] = -1; // convenient return recall; } /** Recall as function of K, averaged over all examples. */ public String averageRecallAsFunctionOfK() { DecimalFormat fmt = new DecimalFormat("0.000"); StringBuffer buf = new StringBuffer(""); buf.append("K\tAvgRecall\n"); double[] recall = averageRecallAtEachK(); for (int k=1; k<recall.length; k++) { if (recall[k]!=recall[k-1]) { buf.append(k+"\t"+fmt.format(recall[k])+"\n"); } } return buf.toString(); } public String toTable(String name,int numToShowAllEntries) { List<Example> ranking = getRanking(name); StringBuffer buf = new StringBuffer(); DecimalFormat fmt = new DecimalFormat("0.000"); int rank = 0; for (Iterator<Example> i=ranking.iterator(); i.hasNext(); ) { Example id = i.next(); ++rank; double score = getScore(name,rank); String tag = isPositive(name,id) ? "+" : "-"; // print the entry if it's positive, or if it's near the top if (rank<numToShowAllEntries || tag.startsWith("+")) { buf.append(rank+"\t"+fmt.format(score)+"\t"+tag+"\t"+id+"\n"); } } // now print the false negatives - ie the unranked positives for (Iterator<Example> i=getUnrankedPositives(name).iterator(); i.hasNext(); ) { Example id = i.next(); String tag = "+"; buf.append(">"+rank+"\t0\t"+tag+"\t"+id+"\n"); } return buf.toString(); } @Override public Viewer toGUI() { ParallelViewer v = new ParallelViewer(); v.addSubView( "Summary Table", new ComponentViewer() { static final long serialVersionUID=20080206L; @Override public JComponent componentFor(Object o) { RankingEvaluation gsEval = (RankingEvaluation)o; return new VanillaViewer( gsEval.toTable() ); } }); ParallelViewer v2 = new ParallelViewer(); v.addSubView( "11-Pt Precision", v2 ); v2.addSubView( "Averaged", new ComponentViewer() { static final long serialVersionUID=20080206L; @Override public JComponent componentFor(Object o) { RankingEvaluation gsEval = (RankingEvaluation)o; double[] avgPrec = gsEval.averageElevenPointPrecision(); LineCharter lc = new LineCharter(); lc.startCurve("11-Pt Avg Prec"); for (int j=0; j<=10; j++) { lc.addPoint( j/10.0, avgPrec[j] ); } return lc.getPanel("11-Pt Average Interpolated Precision", "Recall", "Precision"); } }); String[][] groups = exampleGroups(GRAPHS_PER_PAGE); for (int i=0; i<groups.length; i++) { final String tag = groups.length==1 ? "Details" : ("Details: Group "+(i+1)); final String[] group = groups[i]; v2.addSubView( tag, new ComponentViewer() { static final long serialVersionUID=20080206L; @Override public JComponent componentFor(Object o) { RankingEvaluation gsEval = (RankingEvaluation)o; LineCharter lc = new LineCharter(); for (int i=0; i<group.length; i++) { String name = group[i]; double[] avgPrec = gsEval.elevenPointPrecision(name); lc.startCurve(name); for (int j=0; j<=10; j++) { lc.addPoint( j/10.0, avgPrec[j] ); } } return lc.getPanel("11-Pt Interpolated Precision", "Recall", "Precision"); } }); } v.addSubView( "AvgRecall vs Rank", new ComponentViewer() { static final long serialVersionUID=20080206L; @Override public JComponent componentFor(Object o) { //RankingEvaluation gsEval = (RankingEvaluation)o; double[] avgRec = averageRecallAtEachK(); LineCharter lc = new LineCharter(); lc.startCurve("Recall vs Rank"); for (int i=1; i<avgRec.length; i++) { lc.addPoint( i, avgRec[i] ); } return lc.getPanel("AvgRecall vs Rank", "Rank", "AvgRecall"); } }); ParallelViewer v3 = new ParallelViewer(); v3.putTabsOnLeft(); v.addSubView( "Details", v3 ); for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { final String name = i.next(); v3.addSubView( name, new ComponentViewer() { static final long serialVersionUID=20080206L; @Override public JComponent componentFor(Object o) { return new VanillaViewer( toTable(name,NUM_TOP_TO_SHOW) ); } }); } v.setContent(this); return v; } // // implement Saveable // final static private String EVAL_FORMAT_NAME = "Graph Searcher Evaluation"; final static private String EVAL_EXT = ".gsev"; @Override public String[] getFormatNames() { return new String[]{EVAL_FORMAT_NAME}; } @Override public String getExtensionFor(String format) { return EVAL_EXT; } @Override public void saveAs(File file,String formatName) throws IOException { save(file); } @Override public Object restore(File file) throws IOException { return load(file); } // final static private StringEncoder encoder = new StringEncoder('%',"/\\:;$ \t\n"); // final static private String evalExt = Evaluation.EVAL_EXT; private void save(File file) throws IOException { PrintStream out = new PrintStream(new FileOutputStream(file)); for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { String name = i.next(); List<Example> ranking = getRanking(name); int rank = 0; for (Iterator<Example> j=ranking.iterator(); j.hasNext(); ) { Example id = j.next(); rank++; double weight = getScore(name,rank); out.println(name +"\t"+ id.getSource() +"\t"+ rank +"\t" + weight); } for (Iterator<Example> j=ranking.iterator(); j.hasNext(); ) { Example id = j.next(); if (isPositive(name,id)) { out.println(name +"\t" + id.getSource()); } } Set<Example> pos = getUnrankedPositives(name); for (Iterator<Example> j=pos.iterator(); j.hasNext(); ) { Example id = j.next(); out.println(name +"\t" + id.getSource()); } } out.close(); } static public RankingEvaluation load(File file) throws IOException { RankingEvaluation eval = new RankingEvaluation(); eval.loadFromFile(file); return eval; } private void loadFromFile(File file) throws IOException { TreeMap<String,List<String>> tempListMap=new TreeMap<String,List<String>>(); LineNumberReader in = new LineNumberReader(new InputStreamReader(new FileInputStream(file))); String line = null; while ((line = in.readLine())!=null) { String[] parts = line.split("\t"); if (parts.length==2) { // rankingId positiveExample Set<String> pos = unrankedMap.get(parts[0]); if (pos==null) unrankedMap.put(parts[0], (pos = new TreeSet<String>())); pos.add( parts[1] ); increment(numPosExamples,parts[0],1); } else if (parts.length==4) { // rankingId graphId rank weight List<String> ranking = tempListMap.get(parts[0]); if (ranking==null) tempListMap.put(parts[0], (ranking = new ArrayList<String>())); List<Double> scores = scoreMap.get(parts[0]); if (scores==null) scoreMap.put( parts[0], (scores = new ArrayList<Double>())); scores.add( new Double(StringUtil.atof(parts[3])) ); ranking.add( parts[1] ); } else { throw new IllegalArgumentException(file+" line "+in.getLineNumber()+": illegal format"); } } for (Iterator<String> i=getRankingIterator(); i.hasNext(); ) { String rankingId = i.next(); //System.out.println("unbuffering: "+rankingId); Set<String> pos = unrankedMap.get( rankingId ); if (pos==null) pos = Collections.EMPTY_SET; List<Double> scores = scoreMap.get( rankingId ); List<String> tempRanking = tempListMap.get( rankingId ); List<Example> ranking = new ArrayList<Example>(tempRanking.size()); double[] newScores = new double[scores.size()]; for (int j=0; j<tempRanking.size(); j++) { String exId = tempRanking.get(j); if (pos.contains(exId)) { ranking.set(j, new Example( new MutableInstance(exId), ClassLabel.binaryLabel(+1) )); pos.remove( exId ); } else { ranking.set(j, new Example( new MutableInstance(exId), ClassLabel.binaryLabel(-1) )); } newScores[j] = (scores.get(j)).doubleValue(); } List<Double> newScoresList = new ArrayList<Double>(scores.size()); for(int j=0;j<newScores.length;j++){ newScoresList.add(j,newScores[j]); } scoreMap.put( rankingId, newScoresList ); } } // // test // public class MyCLP extends BasicCommandLineProcessor { //public void graph(String s) { graph=new TextGraph(s,'r'); } public void gui() { guiFlag = true; } public void loadFrom(String s) { loadedFile=s; try { loadFromFile(new File(s)); } catch (IOException ex) { ex.printStackTrace(); } } } public void processArguments(String[] args) { new MyCLP().processArguments(args); } static public void main(String[] args) throws IOException { RankingEvaluation eval = new RankingEvaluation(); eval.processArguments(args); if (eval.guiFlag) new ViewerFrame(eval.loadedFile, eval.toGUI()); else System.out.println(eval.toTable()); } }