package edu.stanford.nlp.dcoref; import java.util.*; /** * B^3 scorer * @author heeyoung * */ public class ScorerBCubed extends CorefScorer { protected enum BCubedType {B0, Ball, Brahman, Bcai, Bconll} private final BCubedType type; public ScorerBCubed(BCubedType _type) { super(ScoreType.BCubed); type = _type; } @Override protected void calculatePrecision(Document doc){ switch(type){ case Bcai: calculatePrecisionBcai(doc); break; case Ball: calculatePrecisionBall(doc); break; case Bconll: calculatePrecisionBconll(doc); break; // same as Bcai } } @Override protected void calculateRecall(Document doc){ switch(type){ case Bcai: calculateRecallBcai(doc); break; case Ball: calculateRecallBall(doc); break; case Bconll: calculateRecallBconll(doc); break; } } private void calculatePrecisionBall(Document doc){ int pDen = 0; double pNum = 0.0; Map<Integer, Mention> goldMentions = doc.allGoldMentions; Map<Integer, Mention> predictedMentions = doc.allPredictedMentions; for(Mention m : predictedMentions.values()){ double correct = 0.0; double total = 0.0; for(Mention m2 : doc.corefClusters.get(m.corefClusterID).getCorefMentions()){ if(m==m2 || (goldMentions.containsKey(m.mentionID) && goldMentions.containsKey(m2.mentionID) && goldMentions.get(m.mentionID).goldCorefClusterID == goldMentions.get(m2.mentionID).goldCorefClusterID)) { correct++; } total++; } pNum += correct/total; pDen++; } precisionDenSum += pDen; precisionNumSum += pNum; } private void calculateRecallBall(Document doc){ int rDen = 0; double rNum = 0.0; Map<Integer, Mention> goldMentions = doc.allGoldMentions; Map<Integer, Mention> predictedMentions = doc.allPredictedMentions; for(Mention m : goldMentions.values()){ double correct = 0.0; double total = 0.0; for(Mention m2 : doc.goldCorefClusters.get(m.goldCorefClusterID).getCorefMentions()){ if(m==m2 || (predictedMentions.containsKey(m.mentionID) && predictedMentions.containsKey(m2.mentionID) && predictedMentions.get(m.mentionID).corefClusterID == predictedMentions.get(m2.mentionID).corefClusterID)) { correct++; } total++; } rNum += correct/total; rDen++; } recallDenSum += rDen; recallNumSum += rNum; } private void calculatePrecisionBcai(Document doc) { int pDen = 0; double pNum = 0.0; Map<Integer, Mention> goldMentions = doc.allGoldMentions; Map<Integer, Mention> predictedMentions = doc.allPredictedMentions; for(Mention m : predictedMentions.values()){ if(!goldMentions.containsKey(m.mentionID) && doc.corefClusters.get(m.corefClusterID).getCorefMentions().size()==1){ continue; } double correct = 0.0; double total = 0.0; for(Mention m2 : doc.corefClusters.get(m.corefClusterID).getCorefMentions()){ if(m==m2 || (goldMentions.containsKey(m.mentionID) && goldMentions.containsKey(m2.mentionID) && goldMentions.get(m.mentionID).goldCorefClusterID == goldMentions.get(m2.mentionID).goldCorefClusterID)) { correct++; } total++; } pNum += correct/total; pDen++; } for(int id : goldMentions.keySet()) { if(!predictedMentions.containsKey(id)) { pNum++; pDen++; } } precisionDenSum += pDen; precisionNumSum += pNum; } private void calculateRecallBcai(Document doc) { int rDen = 0; double rNum = 0.0; Map<Integer, Mention> goldMentions = doc.allGoldMentions; Map<Integer, Mention> predictedMentions = doc.allPredictedMentions; for(Mention m : goldMentions.values()){ double correct = 0.0; double total = 0.0; for(Mention m2 : doc.goldCorefClusters.get(m.goldCorefClusterID).getCorefMentions()){ if(m==m2 || (predictedMentions.containsKey(m.mentionID) && predictedMentions.containsKey(m2.mentionID) && predictedMentions.get(m.mentionID).corefClusterID == predictedMentions.get(m2.mentionID).corefClusterID)) { correct++; } total++; } rNum += correct/total; rDen++; } recallDenSum += rDen; recallNumSum += rNum; } private void calculatePrecisionBconll(Document doc) { // same as Bcai calculatePrecisionBcai(doc); } private void calculateRecallBconll(Document doc) { int rDen = 0; double rNum = 0.0; Map<Integer, Mention> goldMentions = doc.allGoldMentions; Map<Integer, Mention> predictedMentions = doc.allPredictedMentions; for(Mention m : goldMentions.values()){ double correct = 0.0; double total = 0.0; for(Mention m2 : doc.goldCorefClusters.get(m.goldCorefClusterID).getCorefMentions()){ if(m==m2 || (predictedMentions.containsKey(m.mentionID) && predictedMentions.containsKey(m2.mentionID) && predictedMentions.get(m.mentionID).corefClusterID == predictedMentions.get(m2.mentionID).corefClusterID)) { correct++; } total++; } rNum += correct/total; rDen++; } // this part is different from Bcai for(Mention m : predictedMentions.values()) { if(!goldMentions.containsKey(m.mentionID) && doc.corefClusters.get(m.corefClusterID).getCorefMentions().size()!=1) { rNum++; rDen++; } } recallDenSum += rDen; recallNumSum += rNum; } }