package weka.classifiers.meta; import java.util.ArrayList; import java.util.Random; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.core.*; /** * * @author ajb */ public class HeterogeneousEnsemble extends AbstractClassifier{ //The classifiers MUST be added externally ArrayList<Classifier> ensemble; Instances train; double[] weights; double weightsSD=0; public enum WeightType{EQUAL,CV} WeightType w; private HeterogeneousEnsemble(){ } public HeterogeneousEnsemble(ArrayList<Classifier> cl){ ensemble=new ArrayList<Classifier>(cl); weights=new double[ensemble.size()]; w=WeightType.EQUAL; } public HeterogeneousEnsemble(Classifier[] cl){ ensemble=new ArrayList<Classifier>(); for(Classifier c:cl) ensemble.add(c); weights=new double[ensemble.size()]; w=WeightType.EQUAL; } public void useCVWeighting(boolean c){ if(c) w=WeightType.CV; else w=WeightType.EQUAL; } @Override public void buildClassifier(Instances data) throws Exception { train=data; for(Classifier c:ensemble) c.buildClassifier(train); //Weighting for voting here switch(w){ case EQUAL: for(int i=0;i<ensemble.size();i++) weights[i]=1.0/ensemble.size(); break; case CV: findCVWeights(); break; default: System.out.println("Error: weight method not implemented"); throw new UnsupportedOperationException(); } } @Override public double[] distributionForInstance(Instance ins){ double[] dist,temp; dist=new double[ins.numClasses()]; for(int i=0;i<ensemble.size();i++){ try{ Classifier c=ensemble.get(i); temp=c.distributionForInstance(ins); for(int j=0;j<dist.length;j++) dist[j]+=weights[i]*temp[j]; }catch(Exception e){ e.printStackTrace(); System.out.println("Error classifying instance with classifier "); System.exit(0); } } double x=dist[0]; for(int i=1;i<dist.length;i++) x+=dist[i]; for(int i=0;i<dist.length;i++) dist[i]/=x; return dist; } private static final double THRESHOLD1=100; public void findCVWeights() throws Exception { weights=new double[ensemble.size()]; double sum=0,sumSq=0; int folds=train.numInstances(); if(folds>THRESHOLD1){ folds=10; } System.out.print("\n Finding CV Accuracy WITHIN ensemble: "); for(int i=0;i<ensemble.size();i++){ Evaluation evaluation = new Evaluation(train); evaluation.crossValidateModel(ensemble.get(i), train, folds, new Random()); weights[i]=1-evaluation.errorRate(); sum+=weights[i]; sumSq+=weights[i]*weights[i]; System.out.print(","+weights[i]); } System.out.print("\n"); for(int i=0;i<weights.length;i++) weights[i]/=sum; weightsSD=(sumSq - sum*sum/(weights.length))/(weights.length-1); weightsSD=Math.sqrt(weightsSD); } @Override public String getRevision() { throw new UnsupportedOperationException("Not supported yet."); } public double[] getWeights(){ return weights;} public double getWeightsSD(){ return weightsSD;} }