package edu.cmu.minorthird.classify;
import edu.cmu.minorthird.classify.experiments.*;
import java.util.*;
/**
* Multi-class version of a binary classifier. Puts classifiers in order of ones with
* the most positive examples first.
*
* @author Cameron Williams
*/
public class CascadingBinaryLearner extends OneVsAllLearner{
public String[] sortedClassNames;
private List<Dataset> data=null;
private List<Evaluation> eval=null;
public CascadingBinaryLearner(){
super();
}
// /**
// * @deprecated use CascadingBinaryLearner(BatchClassifierLearner learner)
// * @param learnerFactory a ClassifierLearnerFactory which should produce a BinaryClassifier with each call.
// */
// public CascadingBinaryLearner(ClassifierLearnerFactory learnerFactory){
// super(learnerFactory);
// }
public CascadingBinaryLearner(String l){
super(l);
}
public CascadingBinaryLearner(BatchClassifierLearner learner){
this.learner=learner;
this.learnerName=learner.toString();
learnerFactory=new ClassifierLearnerFactory(learnerName);
}
@Override
public void setSchema(ExampleSchema schema){
this.schema=schema;
innerLearner=new ArrayList<ClassifierLearner>();
data=new ArrayList<Dataset>();
//for (int i=0; i<innerLearner.size(); i++) {
for(int i=0;i<schema.getNumberOfClasses();i++){
innerLearner.add(learner.copy());
innerLearner.get(i).setSchema(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
data.add(new BasicDataset());
}
}
private void createRankings(){
// why 9?
Splitter<Example> splitter=new CrossValSplitter<Example>(9);
eval=new ArrayList<Evaluation>();
for(int i=0;i<innerLearner.size();i++){
Evaluation evaluation=Tester.evaluate(innerLearner.get(i),data.get(i),splitter);
eval.add(evaluation);
}
}
private void sortLearners(){
List<BatchClassifierLearner> unsortedLearners=new ArrayList<BatchClassifierLearner>();
String[] classNames=schema.validClassNames();
List<String> unsortedClassNames=new ArrayList<String>();
sortedClassNames=new String[schema.getNumberOfClasses()];
for(int i=0;i<innerLearner.size();i++){
unsortedLearners.add((BatchClassifierLearner)innerLearner.get(i));
unsortedClassNames.add(classNames[i]);
}
//clear list so that it can be reconstructed in sorted order
innerLearner.clear();
int position=0;
while(!unsortedLearners.isEmpty()){
double maxKappa=-10.0;
int learnerIndex=-1;
//find learner with max positive examples
for(int j=0;j<unsortedLearners.size();j++){
try{
//BatchClassifierLearner learner = ((BatchClassifierLearner)unsortedLearners.get(j));
Evaluation evaluation=eval.get(j);
double kappa=evaluation.kappa();
if(kappa>=maxKappa){
maxKappa=kappa;
learnerIndex=j;
}
}catch(Exception e){
e.printStackTrace();
}
}
//add learner to sortedLearners
ClassifierLearner learner=unsortedLearners.remove(learnerIndex);
innerLearner.add(learner);
String className=unsortedClassNames.remove(learnerIndex);
sortedClassNames[position]=className;
position++;
}
}
@Override
public void addExample(Example answeredQuery){
int classIndex=schema.getClassIndex(answeredQuery.getLabel().bestClassName());
for(int i=0;i<innerLearner.size();i++){
ClassLabel label=classIndex==i?ClassLabel.positiveLabel(1.0):ClassLabel.negativeLabel(-1.0);
Example example=new Example(answeredQuery.asInstance(),label);
innerLearner.get(i).addExample(example);
data.get(i).add(example);
}
}
@Override
public void completeTraining(){
for(int i=0;i<innerLearner.size();i++){
innerLearner.get(i).completeTraining();
}
createRankings();
sortLearners();
}
@Override
public Classifier getClassifier(){
Classifier[] classifiers=new Classifier[innerLearner.size()];
for(int i=0;i<innerLearner.size();i++){
classifiers[i]=innerLearner.get(i).getClassifier();
}
return new OneVsAllClassifier(sortedClassNames,classifiers);
}
}