package edu.cmu.minorthird.classify.sequential;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import com.wcohen.ss.DistanceLearnerFactory;
import com.wcohen.ss.api.StringDistance;
import com.wcohen.ss.api.StringDistanceLearner;
import com.wcohen.ss.lookup.SoftDictionary;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.transform.DictionaryTransform;
import edu.cmu.minorthird.classify.transform.InstanceTransform;
import edu.cmu.minorthird.classify.transform.LeaveOneOutDictTransformLearner;
/**
* Extend a SegmenterLearner by including a dictionary.
*
* Distance to the closest dictionary entry will be used as an
* additional feature.
*
* @author William Cohen
*/
public class DictVersion implements BatchSegmenterLearner
{
private String[] featurePattern;
private BatchSegmenterLearner innerLearner;
private SoftDictionary softDictionary;
private StringDistance[] distances;
public DictVersion(BatchSegmenterLearner innerLearner, File dictFile, String distanceNames)
throws IOException,FileNotFoundException
{
this(LeaveOneOutDictTransformLearner.DEFAULT_PATTERN, innerLearner, dictFile, distanceNames);
}
public DictVersion(String[] featurePattern, BatchSegmenterLearner innerLearner, File dictFile, String distanceNames)
throws IOException,FileNotFoundException
{
this.featurePattern = featurePattern;
this.innerLearner = innerLearner;
softDictionary = new SoftDictionary();
softDictionary.load(dictFile);
init(distanceNames);
}
public DictVersion(
String[] featurePattern, BatchSegmenterLearner innerLearner, SoftDictionary softDictionary, String distanceNames)
{
this.featurePattern = featurePattern;
this.innerLearner = innerLearner;
this.softDictionary = softDictionary;
// set up the array of distances
init(distanceNames);
}
private void init(String distanceNames)
{
this.distances = DistanceLearnerFactory.buildArray(distanceNames);
for (int d = 0; d < distances.length; d++) {
if (distances[d] instanceof StringDistanceLearner) {
distances[d] = softDictionary.getTeacher().train( (StringDistanceLearner)distances[d] );
}
}
}
@Override
public void setSchema(ExampleSchema schema) {;}
@Override
public Segmenter batchTrain(SegmentDataset dataset)
{
// in this case, we don't need to learn a transform, we can just construct it...
ExampleSchema schema = dataset.getSchema();
// the constructor requires one dictionary and one set of distances per class
SoftDictionary[] dictPerClass = new SoftDictionary[schema.getNumberOfClasses()];
for (int i=0; i<schema.getNumberOfClasses(); i++) {
dictPerClass[i] = softDictionary;
}
StringDistance[][] distPerClass = new StringDistance[schema.getNumberOfClasses()][distances.length];
for (int i=0; i<schema.getNumberOfClasses(); i++) {
for (int j=0; j<distances.length; j++) {
distPerClass[i][j] = distances[j];
}
}
InstanceTransform transform = new DictionaryTransform(schema,dictPerClass,featurePattern,distPerClass);
SegmentTransform segmentTransform = new SegmentTransform(transform);
// now train on the transformed dataset
SegmentDataset transformedDataset = segmentTransform.transform(dataset);
//new ViewerFrame("transformedDataset", new SmartVanillaViewer(transformedDataset));
Segmenter segmenter = innerLearner.batchTrain( transformedDataset );
// return a transforming version of the learned segmenter
return new TransformingSegmenter( transform, segmenter );
}
public static void main(String[] args)
throws IOException,FileNotFoundException
{
new DictVersion(new SegmentCRFLearner(), new File(args[0]), args[1]);
}
}