package edu.fudan.nlp.pipe;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import edu.fudan.ml.types.Instance;
import edu.fudan.ml.types.alphabet.AlphabetFactory;
import edu.fudan.ml.types.alphabet.IFeatureAlphabet;
import edu.fudan.ml.types.alphabet.LabelAlphabet;
import edu.fudan.ml.types.sv.HashSparseVector;
/**
* 将字符数组类型的数据转换成稀疏向量
* 数据类型:List\<String\> -\> SparseVector
* @author xpqiu
*/
public class StringArray2SV extends Pipe implements Serializable {
private static final long serialVersionUID = 358834035189351765L;
protected IFeatureAlphabet features;
protected LabelAlphabet label;
protected static final String constant = "!#@$";
/**
* 常数项。为防止特征字典优化时改变,设为不可序列化
*/
protected transient int constIndex;
/**
* 特征是否为有序特征
*/
protected boolean isSorted = false;
public StringArray2SV() {
}
public StringArray2SV(AlphabetFactory af) {
init(af);
}
public StringArray2SV(AlphabetFactory af,boolean b){
init(af);
isSorted = b;
}
protected void init(AlphabetFactory af) {
this.features = af.DefaultFeatureAlphabet();
this.label = af.DefaultLabelAlphabet();
// 增加常数项
constIndex = features.lookupIndex(constant);
}
@Override
public void addThruPipe(Instance inst) throws Exception {
List<String> data = (List<String>) inst.getData();
int size = data.size();
HashSparseVector sv = new HashSparseVector();
Iterator<String> it = data.iterator();
for(int i=0;i<size;i++){
String token = it.next();
if(isSorted){
token+="@"+i;
}
int id = features.lookupIndex(token);
if(id==-1)
continue;
sv.put(id, 1.0f);
}
sv.put(constIndex, 1.0f);
inst.setData(sv);
}
private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException{
ois.defaultReadObject();
constIndex = features.lookupIndex(constant);
}
}