package edu.fudan.nlp.pipe.seq;
import java.io.Serializable;
import java.util.Arrays;
import edu.fudan.ml.types.Dictionary;
import edu.fudan.ml.types.Instance;
import edu.fudan.ml.types.alphabet.LabelAlphabet;
import edu.fudan.nlp.pipe.Pipe;
/**
* 将字符序列转换成特征序列 因为都是01特征,这里保存的是索引号
*
* @author xpqiu
*
*/
public class DictLabel extends Pipe {
class WordInfo{
String word;
int len;
public WordInfo(String string, int n) {
word = string;
len = n;
}
}
private static final long serialVersionUID = -8634966199670429510L;
protected Dictionary dict;
protected LabelAlphabet labels;
//BMES标签索引
int idxB;
int idxM;
int idxE;
int idxS;
private boolean mutiple;
public DictLabel(Dictionary dict, LabelAlphabet labels) {
this.dict = dict;
this.mutiple = dict.isAmbiguity();
this.labels = labels;
idxB = labels.lookupIndex("B");
idxM = labels.lookupIndex("M");
idxE = labels.lookupIndex("E");
idxS = labels.lookupIndex("S");
}
public void setDict(Dictionary dict) {
this.dict = dict;
}
public void addThruPipe(Instance instance) throws Exception {
String[][] data = (String[][]) instance.getData();
int length = data[0].length;
int[][] dicData = new int[length][labels.size()];
int indexLen = dict.getIndexLen();
for (int i = 0; i < length; i++) {
if (i + indexLen <= length) {
WordInfo s = getNextN(data[0], i, indexLen);
int[] index = dict.getIndex(s.word);
if(index != null) {
for(int k = 0; k < index.length; k++) {
int n = index[k];
if(n == indexLen) { //下面那个check函数的特殊情况,只为了加速
label(i, s.len, dicData);
if(!mutiple){
i = i + s.len;
break;
}
}
int len = check(i, n, length, data[0], dicData);
if(len>0&&!mutiple){
i = i + len;
break;
}
}
}
}
}
for (int i = 0; i < length; i++)
if (hasWay(dicData[i]))
for(int j = 0; j < dicData[i].length; j++)
dicData[i][j]++;
instance.setDicData(dicData);
}
private boolean hasWay(int[] ia) {
for(int i = 0; i < ia.length; i++) {
if(ia[i] == -1)
return true;
}
return false;
}
/**
*
* @param i
* @param n
* @param length
* @param data
* @param tempData
* @return
*/
private int check(int i, int n, int length, String[] data, int[][] tempData) {
WordInfo s = getNextN(data, i, n);
if (dict.contains(s.word)) {
label(i, s.len, tempData);
return s.len;
}
return 0;
}
/**
*
* @param i
* @param n
* @param tempData
*/
private void label(int i, int n, int[][] tempData) {
// 下面这部分依赖{1=B,2=M,3=E,0=S}
if (n == 1) {
tempData[i][idxS] = -1;
} else {
tempData[i][idxB] = -1;
for (int j = i + 1; j < i + n - 1; j++)
tempData[j][idxM] = -1;
tempData[i + n - 1][idxE] = -1;
}
}
/**
* 得到从位置index开始的长度为N的字串
* @param data String[]
* @param index 起始位置
* @param N 长度
* @return
*/
public WordInfo getNextN(String[] data, int index, int N) {
StringBuilder sb = new StringBuilder();
int i = index;
while(sb.length()<N&&i<data.length){
sb.append(data[i]);
i++;
}
if(sb.length()<=N)
return new WordInfo(sb.toString(),i-index);
else
return new WordInfo(sb.substring(0,N),i-index);
}
}