package edu.cmu.minorthird.classify;
import java.util.*;
/**
* Multi-class version of a binary classifier.
*
* @author William Cohen
*/
public class OneVsAllLearner implements ClassifierLearner{
protected ClassifierLearnerFactory learnerFactory;
protected ClassifierLearner learner;
protected String learnerName;
protected List<ClassifierLearner> innerLearner=null;
protected ExampleSchema schema;
/** Create a new object from a fragment of bean shell code,
* and make sure it's the correct type.
*/
static Object newObjectFromBSH(String s,Class<?> expectedType)
throws IllegalArgumentException{
try{
bsh.Interpreter interp=new bsh.Interpreter();
interp.eval("import edu.cmu.minorthird.classify.*;");
interp.eval("import edu.cmu.minorthird.classify.experiments.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.linear.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.trees.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.knn.*;");
interp.eval("import edu.cmu.minorthird.classify.algorithms.svm.*;");
interp.eval("import edu.cmu.minorthird.classify.transform.*;");
interp.eval("import edu.cmu.minorthird.classify.sequential.*;");
interp.eval("import edu.cmu.minorthird.text.learn.*;");
interp.eval("import edu.cmu.minorthird.text.*;");
interp.eval("import edu.cmu.minorthird.ui.*;");
interp.eval("import edu.cmu.minorthird.util.*;");
if(!s.startsWith("new"))
s="new "+s;
Object o=interp.eval(s);
if(!expectedType.isInstance(o)){
throw new IllegalArgumentException(s+" did not produce "+expectedType);
}
return o;
}catch(bsh.EvalError e){
System.out.println("ERROR: "+e.toString());
throw new IllegalArgumentException("error parsing '"+s+"':\n"+e);
}
}
public static class IllegalArgumentException extends Exception{
static final long serialVersionUID=20071015;
public IllegalArgumentException(String s){
super(s);
}
}
public OneVsAllLearner(){
//this(new ClassifierLearnerFactory("new VotedPerceptron()"));
this("new MaxEntLearner()");
}
// /**
// * @deprecated use OneVsAllLearner(BatchClassifierLearner learner)
// * @param learnerFactory a ClassifierLearnerFactory which should produce a BinaryClassifier with each call.
// */
//
// public OneVsAllLearner(ClassifierLearnerFactory learnerFactory){
// this.learnerFactory=learnerFactory;
// }
public OneVsAllLearner(String learnerName){
this.learnerName=learnerName;
learnerFactory=new ClassifierLearnerFactory(learnerName);
try{
this.learner=(ClassifierLearner)newObjectFromBSH(learnerName,ClassifierLearner.class);
}catch(Exception e){
e.printStackTrace();
}
}
public OneVsAllLearner(ClassifierLearner learner){
this.learner=learner;
this.learnerName=learner.toString();
learnerFactory=new ClassifierLearnerFactory(learnerName);
}
public void setInnerLearner(ClassifierLearner learner){
this.learner=learner;
}
public ClassifierLearner getInnerLearner(){
return learner;
}
@Override
public ClassifierLearner copy(){
OneVsAllLearner learner=null;
try{
learner=(OneVsAllLearner)this.clone();
if(innerLearner!=null){
learner.innerLearner.clear();
for(int i=0;i<innerLearner.size();i++){
ClassifierLearner inner=innerLearner.get(i);
learner.innerLearner.add(inner.copy());
}
}
}catch(Exception e){
System.out.println("Can't clone!");
e.printStackTrace();
}
return learner;
}
@Override
public void setSchema(ExampleSchema schema){
this.schema=schema;
innerLearner=new ArrayList<ClassifierLearner>();
for(int i=0;i<schema.getNumberOfClasses();i++){
innerLearner.add(learner.copy());
innerLearner.get(i).setSchema(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
}
}
@Override
public ExampleSchema getSchema(){
return schema;
}
@Override
public void reset(){
if(innerLearner!=null){
for(int i=0;i<innerLearner.size();i++){
innerLearner.get(i).reset();
}
}
}
@Override
public void setInstancePool(Iterator<Instance> iterator){
List<Instance> list=new ArrayList<Instance>();
while(iterator.hasNext()){
list.add(iterator.next());
}
for(int i=0;i<innerLearner.size();i++){
innerLearner.get(i).setInstancePool(list.iterator());
}
}
@Override
public boolean hasNextQuery(){
for(int i=0;i<innerLearner.size();i++){
if(innerLearner.get(i).hasNextQuery()){
return true;
}
}
return false;
}
@Override
public Instance nextQuery(){
for(int i=0;i<innerLearner.size();i++){
if(innerLearner.get(i).hasNextQuery())
return innerLearner.get(i).nextQuery();
}
return null;
}
@Override
public void addExample(Example answeredQuery){
int classIndex=schema.getClassIndex(answeredQuery.getLabel().bestClassName());
for(int i=0;i<innerLearner.size();i++){
ClassLabel label=classIndex==i?ClassLabel.positiveLabel(1.0):ClassLabel.negativeLabel(-1.0);
innerLearner.get(i).addExample(new Example(answeredQuery.asInstance(),label));
}
}
@Override
public void completeTraining(){
for(int i=0;i<innerLearner.size();i++){
innerLearner.get(i).completeTraining();
}
}
@Override
public Classifier getClassifier(){
Classifier[] classifiers=new Classifier[innerLearner.size()];
for(int i=0;i<innerLearner.size();i++){
classifiers[i]=innerLearner.get(i).getClassifier();
}
return new OneVsAllClassifier(schema.validClassNames(),classifiers);
}
}