package edu.cmu.minorthird.classify; import java.util.*; /** * Multi-class version of a binary classifier. Puts classifiers in order of ones with * the most positive examples first. * * @author Cameron Williams */ public class MostFrequentFirstLearner extends OneVsAllLearner{ public String[] sortedClassNames; public MostFrequentFirstLearner(){ super(); } // /** // * @deprecated use MostFrequentFirstLearner(BatchClassifierLearner learner) // * @param learnerFactory a ClassifierLearnerFactory which should produce a BinaryClassifier with each call. // */ // public MostFrequentFirstLearner(ClassifierLearnerFactory learnerFactory){ // super(learnerFactory); // } public MostFrequentFirstLearner(String l){ super(l); } public MostFrequentFirstLearner(BatchClassifierLearner learner){ this.learner=learner; this.learnerName=learner.toString(); learnerFactory=new ClassifierLearnerFactory(learnerName); } private void sortLearners(){ List<BatchClassifierLearner> unsortedLearners=new ArrayList<BatchClassifierLearner>(); String[] classNames=schema.validClassNames(); List<String> unsortedClassNames=new ArrayList<String>(); sortedClassNames=new String[schema.getNumberOfClasses()]; for(int i=0;i<innerLearner.size();i++){ unsortedLearners.add((BatchClassifierLearner)innerLearner.get(i)); unsortedClassNames.add(classNames[i]); } //clear list so that it can be reconstructed in sorted order innerLearner.clear(); int position=0; while(!unsortedLearners.isEmpty()){ int maxPosEx=0; int learnerIndex=-1; //find learner with max positive examples for(int j=0;j<unsortedLearners.size();j++){ try{ BatchClassifierLearner learner= (unsortedLearners.get(j)); Dataset d=learner.dataset; int numPosEx=0; for(Iterator<Example> it=d.iterator();it.hasNext();){ Example example=it.next(); if(example.getLabel().isPositive()) numPosEx++; } if(numPosEx>maxPosEx){ maxPosEx=numPosEx; learnerIndex=j; } }catch(Exception e){ e.printStackTrace(); } } //add learner to sortedLearners ClassifierLearner learner= unsortedLearners.remove(learnerIndex); innerLearner.add(learner); String className=unsortedClassNames.remove(learnerIndex); sortedClassNames[position]=className; position++; } } @Override public void completeTraining(){ for(int i=0;i<innerLearner.size();i++){ (innerLearner.get(i)).completeTraining(); } sortLearners(); } @Override public Classifier getClassifier(){ Classifier[] classifiers=new Classifier[innerLearner.size()]; for(int i=0;i<innerLearner.size();i++){ classifiers[i]=((innerLearner.get(i))).getClassifier(); } return new OneVsAllClassifier(sortedClassNames,classifiers); } }