package edu.fudan.ml.eval;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.Map.Entry;
import edu.fudan.util.MyCollection;
/**
* 统计 实验结果的Precision,recall 和 FB1值
* @author fxx
*/
public class SeqEval {
private String STRPERCENT = "%";
private String STRLINE = "\n";
private String sep = "";
/**
* 存放实体类型
*/
private static Set<String> entityType;
/**
* 存放正确的实体的容器
*/
private ArrayList<LinkedList<Entity>> entityCs = new ArrayList<LinkedList<Entity>>();
/**
* 存放估计的实体的容器
*/
private ArrayList<LinkedList<Entity>> entityPs = new ArrayList<LinkedList<Entity>>();
/**
* 存放估计中正确实体的容器
*/
private ArrayList<LinkedList<Entity>> entityCinPs = new ArrayList<LinkedList<Entity>> ();
/**
* 词典
*/
HashSet<String> dict;
private boolean latex = false;
public SeqEval() {
if(latex){
STRPERCENT = "\\%";
STRLINE = "\\\\\\hline\n";
sep = "&";
}
}
/**
* 读取评测结果文件,并输出到outputPath
* @param filePath 待评测文件路径
* @param outputPath 评测结果的输出路径
* @throws IOException
*/
public void NeEvl(String outputPath) throws IOException{
String res = "";
res += calcByLength()+"\n";
res += calcByType()+"\n";
res += calcByOOV()+"\n";
res += calcByCOOV()+"\n";
res += calcByOOV2()+"\n";
res += calcByOOVRate()+"\n";
if(outputPath != null ){
File outFile =new File(outputPath);
Writer out=new OutputStreamWriter(new FileOutputStream(outFile));
out.write(res);
out.close();
}
}
private String calcByLength() {
/**
* 估计的中正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mpc = new TreeMap<Integer,Double>();
/**
* 估计的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mp = new TreeMap<Integer,Double>();
/**
* 正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mc = new TreeMap<Integer,Double>();
/**
* OOV
*/
Map<String,Double> oov = new TreeMap<String,Double>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
for(Entity entity:cList){
int len = entity.getEndIndex() - entity.getStartIndex()+1;
adjust(mc, len,1.0);
if(dict!=null&&dict.size()>0){
String s = entity.getEntityStr();
if(!dict.contains(s)){
adjust(oov, len, 1.0);
}
}
}
for(Entity entity:pList){
int len = entity.getEndIndex() - entity.getStartIndex()+1;
adjust(mp, len,1.0);
}
for(Entity entity:cpList){
int len = entity.getEndIndex() - entity.getStartIndex()+1;
adjust(mpc, len,1.0);
}
}
return toString("Length", mpc, mp, mc, oov);
}
private void adjust(Map mc, Object key, double d) {
if(mc.containsKey(key)){
mc.put(key, (Double)mc.get(key)+d);
}else{
mc.put(key, d);
}
}
private String calcByOOV2() {
/**
* 估计的中正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mpc = new TreeMap<Integer,Double>();
/**
* 估计的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mp = new TreeMap<Integer,Double>();
/**
* 正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mc = new TreeMap<Integer,Double>();
/**
* OOV
*/
Map<Integer,Double> oov = new TreeMap<Integer,Double>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
for(Entity entity:cList){
String type;
if(dict!=null&&dict.size()>0){
String s = entity.getEntityStr();
if(dict.contains(s)){
type="INV";
adjust(oov, type, 0.0);
}else{
type="OOV";
adjust(oov, type, 1.0);
}
adjust(mc, type, 1.0);
}
}
for(Entity entity:pList){
String type;
if(dict!=null&&dict.size()>0){
String s = entity.getEntityStr();
if(dict.contains(s)){
type="INV";
}else{
type="OOV";
}
adjust(mp, type, 1.0);
}
}
for(Entity entity:cpList){
String type;
if(dict!=null&&dict.size()>0){
String s = entity.getEntityStr();
if(dict.contains(s)){
type="INV";
}else{
type="OOV";
}
adjust(mpc, type, 1.0);
}
}
}
return toString("INV/OOV",mpc, mp, mc, oov);
}
private String calcByOOV() {
/**
* 估计的中正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mpc = new TreeMap<Integer,Double>();
/**
* 估计的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mp = new TreeMap<Integer,Double>();
/**
* 正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mc = new TreeMap<Integer,Double>();
/**
* OOV
*/
Map<Integer,Double> oov = new TreeMap<Integer,Double>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
ArrayList<String> oovs = findOOV(cList);
int num = oovs.size();
// int num = findContinousOOV(cList);
adjust(mc, num, cList.size());
adjust(oov, num, num);
adjust(mp, num, pList.size());
adjust(mpc, num, cpList.size());
}
return toString("OOV",mpc, mp, mc, oov);
}
private String calcByCOOV() {
/**
* 估计的中正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mpc = new TreeMap<Integer,Double>();
/**
* 估计的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mp = new TreeMap<Integer,Double>();
/**
* 正确的,key是字符串长度,value是这种长度的个数
*/
Map<Integer,Double> mc = new TreeMap<Integer,Double>();
/**
* OOV
*/
Map<Integer,Double> oov = new TreeMap<Integer,Double>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
int num = findContinousOOV(cList);
adjust(mc, num, cList.size());
adjust(oov, num, num);
adjust(mp, num, pList.size());
adjust(mpc, num, cpList.size());
}
return toString("COOV",mpc, mp, mc, oov);
}
private String calcByOOVRate() {
/**
* 估计的中正确的,key是字符串长度,value是这种长度的个数
*/
Map<Double,Double> mpc = new TreeMap<Double,Double>();
/**
* 估计的,key是字符串长度,value是这种长度的个数
*/
Map<Double,Double> mp = new TreeMap<Double,Double>();
/**
* 正确的,key是字符串长度,value是这种长度的个数
*/
Map<Double,Double> mc = new TreeMap<Double,Double>();
/**
* OOV
*/
Map<Double,Double> oov = new TreeMap<Double,Double>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
ArrayList<String> oovs = findOOV(cList);
int num = oovs.size();
double num1;
if(num==0)
num1=0;
else
num1 = Math.ceil(num/(double)cList.size()*20)/20;
adjust(mc, num1, cList.size());
adjust(oov, num1, num);
adjust(mp, num1, pList.size());
adjust(mpc, num1, cpList.size());
}
return toString("OOVRate",mpc, mp, mc, oov);
}
private String toString(String mark, Map mpc,
Map mp, Map mc, Map oov) {
//输出统计数据
DecimalFormat df = new DecimalFormat("0.00");
DecimalFormat df1 = new DecimalFormat("0");
StringBuffer strOutBuf = new StringBuffer();
String strInfo = mark
+ "\t" + sep + "\t" + "Precision"
+ "\t" + sep + "\t" + "Recall"
+ "\t" + sep + "\t" + "FB1"
+ "\t" + sep + "\t" + "PCount"
+ "\t" + sep + "\t" + "CCount"
+ "\t" + sep + "\t" + "Correct"
+ "\t" + sep + "\t" + "OOVRate"
;
strOutBuf.append(strInfo + STRLINE);
for(Object key:mc.keySet()){
double oovrate = (Double)oov.get(key)/(Double)mc.get(key);
if(mpc.containsKey(key) && mp.containsKey(key)){
double pre = (Double) mpc.get(key)/(Double) mp.get(key);
double recall = (Double)mpc.get(key)/(Double)mc.get(key);
double FB1 = (pre*recall*2)/(recall+pre);
String str = key
+ "\t" + sep + "\t" + df.format(pre*100).replaceAll("\\.00$", "") +STRPERCENT
+ "\t\t" + sep + "\t" + df.format(recall*100).replaceAll("\\.00$", "") + STRPERCENT
+ "\t" + sep + "\t" + df.format(FB1*100).replaceAll("\\.00$", "") +STRPERCENT
+ "\t" + sep + "\t" + df1.format(mp.get(key))
+ "\t" + sep + "\t" + df1.format(mc.get(key))
+ "\t" + sep + "\t" + df1.format(mpc.get(key))
+ "\t" + sep + "\t" + df.format(oovrate*100).replaceAll("\\.00$", "")+STRPERCENT;
;
strOutBuf.append(str + STRLINE );
}else{
String str = key
+ "\t" + sep + "\t" + 0 + "%\t"
+ "\t" + sep + "\t" + 0 + STRPERCENT
+ "\t" + sep + "\t" + 0 + STRPERCENT
+ "\t" + sep + "\t" + 0
+ "\t" + sep + "\t" + df1.format(mc.get(key))
+ "\t" + sep + "\t" + 0+ STRPERCENT
+ "\t" + sep + "\t" + df.format(oovrate*100).replaceAll("\\.00$", "") +STRPERCENT;
;
strOutBuf.append(str + STRLINE);
}
}
System.out.println(strOutBuf.toString());
return strOutBuf.toString();
}
private int findContinousOOV(LinkedList<Entity> cList) {
ArrayList<String> oovs = new ArrayList<String>();
int num = 0;
int max = 0;
if(dict!=null&&dict.size()>0){
for(Entity e: cList){
String s = e.getEntityStr();
if(!dict.contains(s)){
num++;
if(num>max)
max=num;
}
else{
num=0;
}
}
}
if(oovs.size()>11)
System.out.println(oovs);
return max;
}
private ArrayList<String> findOOV(LinkedList<Entity> cList) {
ArrayList<String> oovs = new ArrayList<String>();
if(dict!=null&&dict.size()>0){
for(Entity e: cList){
String s = e.getEntityStr();
if(!dict.contains(s))
oovs.add(s);
}
}
// if(oovs.size()>11)
// System.out.println(oovs);
return oovs;
}
private String calcByType() {
/**
* 估计的中正确的,key是字符串长度,value是这种长度的个数
*/
Map<String,Double> mpc = new TreeMap<String,Double>();
/**
* 估计的,key是字符串长度,value是这种长度的个数
*/
Map<String,Double> mp = new TreeMap<String,Double>();
/**
* 正确的,key是字符串长度,value是这种长度的个数
*/
Map<String,Double> mc = new TreeMap<String,Double>();
/**
* OOV
*/
Map<String,Double> oov = new TreeMap<String,Double>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
for(Entity entity:cList){
String type = entity.getType();
adjust(mc, type, 1.0);
if(dict!=null&&dict.size()>0){
String s = entity.getEntityStr();
if(!dict.contains(s)){
adjust(oov, type, 1.0);
}
}
}
for(Entity entity:pList){
String type = entity.getType();
adjust(mp, type, 1.0);
}
for(Entity entity:cpList){
String type = entity.getType();
adjust(mpc, type, 1.0);
}
}
return toString("Type",mpc, mp, mc,oov);
}
/**
* 从reader中提取实体,存到相应的队列中,并统计固定长度实体的个数,存到相应的map中
* @param reader 结果文件的流
* @throws IOException
*/
public void read(String filePath) throws IOException{
String line;
ArrayList<String> words = new ArrayList<String>();
ArrayList<String> markP = new ArrayList<String>();
ArrayList<String> typeP = new ArrayList<String>();
ArrayList<String> markC = new ArrayList<String>();
ArrayList<String> typeC = new ArrayList<String>();
if(filePath == null)
return;
File file = new File(filePath);
BufferedReader reader = null;
entityType = new HashSet<String>();
//按行读取文件内容,一次读一整行
reader = new BufferedReader(new InputStreamReader(new FileInputStream(file),"UTF-8"));
//从文件中提取实体并存入队列中
while ((line = reader.readLine()) != null) {
if(line.equals("")){
newextract(words, markP, typeP, markC, typeC);
}else{
//判断实体,实体开始的边界为B-***或者S-***,结束的边界为E-***或N(O)或空白字符或B-***
//predict
String[] toks = line.split("\\s+");
String[] marktype = getMarkType(toks[1]);
words.add(toks[0]);
markP.add(marktype[0]);
typeP.add(marktype[1]);
entityType.add(marktype[1]);
//correct
marktype = getMarkType(toks[2]);
markC.add(marktype[0]);
typeC.add(marktype[1]);
entityType.add(marktype[1]);
}
}
reader.close();
if(words.size()>0){
newextract(words, markP, typeP, markC, typeC);
}
//从entityPs和entityCs提取正确估计的实体,存到entityCinPs,并更新mpc中的统计信息
extractCInPEntity();
}
private void newextract(ArrayList<String> words, ArrayList<String> markP,
ArrayList<String> typeP, ArrayList<String> markC,
ArrayList<String> typeC) {
LinkedList<Entity> entitylist1 = extract(words,markP,typeP);
entityPs.add(entitylist1);
LinkedList<Entity> entitylist2 = extract(words,markC,typeC);
entityCs.add(entitylist2);
words.clear();
markP.clear();
typeP.clear();
markC.clear();
typeC.clear();
}
private LinkedList<Entity> extract(ArrayList<String> words, ArrayList<String> marks,
ArrayList<String> types) {
int entityStartIndexC = -1; //正确的实体的起始位置
LinkedList<Entity> entitylist = new LinkedList<Entity>();
//记录的是否是估计实体开始的标志
boolean in = true;
StringBuilder sb = new StringBuilder();
for(int i=0;i<words.size();i++){
if(isStart(marks,types,i)){
in = true;
sb = new StringBuilder();
entityStartIndexC = i;
}
if(in)
sb.append(words.get(i));
if(isEnd(marks, types, i)||isStart(marks, types, i+1)){
if(!in){
System.err.println("E");
}
in = false;
Entity entity = new Entity(entityStartIndexC,i, sb.toString().trim());
entity.setType(types.get(i));
entitylist.add(entity);
}
}
return entitylist;
}
private boolean isStart(ArrayList<String> marks, ArrayList<String> types,
int i) {
boolean start = false;
String prevMark;
if(i==0)
prevMark= "O";
else
prevMark = marks.get(i-1);
String curMark = marks.get(i);
String prevType;
if(i==0)
prevType= "";
else
prevType = types.get(i-1);
String curType = types.get(i);
if(curMark.equalsIgnoreCase("B")||curMark.equalsIgnoreCase("S"))
start = true;
else if(prevMark.equalsIgnoreCase("E")&&curMark.equalsIgnoreCase("E"))
start = true;
else if(prevMark.equalsIgnoreCase("E")&&curMark.equalsIgnoreCase("M"))
start = true;
else if(prevMark.equalsIgnoreCase("S")&&curMark.equalsIgnoreCase("M"))
start = true;
else if(prevMark.equalsIgnoreCase("S")&&curMark.equalsIgnoreCase("E"))
start = true;
else if(prevMark.equalsIgnoreCase("O")&&curMark.equalsIgnoreCase("E"))
start = true;
else if(prevMark.equalsIgnoreCase("O")&&curMark.equalsIgnoreCase("M"))
start = true;
else if(!curMark.equalsIgnoreCase("O")&&!curType.equalsIgnoreCase(prevType))
start = true;
return start;
}
private boolean isEnd(ArrayList<String> marks, ArrayList<String> types,
int i) {
boolean end = false;
String nextMark;
if(i==marks.size()-1)
nextMark= "O";
else
nextMark = marks.get(i+1);
String curMark = marks.get(i);
String nextType;
if(i==types.size()-1)
nextType= "";
else
nextType = types.get(i+1);
String curType = types.get(i);
if(curMark.equalsIgnoreCase("E")||curMark.equalsIgnoreCase("S"))
end = true;
else if(nextMark.equalsIgnoreCase("O"))
end = true;
else if(nextMark.equalsIgnoreCase("B"))
end = true;
else if(nextMark.equalsIgnoreCase("S"))
end = true;
else if(!curType.equalsIgnoreCase(nextType))
end = true;
return end;
}
private String[] getMarkType(String label) {
String[] types = new String[2];
int idx = label.indexOf('-');
if(idx!=-1){
types[0] = label.substring(0,idx);
types[1] = label.substring(idx+1);
}else{
types[0] = label;
types[1] = "";
}
return types;
}
/**
* 提取在估计中正确的实体,存到entityCinPs中,并将长度个数统计信息存到mpc中
*/
public void extractCInPEntity(){
//得到在predict中正确的Pc
for(int i=0;i<entityPs.size();i++){
LinkedList<Entity> entityCstmp = new LinkedList<Entity>();;
LinkedList<Entity> entityps = entityPs.get(i);
LinkedList<Entity> entitycs = entityCs.get(i);
LinkedList<Entity> entitycinps = new LinkedList<Entity>();
for(Entity entityp:entityps){
while(!entitycs.isEmpty()){
Entity entityc = entitycs.peek();
if(entityp.equals(entityc)){
entitycinps.add(entityp);
entityCstmp.offer(entitycs.poll());
break;
}else if(entityp.getStartIndex() == entityc.getStartIndex()){
entityCstmp.offer(entitycs.poll());
break;
}
else if(entityp.getStartIndex() > entityc.getStartIndex()){
entityCstmp.offer(entitycs.poll());
}else{
break;
}
}
}
entityCinPs.add(entitycinps);
for(Entity entityp:entityCstmp){
entitycs.offer(entityp);
}
}
}
public HashSet<String> readOOV(String path) throws IOException{
dict = new HashSet<String>();
BufferedReader bfr;
bfr = new BufferedReader(new InputStreamReader(new FileInputStream(path),"utf8"));
String line = null;
while ((line = bfr.readLine()) != null) {
if(line.length()==0)
continue;
dict.add(line);
}
bfr.close();
return dict;
}
public static void main(String[] args) throws IOException{
String filePath = null;
String outputPath = null;
if(args.length >0){
if(args[0].equals("-h")){
System.out.println("NeSatistic.jar 要评测的文件 [输出到文件]");
}else{
filePath = args[0];
}
}
if(args.length == 2){
outputPath = args[1];
}
filePath ="./paperdata/ctb6-seg/work/ctb_三列式结果_0.txt";
String dictpath = "./paperdata/ctb6-seg/train-dict.txt";
// filePath = "./example-data/sequence/seq.res";
//读取评测结果文件,并输出到outputPath
SeqEval ne1;
ne1 = new SeqEval();
ne1.readOOV(dictpath);
ne1.read(filePath);
// ne1.getWrongOOV("./paperdata/ctb6-seg/wrong-dict.txt");
ne1.getRightOOV("./paperdata/ctb6-seg/right-pattern.txt");
ne1.NeEvl(null);
// ne1 = new NESatistic();
// ne1.readOOV("./paperdata/exp-data/msr_training_words.utf8");
// ne1.read("./paperdata/exp-data/msr_三列式结果_0.txt");
// ne1.getWrongOOV("./paperdata/exp-data/msr_OOV-Wrong.txt");
// ne1.NeEvl(null);
//
// ne1 = new NESatistic();
// ne1.readOOV("./paperdata/exp-data/as_training_words.utf8");
// ne1.read("./paperdata/exp-data/as_三列式结果_0.txt");
// ne1.getWrongOOV("./paperdata/exp-data/as_OOV-Wrong.txt");
// ne1.NeEvl(null);
//
// ne1 = new NESatistic();
// ne1.readOOV("./paperdata/exp-data/pku_training_words.utf8");
// ne1.read("./paperdata/exp-data/pku_三列式结果_0.txt");
// ne1.getWrongOOV("./paperdata/exp-data/pku_OOV-Wrong.txt");
// ne1.NeEvl(null);
//
// ne1 = new NESatistic();
// ne1.readOOV("./paperdata/exp-data/cityu_training_words.utf8");
// ne1.read("./paperdata/exp-data/cityu_三列式结果_0.txt");
// ne1.getWrongOOV("./paperdata/exp-data/cityu_OOV-Wrong.txt");
// ne1.NeEvl(null);
}
public void getWrongOOV(String string) {
if(dict==null||dict.size()==0)
return;
TreeSet<String> set = new TreeSet<String>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
TreeSet<String> set1 = new TreeSet<String>();
TreeSet<String> set2 = new TreeSet<String>();
for(Entity entity:cList){
String s = entity.getEntityStr();
if(dict.contains(s)){
set1.add(s);
}else{
set1.add(s);
}
}
for(Entity entity:pList){
String s = entity.getEntityStr();
if(dict.contains(s)){
set2.add(s);
}else{
set2.add(s);
}
}
for(Entity entity:cpList){
String s = entity.getEntityStr();
set1.remove(s);
set2.remove(s);
}
// set.addAll(set1);
set.addAll(set2);
}
MyCollection.write(set, string);
}
private void getRightOOV(String string) {
if(dict==null||dict.size()==0)
return;
TreeMap<String,String> set = new TreeMap<String,String>();
for(int i=0;i<entityCs.size();i++){
LinkedList<Entity> cList = entityCs.get(i);
LinkedList<Entity> pList = entityPs.get(i);
LinkedList<Entity> cpList = entityCinPs.get(i);
for(Entity entity:cpList){
String e = entity.getEntityStr();
// if(dict.contains(e))
// break;
int idx = cList.indexOf(entity);
String s= " ... ";
if(idx!=-1){
if(idx>0)
s = cList.get(idx-1).getEntityStr() + s;
if(idx<cList.size()-1)
s = s+ cList.get(idx+1).getEntityStr();
}
adjust(set, s, 1);
}
}
List<Entry> sortedposFreq = MyCollection.sort(set);
MyCollection.write(sortedposFreq, string, true);
}
}