/**
* Copyright (c) 2011, Regents of the University of Colorado All rights
* reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer. Redistributions in binary
* form must reproduce the above copyright notice, this list of conditions and
* the following disclaimer in the documentation and/or other materials provided
* with the distribution. Neither the name of the University of Colorado at
* Boulder nor the names of its contributors may be used to endorse or promote
* products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package clear.dep.srl;
import clear.dep.DepNode;
import clear.dep.DepTree;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
/**
* Compare two dependency-based semantic role labeling outputs.
*
* @author Jinho D. Choi <b>Last update:</b> 4/19/2011
*/
public class SRLEval {
private final String TOTAL = "TOTAL";
private HashMap<String, int[]> m_score;
public SRLEval() {
m_score = new HashMap<>();
int[] value = new int[4];
m_score.put(TOTAL, value);
}
public void evaluate(DepTree gold, DepTree sys) {
for (int i = 1; i < gold.size(); i++) {
DepNode gNode = gold.get(i);
DepNode sNode = sys.get(i);
measure(gNode, sNode);
}
}
private void measure(DepNode gNode, DepNode sNode) {
ArrayList<SRLHead> gHeads = gNode.srlInfo.heads;
ArrayList<SRLHead> sHeads = sNode.srlInfo.heads;
int[] total = m_score.get(TOTAL);
// int[] local = new int[4];
int[] gArg, sArg;
for (SRLHead gHead : gHeads) {
if (gHead.label.startsWith("C-")) {
gArg = getArray(gHead.label.substring(2));
} else {
gArg = getArray(gHead.label);
}
for (SRLHead sHead : sHeads) {
if (sHead.equals(gHead.headId)) {
total[0]++;
// local[0]++;
if (sHead.label.equals(gHead.label)) {
total[1]++;
// local[1]++;
gArg[0]++;
}
}
}
gArg[2]++; // recall
}
total[2] += sHeads.size(); // precision
total[3] += gHeads.size(); // recall
// local[2] += sHeads.size();
// local[3] += gHeads.size();
for (SRLHead sHead : sHeads) {
if (sHead.label.startsWith("C-")) {
sArg = getArray(sHead.label.substring(2));
} else {
sArg = getArray(sHead.label);
}
sArg[1]++; // precision
}
/*
* if (local[3] > 0) { double p = (local[2] > 0) ? 100d * local[0] /
* local[2] : 0d; double r = (local[3] > 0) ? 100d * local[0] / local[3]
* : 0d; double fai = (p+r > 0) ? getF1(p, r) : 0d;
*
* p = (local[2] > 0) ? 100d * local[1] / local[2] : 0d; r = (local[3] >
* 0) ? 100d * local[1] / local[3] : 0d; double fac = (p+r > 0) ?
* getF1(p, r) : 0d;
*
* System.out.println(fai+"\t"+fac);
}
*/
}
private int[] getArray(String label) {
if (m_score.containsKey(label)) {
return m_score.get(label);
} else {
int[] value = new int[3];
m_score.put(label, value);
return value;
}
}
public void print() {
System.out.println("--------------------------------------------------");
System.out.printf("%10s%10s%10s%10s%10s\n", "Label", "Dist", "P", "R", "F1");
System.out.println("--------------------------------------------------");
int total = printTotal();
System.out.println("--------------------------------------------------");
ArrayList<String> labels = new ArrayList<>(m_score.keySet());
Collections.sort(labels);
for (String label : labels) {
if (!label.equals(TOTAL)) {
printLocal(label, total);
}
}
System.out.println("--------------------------------------------------");
}
private int printTotal() {
int[] value = m_score.get(TOTAL);
double precision = 100d * value[0] / value[2];
double recall = 100d * value[0] / value[3];
double f1 = getF1(precision, recall);
printEach("UAS", 100, precision, recall, f1);
precision = 100d * value[1] / value[2];
recall = 100d * value[1] / value[3];
f1 = getF1(precision, recall);
printEach("LAS", 100, precision, recall, f1);
return value[3];
}
private int printLocal(String label, int total) {
int[] value = m_score.get(label);
double dist = 100d * value[2] / total;
double precision = 100d * value[0] / value[1];
double recall = 100d * value[0] / value[2];
double f1 = getF1(precision, recall);
printEach(label, dist, precision, recall, f1);
return value[2];
}
private void printEach(String label, double dist, double precision, double recall, double f1) {
System.out.printf("%10s%10.2f%10.2f%10.2f%10.2f\n", label, dist, precision, recall, f1);
}
public double getF1() {
int[] value = m_score.get(TOTAL);
double precision = 100d * value[1] / value[2];
double recall = 100d * value[1] / value[3];
return getF1(precision, recall);
}
static public double getF1(double precision, double recall) {
return 2 * (precision * recall) / (precision + recall);
}
}