package de.berlin.hu.chemspot; import java.io.BufferedWriter; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.io.OutputStreamWriter; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.uima.jcas.JCas; import de.berlin.hu.util.Constants; import de.berlin.hu.util.Constants.ChemicalType; import de.berlin.hu.wbi.common.research.Evaluator; public class ChemicalNEREvaluator { private int TP = 0; private int FP = 0; private int FN = 0; private Object evaluationLock = new Object(); private List<Mention> truePositives = new ArrayList<Mention>(); private List<Mention> falsePositives = new ArrayList<Mention>(); private List<Mention> falseNegatives = new ArrayList<Mention>(); private List<Mention> predictions = new ArrayList<Mention>(); private List<Mention> goldstandard = new ArrayList<Mention>(); private List<Mention> normalizedAll = new ArrayList<Mention>(); private List<Mention> normalized = new ArrayList<Mention>(); private List<Mention> normalizedCorrect = new ArrayList<Mention>(); private Object normalizationLock = new Object(); public static String getEvaluationResult(Evaluator<?, ?> evaluator) { return getEvaluationResult(evaluator.getNumberOfTP(), evaluator.getNumberOfFP(), evaluator.getNumberOfFN()); } public static String getEvaluationResult(int tps, int fps, int fns) { StringBuffer buffer = new StringBuffer(); double precision = tps + fps > 0 ? (double) tps / ((double) tps + fps) : 0; double recall = tps + fns > 0 ? (double) tps / ((double) tps + fns) : 0; double fscore = precision + recall > 0 ? 2 * (precision * recall) / (precision + recall) : 0; buffer.append(String.format("True Positives : %d%n", tps)); buffer.append(String.format("False Positives : %d%n", fps)); buffer.append(String.format("False Negatives : %d%n", fns)); buffer.append(String.format("Precision : %3.2f %%%n", precision * 100.0)); buffer.append(String.format("Recall : %3.2f %%%n", recall * 100.0)); buffer.append(String.format("F1 Score : %3.2f %%%n", fscore * 100.0)); return buffer.toString(); } /** * Evaluates the annotation results. * * @param jcas */ public void evaluate(JCas jcas) { System.out.println("Starting evaluation..."); List<Mention> mentions = ChemSpot.getMentions(jcas); List<Mention> goldstandardAnnotations = ChemSpot.getGoldstandardAnnotations(jcas); predictions.addAll(mentions); goldstandard.addAll(goldstandardAnnotations); synchronized(evaluationLock) { if (goldstandardAnnotations.size() == 0) { FP += mentions.size(); falsePositives.addAll(mentions); } else if (mentions.size() == 0) { FN += goldstandardAnnotations.size(); falseNegatives.addAll(goldstandardAnnotations); } else { Evaluator<Mention, Mention> evaluator = new Evaluator<Mention, Mention>(mentions, goldstandardAnnotations); evaluator.evaluate(); TP += evaluator.getTruePositives().size(); FP += evaluator.getFalsePositives().size(); FN += evaluator.getFalseNegatives().size(); truePositives.addAll(evaluator.getTruePositives()); falsePositives.addAll(evaluator.getFalsePositives()); falseNegatives.addAll(evaluator.getFalseNegatives()); evaluateNormalization(new ArrayList<Mention>(evaluator.getTruePositives()), goldstandardAnnotations); System.out.println(getEvaluationResult(TP, FP, FN)); if (!normalized.isEmpty()) { double correctAllRatio = !normalizedAll.isEmpty() ? (double)normalizedCorrect.size() / (double)normalizedAll.size() : 0; double correctNormalizedRatio = !normalized.isEmpty() ? (double)normalizedCorrect.size() / (double)normalized.size() : 0; System.out.format("%d of %d entities were normalized, %d correctly (%.2f %% of all and %.2f %% of normalized)%n", normalized.size(), normalizedAll.size(), normalizedCorrect.size(), correctAllRatio * 100.0, correctNormalizedRatio * 100); } } } } public static String evaluateByPredictionType(List<Mention> tps, List<Mention> predictions) { Map<ChemicalType, List<Mention>> mapTypeToPredictions = new HashMap<ChemicalType, List<Mention>>(); for (Mention prediction : predictions) { ChemicalType type = prediction.getType(); if (type != null) { if (!mapTypeToPredictions.containsKey(type)) { mapTypeToPredictions.put(type, new ArrayList<Mention>()); } mapTypeToPredictions.get(type).add(prediction); } } System.out.println(); StringBuffer buffer = new StringBuffer(); buffer.append(String.format("Evaluation by Prediction Type:%n")); List<ChemicalType> types = new ArrayList<ChemicalType>(mapTypeToPredictions.keySet()); Collections.sort(types); for (ChemicalType type : types) { List<Mention> typeGoldstandard = new ArrayList<Mention>(mapTypeToPredictions.get(type)); typeGoldstandard.retainAll(tps); List<Mention> typePredictions = mapTypeToPredictions.get(type); float precision = typePredictions.size() > 0 ? (float)typeGoldstandard.size() / (float)typePredictions.size() : 0; buffer.append(String.format("%s - %d correct, %d found, precision: %3.2f%%%n", type, typeGoldstandard.size(), typePredictions.size(), precision * 100.0)); } return buffer.toString(); } public static String evaluateByGoldstandardType(List<Mention> predictions, List<Mention> goldstandard) { Map<ChemicalType, List<Mention>> mapTypeToGoldstandard = new HashMap<ChemicalType, List<Mention>>(); for (Mention goldstandardMention : goldstandard) { ChemicalType type = goldstandardMention.getType(); if (type != null) { if (!mapTypeToGoldstandard.containsKey(type)) { mapTypeToGoldstandard.put(type, new ArrayList<Mention>()); } mapTypeToGoldstandard.get(type).add(goldstandardMention); } } System.out.println(); if (mapTypeToGoldstandard.size() > 1) { StringBuffer buffer = new StringBuffer(); buffer.append(String.format("Evaluation by Goldstandard Type:%n")); int tps = 0; int fps = 0; int fns = 0; List<ChemicalType> types = new ArrayList<ChemicalType>(mapTypeToGoldstandard.keySet()); Collections.sort(types); for (ChemicalType type : types) { List<Mention> typeGoldstandard = mapTypeToGoldstandard.get(type); List<Mention> typePredictions = new ArrayList<Mention>(); for (Mention prediction : predictions) { if (type.equals(prediction.getType())) { typePredictions.add(prediction); } } Evaluator<Mention, Mention> evaluator = new Evaluator<Mention, Mention>(typePredictions, typeGoldstandard); evaluator.evaluate(); tps += evaluator.getNumberOfTP(); fps += evaluator.getNumberOfFP(); fns += evaluator.getNumberOfFN(); buffer.append(type + " - " + getEvaluationResult(evaluator).replaceAll("\r?\n(?!$)", ", ").replaceAll(" +", " ")); } buffer.append("ALL - " + getEvaluationResult(tps, fps, fns).replaceAll("\r?\n(?!$)", ", ").replaceAll(" +", " ")); return buffer.toString(); } return ""; } private double evaluateNormalization(List<Mention> tps, List<Mention> goldStandard) { Collections.sort(tps); Collections.sort(goldStandard); int i = 0; for (Mention m : tps) { while (i < goldStandard.size() && goldStandard.get(i).getStart() < m.getStart()) i++; if (goldStandard.get(i).getStart() == m.getStart()) { Mention s = goldStandard.get(i); if (s.getCHEB() != null) { normalizedAll.add(s); if (m.getCHEB() != null && !m.getCHEB().isEmpty()) { normalized.add(m); if (m.getCHEB().equals(s.getCHEB())) { normalizedCorrect.add(m); } else { m.setCHEB(String.format("%s (correct: %s)", m.getCHEB(), s.getCHEB())); } } } } } return !normalizedAll.isEmpty() ? (double)normalizedCorrect.size() / (double)normalizedAll.size() : 0; } private static List<List<Mention>> sortMentionListsBySize(List<Mention> list, boolean bySource) { List<List<Mention>> result = new ArrayList<List<Mention>>(); Map<String, List<Mention>> annotationMap = new HashMap<String, List<Mention>>(); for (Mention mention : list) { String key = bySource ? mention.getSource() : mention.getText().toLowerCase(); if (!annotationMap.containsKey(key)) { annotationMap.put(key, new ArrayList<Mention>()); } annotationMap.get(key).add(mention); } for (String key : annotationMap.keySet()) { result.add(annotationMap.get(key)); } Comparator<List<Mention>> comparator = new Comparator<List<Mention>>() { public int compare(List<Mention> o1, List<Mention> o2) { return o1.size() - o2.size(); } }; Collections.sort(result, Collections.reverseOrder(comparator)); return result; } private static void writeOverlapping(OutputStream s, String name1, List<Mention> list1, String name2, List<Mention> list2) throws IOException { BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(s)); list1 = new ArrayList<Mention>(list1); Collections.sort(list1); list2 = new ArrayList<Mention>(list2); Collections.sort(list2); Pattern startPattern = Pattern.compile("(\\S+\\s+){5}\\S*$"); Pattern stopPattern = Pattern.compile("^\\S*(\\s+\\S+){5}"); writer.write(String.format("Overlapping occurrences of <%s> and [%s]:%n%n", name1, name2)); int maxLength = 100; int i = 0; for (Mention m1 : list1) { while (i < list2.size() && list2.get(i).getEnd() < m1.getStart()) i++; int j = i; while (j < list2.size() && list2.get(j).overlaps(m1)) { Mention m2 = list2.get(j++); if (!m1.getDocumentText().equals(m2.getDocumentText())) continue; String text = m1.getDocumentText(); int begin = Math.min(m1.getStart(), m2.getStart()); int end = Math.max(m1.getEnd(), m2.getEnd()); Matcher matcher = startPattern.matcher(text.substring(Math.max(begin-maxLength, 0), begin)); int start = matcher.find() ? Math.max(begin-maxLength, 0) + matcher.start() : Math.max(begin-30, 0); matcher = stopPattern.matcher(text.substring(end, Math.min(end+maxLength, text.length()))); int stop = matcher.find() ? end + matcher.end() : Math.min(end+30, text.length()); StringBuilder sb = new StringBuilder(); sb.append(text.substring(start, stop)); sb.insert(m1.getStart() - start, '<'); sb.insert(m1.getEnd() - start + 1, '>'); sb.insert(m2.getStart() - start + (m1.getStart() < m2.getStart() ? 1 : 0) + (m1.getStart() == m2.getStart() && m1.getEnd() > m2.getEnd() ? 1 : 0), "["); sb.insert(m2.getEnd() - start + 2 + (m1.getEnd() < m2.getEnd() || (m1.getEnd() == m2.getEnd() && m1.getStart() > m2.getStart()) ? 1 : 0), "]").toString(); writer.write("..." + sb.toString().replaceAll("\r?\n", "\\\\n") + "..."); writer.newLine(); } } writer.flush(); } private static void writeList(OutputStream s, String name, List<Mention> list) throws IOException { BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(s)); List<List<Mention>> listBySize = sortMentionListsBySize(list, false); writer.write(String.format("%n%n%n%s:%n", name)); writer.write(String.format("%8s\t%25s\t%30s\t%s%n", "#", "CHEMICAL", "SOURCE", "ChEBI ID")); for (List<Mention> annotationList : listBySize) { List<List<Mention>> listBySource = sortMentionListsBySize(annotationList, true); String sources = ""; for (List<Mention> sourceList : listBySource) { String source = !sourceList.isEmpty() ? sourceList.get(0).getSource() : ""; source = source == null || source.isEmpty() ? Constants.UNKNOWN : source; if (listBySource.size() == 1) { source = (sourceList.size() > 1 ? "all " : "") + source; } else { source = sourceList.size() + " " + source; } sources += String.format("%s%s", !sources.isEmpty() ? ", " : "", source); } Set<String> ids = new HashSet<String>(); for (Mention m : annotationList) { String id = m.getCHEB(); if (id != null && !id.trim().isEmpty() && !"null".equals(id.trim()) && !ids.contains(id.trim())) { ids.add(id.trim()); } } String idString = ""; for (String id : ids) { idString += ("".equals(idString) ? "" : ", ") + id; } String annotation = !annotationList.isEmpty() ? annotationList.get(0).getText() : ""; writer.write(String.format("%8d\t%25s\t%30s\t%s%n", annotationList.size(), annotation, sources, idString)); } writer.flush(); } private static void writeContext(OutputStream s, String name, List<Mention> list) throws IOException { BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(s)); List<List<Mention>> listBySize = sortMentionListsBySize(list, false); writer.write(String.format("%s:%n", name)); Pattern startPattern = Pattern.compile("(\\S+\\s+){5}\\S*$"); Pattern stopPattern = Pattern.compile("^\\S*(\\s+\\S+){5}"); int maxLength = 100; for (List<Mention> annotationList : listBySize) { if (annotationList.isEmpty()) continue; writer.write(String.format("%n%n%s (%d):%n", annotationList.get(0).getText(), annotationList.size())); int i = 0; for (Mention mention : annotationList) { if (i++ > 30) { writer.write("..."); writer.newLine(); break; } String text = null; text = mention.getDocumentText(); int begin = mention.getStart(); int end = mention.getEnd(); Matcher matcher = startPattern.matcher(text.substring(Math.max(begin-maxLength, 0), begin)); int start = matcher.find() ? Math.max(begin-maxLength, 0) + matcher.start() : Math.max(begin-30, 0); matcher = stopPattern.matcher(text.substring(end, Math.min(end+maxLength, text.length()))); int stop = matcher.find() ? end + matcher.end() : Math.min(end+30, text.length()); String output = String.format("...%s<%s>%s...", text.substring(start , begin), text.substring(begin, end), text.substring(end, stop)); output = output.replaceAll("\r?\n", "\\\\n"); writer.write(output); writer.newLine(); } } writer.flush(); } public void writeNormalizations(OutputStream s, List<Mention> normalizedAll, List<Mention> normalized, List<Mention> normalizedCorrect) throws IOException { BufferedWriter w = new BufferedWriter(new OutputStreamWriter(s)); int normalizedAllCount = normalizedAll.size(); int normalizedCount = normalized.size(); int normalizedCorrectCount = normalizedCorrect.size(); double correctAllRatio = normalizedAllCount != 0 ? (double)normalizedCorrect.size() / (double)normalizedAll.size() : 0; double precision = normalizedCount != 0 ? (double)normalizedCorrectCount / normalizedCount : 0; double recall = normalizedCount != 0 ? (double)normalizedCount / normalizedAllCount : 0; double fScore = (precision != 0 || recall != 0) ? 2 * precision * recall / (precision + recall) : 0 ; w.write(String.format("entities total : %d%n", normalizedAllCount)); w.write(String.format("entities normalized : %d%n", normalizedCount)); w.write(String.format("normalized correct : %d%n", normalizedCorrectCount)); w.write(String.format("percent correct (all) : %.2f %%%n%n", correctAllRatio * 100.0)); w.write(String.format("precision: %.2f %%%n" , precision * 100.0)); w.write(String.format("recall: %.2f %%%n" , recall * 100.0)); w.write(String.format("f1 score: %.2f %%%n" , fScore * 100.0)); w.flush(); writeList(s, "correct", normalizedCorrect); List<Mention> normalizedIncorrect = new ArrayList<Mention>(normalized); normalizedIncorrect.removeAll(normalizedCorrect); writeList(s, "incorrect", normalizedIncorrect); List<Mention> notNormalized = new ArrayList<Mention>(normalizedAll); notNormalized.removeAll(normalizedCorrect); notNormalized.removeAll(normalizedIncorrect); writeList(s, "not normalized", notNormalized); } public void writeDetailedEvaluationResults(String outputPath) throws IOException { synchronized (evaluationLock) { if (outputPath == null) outputPath = ""; File evaluationFile = new File(outputPath + "evaluation.txt"); OutputStream writer = new FileOutputStream(evaluationFile); BufferedWriter w = new BufferedWriter(new OutputStreamWriter(writer)); w.write(getEvaluationResult(TP, FP, FN)); String predictionTypeEvaluation = evaluateByPredictionType(truePositives, predictions); System.out.printf("%s", predictionTypeEvaluation); w.write(String.format("%n%n%s", predictionTypeEvaluation)); String goldstandardTypeEvaluation = evaluateByGoldstandardType(predictions, goldstandard); if (!goldstandardTypeEvaluation.isEmpty()) { System.out.printf("%s%n%n", goldstandardTypeEvaluation); w.write(String.format("%n%n%s", goldstandardTypeEvaluation)); } w.flush(); writeList(writer, "true positives", truePositives); writeList(writer, "false negatives", falseNegatives); writeList(writer, "false positives", falsePositives); w.close(); writer.close(); System.out.println("Evaluation results written to: " + evaluationFile.getName()); File falsePositivesFile = new File(outputPath + "evaluation-FPs.txt"); writer = new FileOutputStream(falsePositivesFile); writeContext(writer, "false positives contexts", falsePositives); writer.close(); System.out.println("False positive contexts written to: " + falsePositivesFile.getName()); File falseNegativesFile = new File(outputPath + "evaluation-FNs.txt"); writer = new FileOutputStream(falseNegativesFile); writeContext(writer, "false negatives contexts", falseNegatives); writer.close(); System.out.println("False negative contexts written to: " + falseNegativesFile.getName()); File falsePositivesNegativesFile = new File(outputPath + "evaluation-overlappings-FPs-FNs.txt"); writer = new FileOutputStream(falsePositivesNegativesFile); writeOverlapping(writer, "false negatives", falseNegatives, "false positives", falsePositives); writer.close(); System.out.println("Overlapping occurrences of false positives and negatives written to: " + falsePositivesNegativesFile.getName()); synchronized(normalizationLock) { if (!normalized.isEmpty()) { File normalizedFile = new File(outputPath + "normalizations.txt"); writer = new FileOutputStream(normalizedFile); writeNormalizations(writer, normalizedAll, normalized, normalizedCorrect); writer.close(); System.out.println("Normalized entities written to: " + normalizedFile.getName()); } } } } public int getTP() { return TP; } public void setTP(int tP) { TP = tP; } public int getFP() { return FP; } public void setFP(int fP) { FP = fP; } public int getFN() { return FN; } public void setFN(int fN) { FN = fN; } public List<Mention> getTruePositives() { return truePositives; } public void setTruePositives(List<Mention> truePositives) { this.truePositives = truePositives; } public List<Mention> getFalsePositives() { return falsePositives; } public void setFalsePositives(List<Mention> falsePositives) { this.falsePositives = falsePositives; } public List<Mention> getFalseNegatives() { return falseNegatives; } public void setFalseNegatives(List<Mention> falseNegatives) { this.falseNegatives = falseNegatives; } public List<Mention> getNormalizedAll() { return normalizedAll; } public void setNormalizedAll(List<Mention> normalizedAll) { this.normalizedAll = normalizedAll; } public List<Mention> getNormalized() { return normalized; } public void setNormalized(List<Mention> normalized) { this.normalized = normalized; } public List<Mention> getNormalizedCorrect() { return normalizedCorrect; } public void setNormalizedCorrect(List<Mention> normalizedCorrect) { this.normalizedCorrect = normalizedCorrect; } }