/** FeatureGenImpl.java
*
* @author Sunita Sarawagi
* @since 1.0
* @version 1.3
*/
package iitb.Model;
import gnu.trove.set.hash.TIntHashSet;
import iitb.CRF.DataIter;
import iitb.CRF.DataSequence;
import iitb.CRF.Feature;
import iitb.CRF.FeatureGeneratorNested;
import iitb.CRF.SegmentDataSequence;
import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.StringTokenizer;
import java.util.Vector;
/**
* The FeatureGenerator is an aggregator over all these different
* feature types. You can inherit from the FeatureGenImpl class and
* after calling one of the constructors that does not make a call to
* (addFeatures()) you can then implement your own addFeatures
* class. There you will typically add the EdgeFeatures feature first
* and then the rest. So, for example if you wanted to add some
* parameter for each label (like a prior), you can create a new
* FeatureTypes class that will create as many featureids as the
* number of labels. You will have to create a new class that is
* derived from FeatureGenImpl and just have a different
* implementation of the addFeatures subroutine. The rest will be
* handled by the parent class.
* This class is responsible for converting the
* string-ids that the FeatureTypes assign to their features into
* distinct numbers. It has a inner class called FeatureMap that will
* make one pass over the training data and create the map of
* featurenames->integer id and as a side effect count the number of
* features.
*
* @author Sunita Sarawagi
* */
public class FeatureGenImpl implements FeatureGeneratorNested {
/**
*
*/
private static final long serialVersionUID = 7651911985442866611L;
ArrayList<FeatureTypes> features;
transient Iterator<FeatureTypes> featureIter;
protected FeatureTypes currentFeatureType;
protected FeatureImpl featureToReturn, feature;
public Model model;
int numFeatureTypes=0;
int totalFeatures;
boolean _fixedTransitions=true;
public boolean generateOnlyXFeatures=false;
public boolean addOnlyTrainFeatures=true;
TIntHashSet retainedFeatureTypes=new TIntHashSet(); // all features of this type are retained.
transient DataSequence data;
int cposEnd;
int cposStart;
protected WordsInTrain dict;
Vector<WordsInTrain> otherDicts = new Vector<WordsInTrain>();
public void addFeature(FeatureTypes fType) {
addFeature(fType,false);
}
public void addFeature(FeatureTypes fType, boolean retainThis) {
features.add(fType);
if (retainThis) retainedFeatureTypes.add(fType.getTypeId()+1);
if (!fType.fixedTransitionFeatures())
_fixedTransitions = false;
}
public void setDict(WordsInTrain d) {
dict = d;
}
public WordsInTrain getDict(){
if (dict == null)
dict = new WordsInTrain();
return dict;
}
protected void addFeatures() {
addFeature(new EdgeFeatures(this));
addFeature(new StartFeatures(this));
addFeature(new EndFeatures(this));
dict = new WordsInTrain();
addFeature(new UnknownFeature(this,dict));
// addFeature(new KnownInOtherState(model, dict));
// addFeature(new KernelFeaturesForLongEntity(model,new WordFeatures(model, dict)));
addFeature(new WordFeatures(this, dict));
addFeature(new FeatureTypesEachLabel(this,new ConcatRegexFeatures(this,0,0)));
}
protected FeatureTypes getFeature(int i) {
return features.get(i);
}
protected boolean keepFeature(DataSequence seq, FeatureImpl f) {
if ((retainedFeatureTypes != null) && (retainedFeatureTypes.contains(currentFeatureType.getTypeId()+1)))
return true;
return retainFeature(seq,f);
}
protected boolean retainFeature(DataSequence seq, FeatureImpl f) {
return ((seq.y(cposEnd) == f.y() || seq.y(cposEnd) < 0)
&& ((cposStart == 0) || (f.yprev() < 0) || (seq.y(cposStart-1) == f.yprev()) || seq.y(cposStart-1) < 0));
}
protected boolean featureCollectMode = false;
class FeatureMap implements Serializable {
/**
*
*/
private static final long serialVersionUID = -2268366275560581428L;
Hashtable<FeatureIdentifier, FeatureImpl> strToInt = new Hashtable<FeatureIdentifier, FeatureImpl>();
FeatureIdentifier idToName[];
FeatureMap(){
featureCollectMode = true;
}
public int getId(FeatureImpl f) {
int id = getId(f.identifier());
if ((id >= 0) && featureCollectMode) {
FeatureIdentifier storedFIdentifier = (strToInt.get(f.identifier())).identifier();
if (!storedFIdentifier.name.equals(f.identifier().name)) {
System.out.println("WARNING: same feature-id for different feature names?: "
+ storedFIdentifier + ":" + f.identifier());
}
}
if ((id < 0) && featureCollectMode && (!addOnlyTrainFeatures || keepFeature(data,f))) {
// System.out.println("Feature " + f.identifier().id + " " + f.identifier());
return add(f);
}
return id;
}
private int getId(Object key) {
if (strToInt.get(key) != null) {
// return ((Integer)strToInt.get(key)).intValue();
return ((FeatureImpl)strToInt.get(key)).index();
}
return -1;
}
public int getIndex(FeatureIdentifier fId) {return getId(fId);}
public int add(FeatureImpl feature) {
int newId = strToInt.size();
// strToInt.put(feature.identifier().clone(), new Integer(newId));
FeatureImpl newFeature = (FeatureImpl) feature.clone();
newFeature.id = newId;
strToInt.put(newFeature.identifier(),newFeature);
return newId;
}
void freezeFeatures() {
// System.out.println(strToInt.size());
featureCollectMode = false;
idToName = new FeatureIdentifier[strToInt.size()];
for (Enumeration<FeatureIdentifier> e = strToInt.keys() ; e.hasMoreElements() ;) {
//TODO: Just add immediately
FeatureIdentifier key = e.nextElement();
idToName[getId(key)] = key;
}
totalFeatures = strToInt.size();
}
public int collectFeatureIdentifiers(DataIter trainData, int maxMem) throws Exception {
for (trainData.startScan(); trainData.hasNext();) {
DataSequence seq = trainData.next();
addTrainRecord(seq);
}
freezeFeatures();
return strToInt.size();
}
public void write(PrintWriter out) throws IOException {
out.println("******* Features ************");
out.println(strToInt.size());
for (Enumeration<FeatureIdentifier> e = strToInt.keys() ; e.hasMoreElements() ;) {
Object key = e.nextElement();
out.println(key + " " + getId(key));
}
}
public int read(BufferedReader in) throws IOException {
in.readLine();
int len = Integer.parseInt(in.readLine());
String line;
for(int l = 0; (l < len) && ((line=in.readLine())!=null); l++) {
StringTokenizer entry = new StringTokenizer(line," ");
FeatureIdentifier key = new FeatureIdentifier(entry.nextToken());
int pos = Integer.parseInt(entry.nextToken());
strToInt.put(key,new FeatureImpl(pos,key));
}
freezeFeatures();
return strToInt.size();
}
public FeatureIdentifier getIdentifier(int id) {return idToName[id];}
public String getName(int id) {return idToName[id].toString();}
};
FeatureMap featureMap;
static Model getModel(String modelSpecs, int numLabels) throws Exception {
// create model..
return Model.getNewModel(numLabels,modelSpecs);
}
public FeatureGenImpl(String modelSpecs, int numLabels) throws Exception {
this(modelSpecs,numLabels,true);
}
public FeatureGenImpl(String modelSpecs, int numLabels, boolean addFeatureNow) throws Exception {
this(getModel(modelSpecs,numLabels),numLabels,addFeatureNow);
}
public FeatureGenImpl(Model m, int numLabels, boolean addFeatureNow) throws Exception {
model = m;
features = new ArrayList<FeatureTypes>();
featureToReturn = new FeatureImpl();
feature = new FeatureImpl();
featureMap = new FeatureMap();
if (addFeatureNow) addFeatures();
}
public boolean stateMappings(DataIter trainData) throws Exception {
if (model.numStates() == model.numberOfLabels())
return false;
for (trainData.startScan(); trainData.hasNext();) {
DataSequence seq = trainData.next();
if (seq instanceof SegmentDataSequence) {
model.stateMappings((SegmentDataSequence)seq);
} else {
model.stateMappings(seq);
}
}
return true;
}
public boolean mapStatesToLabels(DataSequence data) {
if (model.numStates() == model.numberOfLabels())
return false;
if (data instanceof SegmentDataSequence) {
model.mapStatesToLabels((SegmentDataSequence)data);
} else {
for (int i = 0; i < data.length(); i++) {
data.set_y(i, label(data.y(i)));
}
}
return true;
}
public void labelsToSegments(SegmentDataSequence data) {
model.mapStatesToLabels((SegmentDataSequence)data);
}
public int maxMemory() {return 1;}
public boolean train(DataIter trainData) throws Exception {
return train(trainData,true);
}
public boolean train(DataIter trainData, boolean cachedLabels) throws Exception {
return train(trainData,cachedLabels,true);
}
public boolean labelMappingNeeded() {return model.numStates() != model.numberOfLabels();}
public boolean train(DataIter trainData, boolean cachedLabels, boolean collectIds) throws Exception {
// map the y-values in the training set.
boolean labelsMapped = false;
if (cachedLabels) {
labelsMapped = stateMappings(trainData);
}
if (dict != null) dict.train(trainData,model.numStates());
for (WordsInTrain d : otherDicts) {
d.train(trainData, model.numStates());
}
boolean requiresTraining = false;
for (int f = 0; f < features.size(); f++) {
if (getFeature(f).requiresTraining()) {
requiresTraining = true;
break;
}
}
if (requiresTraining) {
for (trainData.startScan(); trainData.hasNext();) {
DataSequence seq = trainData.next();
for (int f = 0; f < features.size(); f++) {
if (getFeature(f).requiresTraining()) {
trainFeatureType(getFeature(f),seq);
}
}
}
}
if (collectIds) totalFeatures = featureMap.collectFeatureIdentifiers(trainData,maxMemory());
return labelsMapped;
};
/**
* @param featureType
* @param seq
*/
protected void trainFeatureType(FeatureTypes featureType, DataSequence seq) {
for (int l = 0; l < seq.length(); l++) {
// train each feature type.
featureType.train(seq,l);
}
}
/**
* @param seq
*/
public int addTrainRecord(DataSequence seq) {
int numF = 0;
for (int l = 0; l < seq.length(); l++) {
for (startScanFeaturesAt(seq,l); hasNext(); numF++) {
next();
}
}
return numF;
}
public void printStats() {
System.out.println("Num states " + model.numStates());
System.out.println("Num edges " + model.numEdges());
if (dict != null) System.out.println("Num words in dictionary " + dict.dictionaryLength());
System.out.println("Num features " + numFeatures());
}
protected FeatureImpl nextNoId() {
feature.copy(featureToReturn);
advance(false);
return feature;
}
protected void advance() {
advance(!featureCollectMode);
}
protected void advance(boolean returnWithId) {
while (true) {
for (;((currentFeatureType == null) || !currentFeatureType.hasNext()) && featureIter.hasNext();) {
currentFeatureType = featureIter.next();
}
if (!currentFeatureType.hasNext())
break;
while (currentFeatureType.hasNext()) {
featureToReturn.init();
copyNextFeature(featureToReturn);
featureToReturn.id = featureMap.getId(featureToReturn);
if (featureToReturn.id < 0){
continue;
}
if (featureValid(data, cposStart, cposEnd, featureToReturn, model, _fixedTransitions))
return;
}
}
featureToReturn.id = -1;
}
/**
* @param featureToReturn
*/
protected void copyNextFeature(FeatureImpl featureToReturn) {
currentFeatureType.next(featureToReturn);
}
/**
* @param featureToReturn
* @param cposEnd
* @param cposStart
* @param data
* @return
*/
public boolean featureValid(DataSequence data, int cposStart, int cposEnd, FeatureImpl featureToReturn, Model model, boolean cacheEdgeFeatures) {
return featureValidStatic(data, cposStart, cposEnd, featureToReturn, model, cacheEdgeFeatures);
}
public static boolean featureValidStatic(DataSequence data, int cposStart, int cposEnd, FeatureImpl featureToReturn, Model model, boolean cacheEdgeFeatures) {
if (((cposStart > 0) && (cposEnd < data.length()-1))
|| (featureToReturn.y() >= model.numStates())
|| (featureToReturn.yprev() >= model.numStates())
|| ((featureToReturn.yprev() >= 0) && cacheEdgeFeatures))
return true;
if ((cposStart == 0) && (model.isStartState(featureToReturn.y()))
&& ((data.length()>1) || (model.isEndState(featureToReturn.y()))))
return true;
if ((cposEnd == data.length()-1) && (model.isEndState(featureToReturn.y())))
return true;
return false;
}
protected void initScanFeaturesAt(DataSequence d) {
data = d;
currentFeatureType = null;
featureIter = features.iterator();
advance();
}
public void startScanFeaturesAt(DataSequence d, int prev, int p) {
cposEnd = p;
cposStart = prev+1;
for (int i = 0; i < features.size(); i++) {
getFeature(i).startScanFeaturesAt(d,prev,cposEnd);
}
initScanFeaturesAt(d);
}
public void startScanFeaturesAtOnlyNonCached(DataSequence d, int prev, int p) {
cposEnd = p;
cposStart = prev+1;
for (int i = 0; i < features.size(); i++) {
if (!getFeature(i).needsCaching())
getFeature(i).startScanFeaturesAt(d,prev,cposEnd);
}
initScanFeaturesAt(d);
}
public void startScanFeaturesAt(DataSequence d, int p) {
cposEnd = p;
cposStart = p;
for (int i = 0; i < features.size(); i++) {
getFeature(i).startScanFeaturesAt(d,cposEnd);
}
initScanFeaturesAt(d);
}
public void startScanFeaturesAtOnlyNonCached(DataSequence d, int p) {
cposEnd = p;
cposStart = p;
for (int i = 0; i < features.size(); i++) {
if (!getFeature(i).needsCaching())
getFeature(i).startScanFeaturesAt(d,cposEnd);
}
initScanFeaturesAt(d);
}
public boolean hasNext() {
return (featureToReturn.id >= 0);
}
public Feature next() {
feature.copy(featureToReturn);
advance();
// System.out.println(feature);
return feature;
}
public void freezeFeatures() {
if (featureCollectMode)
featureMap.freezeFeatures();
}
public int numFeatures() {
return totalFeatures;
}
public FeatureIdentifier featureIdentifier(int id) {return featureMap.getIdentifier(id);}
public String featureName(int featureIndex) {
return featureMap.getName(featureIndex);
}
public int featureIndex(FeatureIdentifier fId) {return featureMap.getIndex(fId);}
public int numStates() {
return model.numStates();
}
public int label(int stateNum) {
return (stateNum >= 0)?model.label(stateNum):stateNum;
}
protected int numFeatureTypes() {
return features.size();
}
public void read(String fileName) throws IOException {
BufferedReader in=new BufferedReader(new FileReader(fileName));
if (dict != null) dict.read(in, model.numStates());
totalFeatures = featureMap.read(in);
}
public void write(String fileName) throws IOException {
PrintWriter out=new PrintWriter(new FileOutputStream(fileName));
if (dict != null) dict.write(out);
featureMap.write(out);
out.close();
}
public void displayModel(double featureWts[]) throws IOException {
displayModel(featureWts,System.out);
}
public void displayModel(double featureWts[], PrintStream out) throws IOException {
displayModel(featureWts, out, false);
}
public void displayModel(double featureWts[], PrintStream out, boolean origFName) throws IOException {
int numF = numFeatures();
for (int fIndex = 0; fIndex < numF; fIndex++) {
Object feature = featureIdentifier(fIndex).name;
int classIndex = featureIdentifier(fIndex).stateId;
int label = model.label(classIndex);
if (!origFName)
out.println(feature + " " + label + " " + classIndex + " " + featureWts[fIndex]);
else
out.println(featureName(fIndex) + " "+featureWts[fIndex]);
}
/*
out.println("Feature types statistics");
for (int f = 0; f < features.size(); f++) {
getFeature(f).print(featureMap, featureWts);
}
*/
}
public boolean fixedTransitionFeatures() {
return _fixedTransitions;
}
// returns the label-independent featureId of the current feature
public int xFeatureIdCurrent() {
return currentFeatureType.labelIndependentId(featureToReturn);
}
public void addDict(WordsInTrain ngramDict) {
otherDicts.add(ngramDict);
}
};