package edu.cmu.minorthird.classify.sequential;
import java.util.Iterator;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
/**
* Sequential learner based on the CRF algorithm. Source for the iitb.CRF
* package available from http://crf.sourceforge.net.
* This class implements the semi-markov version of CRF
*
* @author Sunita Sarawagi
*/
public class SegmentCRFLearner extends CRFLearner implements BatchSegmenterLearner,SequenceConstants,Segmenter
{
static final long serialVersionUID=20080207L;
static int negativeClass = 0;
int maxMemory;
public SegmentCRFLearner() {
this("");
}
public SegmentCRFLearner(String args) {
super(args);
}
class SegmentDataSequence implements iitb.CRF.SegmentDataSequence {
CandidateSegmentGroup segs;
int labels[]=null;
int segLengths[];
SegmentDataSequence(CandidateSegmentGroup tokens) {
segs = tokens;
alloc();
}
SegmentDataSequence() {}
void alloc() {
if ((labels == null) || (length() > labels.length)) {
labels = new int[length()];
segLengths = new int[length()];
}
}
@Override
public int length() {return segs.getSequenceLength();}
void init(CandidateSegmentGroup tokens) {
segs = tokens;
alloc();
int pos, len;
for (pos=0; pos<length(); pos++) {
labels[pos] = negativeClass;
segLengths[pos] = 1;
for (len=1; len<=tokens.getMaxWindowSize(); len++) {
Instance inst = tokens.getSubsequenceInstance(pos,pos+len);
ClassLabel label = tokens.getSubsequenceLabel(pos,pos+len);
if (inst!=null && !label.isNegative()) {
for (int k = pos; k < pos+len; segLengths[k] = -1, labels[k++] = schema.getClassIndex(label.bestClassName()));
segLengths[pos] = len;
pos += (len-1);
break;
}
}
}
}
@Override
public int y(int i) {
return labels[i];
}
@Override
public Object x(int i) {
return null;
}
@Override
public void set_y(int i, int label) {
labels[i] = label;
}
Segmentation getSegments() {
Segmentation segs = new Segmentation(schema);
for (int i = 0; i < length(); i+= segLengths[i]) {
segs.add(new Segmentation.Segment(i,i+segLengths[i],labels[i]));
}
return segs;
}
@Override
public int getSegmentEnd(int segmentStart) {
return segLengths[segmentStart]+segmentStart-1;
}
@Override
public void setSegment(int segmentStart, int segmentEnd, int y){
for (int pos = segmentStart; pos <= segmentEnd; pos++) {
labels[pos] = y;
segLengths[pos] = -1;
}
segLengths[segmentStart] = segmentEnd-segmentStart+1;
}
};
class CRFSegmentDataIter implements iitb.CRF.DataIter {
Iterator<CandidateSegmentGroup> iter;
SegmentDataset dataset;
SegmentDataSequence segData;
CRFSegmentDataIter(SegmentDataset ds) {
dataset = ds;
segData = new SegmentDataSequence();
}
@Override
public void startScan() {
iter =dataset.candidateSegmentGroupIterator();
}
@Override
public boolean hasNext() {
return iter.hasNext();
}
@Override
public iitb.CRF.DataSequence next() {
segData.init(iter.next());
return segData;
}
};
class NestedMTFeatureTypes extends MTFeatureTypes {
static final long serialVersionUID=20080207L;
NestedMTFeatureTypes(iitb.Model.NestedFeatureGenImpl gen) {
super(gen);
}
@Override
public boolean startScanFeaturesAt(iitb.CRF.DataSequence data, int prevPos, int pos) {
SegmentDataSequence segData = (SegmentDataSequence)data;
example = segData.segs.getSubsequenceInstance(prevPos+1,pos+1);
featureLooper = example.featureIterator();
return startScan();
}
};
public class SemiMTFeatureGenImpl extends iitb.Model.NestedFeatureGenImpl {
static final long serialVersionUID=20080207L;
public SemiMTFeatureGenImpl(int numLabels, String[] labelNames, java.util.Properties options) throws Exception {
super(numLabels,options,false);
Feature features[] = new Feature[labelNames.length];
for (int i = 0; i < labelNames.length; i++)
features[i] = new Feature(new String[]{ HISTORY_FEATURE, "1", labelNames[i]});
addFeature(new iitb.Model.EdgeFeatures(this, features));
addFeature(new iitb.Model.StartFeatures(this, new Feature(new String[]{ HISTORY_FEATURE, "1", NULL_CLASS_NAME})));
//addFeature(new iitb.Model.EndFeatures(model, new Feature("E")));
addFeature(new NestedMTFeatureTypes(this));
}
};
iitb.CRF.NestedCRF nestedCrfModel;
iitb.CRF.DataIter allocModel(SegmentDataset dataset) throws Exception {
maxMemory = dataset.getMaxWindowSize();
options.setProperty("MaxMemory",""+maxMemory);
negativeClass = schema.getClassIndex(ExampleSchema.NEG_CLASS_NAME);
featureGen = new SemiMTFeatureGenImpl(schema.getNumberOfClasses(),schema.validClassNames(),options);
nestedCrfModel = new iitb.CRF.NestedCRF(featureGen.numStates(),featureGen,options);
crfModel = nestedCrfModel;
return new CRFSegmentDataIter(dataset);
}
@Override
public Segmenter batchTrain(SegmentDataset dataset) {
try {
schema = dataset.getSchema();
doTrain(allocModel(dataset));
return this;
} catch (Exception e) {
e.printStackTrace();
throw new IllegalStateException("error in CRF: "+e);
}
}
/** Return a predicted type for each element of the sequence. */
@Override
public Segmentation segmentation(CandidateSegmentGroup g) {
SegmentDataSequence seq = new SegmentDataSequence(g);
nestedCrfModel.apply(seq);
// featureGen.mapStatesToLabels(seq);
return seq.getSegments();
}
/** Return some string that 'explains' the classification */
@Override
public String explain(CandidateSegmentGroup g) {
return "not supported";
}
};