package edu.cmu.minorthird.classify.sequential;
import java.util.Iterator;
import org.apache.log4j.Logger;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.transform.AugmentedInstance;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
/**
* @author William Cohen
*/
public class StackedSequenceLearner implements BatchSequenceClassifierLearner
{
private static Logger log = Logger.getLogger(StackedSequenceLearner.class);
private SequenceClassifierLearner baseLearner;
private StackingParams params;
/** Bundle of parameters for the StackedSequenceLearner. */
public static class StackingParams {
public int historySize=5, futureSize=5, stackingDepth=1;
public boolean useLogistic=true,useTargetPrediction=true,useConfidence=true;
public Splitter<Example[]> splitter=new CrossValSplitter<Example[]>(5);
int crossValSplits=5;
/** Number of instances before the current target for which the
* predicted class will be used as a feature. */
public int getHistorySize() { return historySize; }
public void setHistorySize(int newHistorySize) { this.historySize = newHistorySize; }
/** Number of instances after the current target for which the
* predicted class will be used as a feature. */
public int getFutureSize() { return futureSize; }
public void setFutureSize(int newFutureSize) { this.futureSize = newFutureSize; }
/** If true, adjust all class confidences by passing them thru a logistic function */
public boolean getUseLogisticOnConfidences() { return useLogistic; }
public void setUseLogisticOnConfidences(boolean flag) { useLogistic=flag; }
/** If true, use confidence in class predictions as weight for that feature. */
public boolean getUseConfidences() { return useConfidence; }
public void setUseConfidences(boolean flag) { useConfidence=flag; }
/** If true, adjust all class confidences by passing them thru a logistic function */
public boolean getUseTargetPrediction() { return useTargetPrediction; }
public void setUseTargetPrediction(boolean flag) { useTargetPrediction=flag; }
/** Number of iterations of stacking to use */
public int getStackingDepth() { return stackingDepth; }
public void setStackingDepth(int newStackingDepth) { this.stackingDepth = newStackingDepth; }
/* Number of cross-validation splits to use in making predictions */
public int getCrossValSplits() { return crossValSplits; }
public void setCrossValSplits(int newCrossValSplits) {
this.splitter = new CrossValSplitter<Example[]>(newCrossValSplits);
crossValSplits = newCrossValSplits;
}
}
/** Number of instances before the current target for which the
* predicted class will be used as a feature. */
@Override
public int getHistorySize() { return params.historySize; }
public void setHistorySize(int newHistorySize) { params.setHistorySize(newHistorySize); }
public StackingParams getParams() { return params; }
public StackedSequenceLearner()
{
this.baseLearner = new CMMLearner(new VotedPerceptron(),0);
this.params = new StackingParams();
}
public StackedSequenceLearner(SequenceClassifierLearner baseLearner,int depth)
{
this();
this.baseLearner = baseLearner;
params.setStackingDepth(depth);
}
public StackedSequenceLearner(ClassifierLearner baseLearner,int depth)
{
this();
this.baseLearner = new CMMLearner(baseLearner,0);
params.setStackingDepth(depth);
}
public StackedSequenceLearner(SequenceClassifierLearner baseLearner,int depth,int windowSize)
{
this();
this.baseLearner = baseLearner;
params.setStackingDepth(depth);
params.setHistorySize(windowSize);
params.setFutureSize(windowSize);
}
public StackedSequenceLearner(ClassifierLearner baseLearner,int depth,int windowSize)
{
this();
this.baseLearner = new CMMLearner(baseLearner,0);
params.setStackingDepth(depth);
params.setHistorySize(windowSize);
params.setFutureSize(windowSize);
}
@Override
public void setSchema(ExampleSchema schema) {;}
@Override
public SequenceClassifier batchTrain(SequenceDataset dataset)
{
SequenceClassifier[] m = new SequenceClassifier[params.stackingDepth+1];
SequenceDataset stackedDataset = dataset;
stackedDataset.setHistorySize(0);
ProgressCounter pc = new ProgressCounter("training stacked learner","stacking level",params.stackingDepth+1);
for (int d=0; d<=params.stackingDepth; d++) {
m[d] = new DatasetSequenceClassifierTeacher(stackedDataset).train(baseLearner);
if (d+1 <= params.stackingDepth) {
stackedDataset = stackDataset(stackedDataset);
//new ViewerFrame("Dataset "+(d+1), new SmartVanillaViewer(stackedDataset));
}
pc.progress();
}
pc.finished();
return new StackedSequenceClassifier(m, params);
}
/**
* Create a new dataset in which each instance has been augmented
* with the history features constructed from the *predicted* labels
* of previous examples, where the prediction is made using
* cross-validation.
*/
public SequenceDataset stackDataset(SequenceDataset dataset)
{
// String[] history = new String[params.historySize];
SequenceDataset result = new SequenceDataset();
Dataset.Split s = dataset.splitSequence(params.splitter);
// ExampleSchema schema = dataset.getSchema();
ProgressCounter pc = new ProgressCounter("labeling for stacking","fold",s.getNumPartitions());
for (int k=0; k<s.getNumPartitions(); k++) {
SequenceDataset trainData = (SequenceDataset)s.getTrain(k);
SequenceDataset testData = (SequenceDataset)s.getTest(k);
log.info("splitting with "+params.splitter+", preparing to train on "+trainData.size()
+" and test on "+testData.size());
SequenceClassifier c = new DatasetSequenceClassifierTeacher(trainData).train(baseLearner);
for (Iterator<Example[]> i=testData.sequenceIterator(); i.hasNext(); ) {
Example[] seq = i.next();
ClassLabel[] labels = c.classification(seq);
Example[] stackSeq = new Example[seq.length];
for (int j=0; j<seq.length; j++) {
//System.out.println("stackSeq["+j+"]="+stackSeq[j]);
Instance stackInstance = stackInstance(j,seq[j].asInstance(),labels,params);
stackSeq[j] = new Example(stackInstance,seq[j].getLabel());
}
result.addSequence( stackSeq );
}
log.info("splitting with "+params.splitter+", stored classified dataset");
pc.progress();
}
pc.finished();
result.setHistorySize(0);
return result;
}
static private Instance stackInstance(int j,Instance instancej,ClassLabel[] labels,StackingParams params)
{
int numNewFeatures = params.historySize+params.futureSize+(params.useTargetPrediction?1:0);
String[] features = new String[numNewFeatures];
double[] values = new double[numNewFeatures];
int index=0;
for (int m=j-params.historySize; m<=j+params.futureSize; m++) {
if (m!=j || params.useTargetPrediction) {
if (m>=0 && m<labels.length) {
features[index] = stackFeatureName(m-j,labels[m].bestClassName());
values[index] = 1.0;
if (params.useConfidence) {
double w = labels[m].bestWeight();
values[index] = params.useLogistic ? MathUtil.logistic(w) : w;
}
} else {
features[index] = stackFeatureName(m-j,"NULL");
values[index] = 1.0;
}
index++;
}
}
return new AugmentedInstance(instancej,features,values);
}
private static String stackFeatureName(int offsetFromTarget,String predictedClassName)
{
if (offsetFromTarget<0) return "pred.prev."+(-offsetFromTarget)+"."+predictedClassName;
else if (offsetFromTarget>0) return "pred.next."+offsetFromTarget+"."+predictedClassName;
else return "pred.here."+predictedClassName;
}
private class StackedSequenceClassifier implements SequenceClassifier,Visible
{
private SequenceClassifier[] m;
// private ExampleSchema schema;
private StackingParams params;
public StackedSequenceClassifier(SequenceClassifier[] m, StackingParams params)
{
this.m = m;
this.params = params;
}
@Override
public ClassLabel[] classification(Instance[] sequence)
{
// String[] history = new String[params.historySize];
ClassLabel[] labels = m[0].classification(sequence);
Instance[] augmentedSequence = new Instance[sequence.length];
for (int d=1; d<m.length; d++) {
// augment the examples with context from the last classifier
for (int j=0; j<sequence.length; j++) {
augmentedSequence[j] = stackInstance(j, sequence[j], labels, params);
}
// label the augmented examples
labels = m[d].classification(augmentedSequence);
}
return labels;
}
@Override
public String explain(Instance[] sequence)
{
return "not implemented";
}
@Override
public Explanation getExplanation(Instance[] sequence) {
Explanation ex = new Explanation(explain(sequence));
return ex;
}
@Override
public Viewer toGUI()
{
ParallelViewer v = new ParallelViewer();
for (int i=0; i<m.length; i++) {
final int k = i;
v.addSubView(
"Level "+k+" classifier",
new TransformedViewer( new SmartVanillaViewer(m[k]) ) {
static final long serialVersionUID=20080207L;
@Override
public Object transform(Object o) {
StackedSequenceClassifier s = (StackedSequenceClassifier)o;
return s.m[k];
}});
}
v.setContent(this);
return v;
}
}
}