/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify;
import edu.cmu.minorthird.classify.algorithms.svm.*;
import edu.cmu.minorthird.classify.algorithms.linear.*;
import edu.cmu.minorthird.util.gui.*;
import org.apache.log4j.*;
/**
* Online version of a BatchClassifierLearner.
*
* @author William Cohen
*/
public class OnlineVersion extends OnlineClassifierLearner
{
private static Logger log = Logger.getLogger(OnlineVersion.class);
private BatchClassifierLearner innerLearner;
private OnlineClassifierLearner bootstrapLearner;
private double loadFactor;
private int minBatchTrainingSize;
private Classifier storedClassifier;
private int lastTrainingSetSize;
private Dataset dataset;
/**
* Emulate on-line learning with a batch algorithm.
*
* @param innerLearner batch learning algorithm
* @param loadFactor re-train batch algorithm when number of available
* examples is loadFactor * M, where M is the number of examples
* available at the last training round.
* @param bootstrapLearner on-line learner used for the first few rounds
* @param minBatchTrainingSize use online bootstrapLearner until minBatchTrainingSize examples are available.
*/
public OnlineVersion(
BatchClassifierLearner innerLearner,double loadFactor,
OnlineClassifierLearner bootstrapLearner,int minBatchTrainingSize)
{
this.innerLearner = innerLearner;
this.loadFactor = loadFactor;
this.bootstrapLearner = bootstrapLearner;
this.minBatchTrainingSize = minBatchTrainingSize;
reset();
}
public OnlineVersion(BatchClassifierLearner innerLearner,double loadFactor)
{
this(innerLearner,loadFactor,new VotedPerceptron(),10);
}
public OnlineVersion(BatchClassifierLearner innerLearner)
{
this(innerLearner,1.5);
}
public OnlineVersion() {
this(new SVMLearner());
}
public BatchClassifierLearner getInnerLearner() { return innerLearner; }
public void setInnerLearner(BatchClassifierLearner learner) { this.innerLearner=learner; }
public OnlineClassifierLearner getBootstrapLearner() { return bootstrapLearner; }
public void setBootstrapLearner(OnlineClassifierLearner learner) { this.bootstrapLearner = learner; }
public double getBatchLoadFactor() { return loadFactor; }
public void setBatchLoadFactor(double d) { loadFactor=d; }
public int getMinBatchTrainingSize() { return minBatchTrainingSize; }
public void setMinBatchTrainingSize(int m) { minBatchTrainingSize=m; }
@Override
final public void setSchema(ExampleSchema schema) {
innerLearner.setSchema(schema);
bootstrapLearner.setSchema(schema);
}
@Override
final public ExampleSchema getSchema(){
return innerLearner.getSchema();
}
@Override
final public void reset()
{
storedClassifier = null;
lastTrainingSetSize = 0;
dataset = new BasicDataset();
innerLearner.reset();
bootstrapLearner.reset();
}
@Override
final public void addExample(Example example)
{
dataset.add(example);
if (dataset.size()<minBatchTrainingSize) {
bootstrapLearner.addExample(example);
}
}
@Override
final public void completeTraining()
{
new ViewerFrame("compete data",dataset.toGUI());
if (dataset.size()>lastTrainingSetSize || storedClassifier==null) {
log.info("final training for "+innerLearner+" on "+dataset.size()+" examples");
storedClassifier = innerLearner.batchTrain(dataset);
new ViewerFrame("classifier", new SmartVanillaViewer(storedClassifier));
lastTrainingSetSize = dataset.size();
}
}
@Override
final public Classifier getClassifier()
{
if (dataset.size() < minBatchTrainingSize) {
return bootstrapLearner.getClassifier();
} else if (dataset.size() > lastTrainingSetSize*loadFactor || storedClassifier==null) {
log.info("re-training "+innerLearner+" on "+dataset.size()+" examples");
storedClassifier = innerLearner.batchTrain(dataset);
log.info("batch classifier is "+storedClassifier);
lastTrainingSetSize = dataset.size();
return storedClassifier;
} else {
return storedClassifier;
}
}
}