/* Copyright 2003-2004, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.sequential;
import java.io.Serializable;
import javax.swing.JComponent;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OneVsAllClassifier;
import edu.cmu.minorthird.classify.OnlineClassifierLearner;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.VanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
public class SequenceUtils
{
/** Create an array of n copies of the prototype learner. */
static public OnlineClassifierLearner[] duplicatePrototypeLearner(OnlineClassifierLearner prototype,int n)
{
try {
OnlineClassifierLearner[] result = new OnlineClassifierLearner[n];
for (int i=0; i<n; i++) {
result[i] = (OnlineClassifierLearner)prototype.copy();
result[i].reset();
}
return result;
} catch (Exception ex) {
throw new IllegalArgumentException("innerLearner must be cloneable");
}
}
/** Wraps the OneVsAllClassifier, and provides a more convenient constructor. */
public static class MultiClassClassifier extends OneVsAllClassifier implements Serializable
{
static private final long serialVersionUID = 20080207L;
public MultiClassClassifier(ExampleSchema schema,ClassifierLearner[] learners)
{
super(schema.validClassNames(), getBinaryClassifiers(learners));
}
public MultiClassClassifier(ExampleSchema schema,Classifier[] classifiers)
{
super(schema.validClassNames(), classifiers);
}
public ExampleSchema getSchema() { return new ExampleSchema(getClassNames()); }
static public BinaryClassifier[] getBinaryClassifiers(ClassifierLearner[] learners)
{
BinaryClassifier[] result = new BinaryClassifier[learners.length];
for (int i=0; i<learners.length; i++) {
result[i] = new MyBinaryClassifier(learners[i].getClassifier());
}
return result;
}
static private class MyBinaryClassifier extends BinaryClassifier implements Visible
{
static final long serialVersionUID=20080207L;
private Classifier c;
public MyBinaryClassifier(Classifier c) { this.c = c; }
@Override
public double score(Instance instance) { return c.classification(instance).posWeight(); };
@Override
public String explain(Instance instance) { return c.explain(instance); }
@Override
public Explanation getExplanation(Instance instance) {
Explanation.Node top = c.getExplanation(instance).getTopNode();
Explanation ex = new Explanation(top);
return ex;
}
@Override
public Viewer toGUI() {
Viewer v = new ComponentViewer() {
static final long serialVersionUID=20080207L;
@Override
public JComponent componentFor(Object o) {
MyBinaryClassifier b = (MyBinaryClassifier)o;
return (b.c instanceof Visible)?((Visible)b.c).toGUI():new VanillaViewer(c);
}
};
v.setContent(this);
return v;
}
@Override
public String toString() { return "[MyBC "+c+"]"; }
};
}
}