package edu.cmu.minorthird.classify.multi; import org.apache.log4j.Logger; import edu.cmu.minorthird.classify.ClassifierLearner; import edu.cmu.minorthird.classify.Splitter; import edu.cmu.minorthird.classify.experiments.CrossValidatedDataset; import edu.cmu.minorthird.classify.transform.AbstractInstanceTransform; import edu.cmu.minorthird.classify.transform.PredictedClassTransform; import edu.cmu.minorthird.classify.transform.TransformingMultiClassifier; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.gui.ParallelViewer; import edu.cmu.minorthird.util.gui.TransformedViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.Visible; /** * View result of some sort of train/test experiment for Data with Multiple Labels. * * @author Cameron Williams */ public class MultiCrossValidatedDataset implements Visible{ static private Logger log=Logger.getLogger(CrossValidatedDataset.class); private MultiClassifiedDataset[] cds; private MultiClassifiedDataset[] trainCds; private MultiEvaluation v; public MultiCrossValidatedDataset(ClassifierLearner learner,MultiDataset d, Splitter<MultiExample> splitter){ this(learner,d,splitter,false,false); } public MultiCrossValidatedDataset(ClassifierLearner learner,MultiDataset d, Splitter<MultiExample> splitter,boolean saveTrainPartitions){ this(learner,d,splitter,saveTrainPartitions,false); } public MultiCrossValidatedDataset(ClassifierLearner learner,MultiDataset d, Splitter<MultiExample> splitter,boolean saveTrainPartitions,boolean cross){ MultiDataset.MultiSplit s=d.MultiSplit(splitter); cds=new MultiClassifiedDataset[s.getNumPartitions()]; trainCds= saveTrainPartitions?new MultiClassifiedDataset[s.getNumPartitions()] :null; v=new MultiEvaluation(d.getMultiSchema()); ProgressCounter pc= new ProgressCounter("train/test","fold",s.getNumPartitions()); for(int k=0;k<s.getNumPartitions();k++){ MultiDataset trainData=s.getTrain(k); if(cross) trainData=trainData.annotateData(); MultiDataset testData=s.getTest(k); log.info("splitting with "+splitter+", preparing to train on "+ trainData.size()+" and test on "+testData.size()); MultiClassifier c= new MultiDatasetClassifierTeacher(trainData).train(learner); //if(cross) testData=testData.annotateData(c); if(cross){ AbstractInstanceTransform transformer=new PredictedClassTransform(c); c=new TransformingMultiClassifier(c,transformer); } MultiDatasetIndex testIndex=new MultiDatasetIndex(testData); cds[k]=new MultiClassifiedDataset(c,testData,testIndex); if(trainCds!=null) trainCds[k]=new MultiClassifiedDataset(c,trainData,testIndex); v.extend(c,testData); //v.setProperty("classesInFold"+(k+1), "train: "+classDistributionString(trainData.getSchema(),new MultiDatasetIndex(trainData)) // +" test: "+classDistributionString(testData.getSchema(),testIndex)); log.info("splitting with "+splitter+", stored classified dataset"); pc.progress(); } pc.finished(); } // private String classDistributionString(MultiExampleSchema multiSchema, // MultiDatasetIndex index){ // StringBuffer buf=new StringBuffer(""); // java.text.DecimalFormat fmt=new java.text.DecimalFormat("#####"); // ExampleSchema[] schemas=multiSchema.getSchemas(); // for(int x=0;x<schemas.length;x++){ // ExampleSchema schema=schemas[x]; // for(int i=0;i<schema.getNumberOfClasses();i++){ // if(buf.length()>0) // buf.append("; "); // String label=schema.getClassName(i); // buf.append(fmt.format(index.size(label))+" "+label); // } // } // return buf.toString(); // } @Override public Viewer toGUI(){ ParallelViewer main=new ParallelViewer(); for(int i=0;i<cds.length;i++){ final int k=i; System.out.println(i); main.addSubView("Test Partition "+(i+1),new TransformedViewer(cds[0] .toGUI()){ static final long serialVersionUID=20080130L; @Override public Object transform(Object o){ // MultiCrossValidatedDataset cvd=(MultiCrossValidatedDataset)o; return cds[k]; } }); } if(trainCds!=null){ for(int i=0;i<trainCds.length;i++){ final int k=i; main.addSubView("Train Partition "+(i+1),new TransformedViewer(cds[0] .toGUI()){ static final long serialVersionUID=20080130L; @Override public Object transform(Object o){ // MultiCrossValidatedDataset cvd=(MultiCrossValidatedDataset)o; return trainCds[k]; } }); } } main.addSubView("Overall Evaluation",new TransformedViewer(v.toGUI()){ static final long serialVersionUID=20080130L; @Override public Object transform(Object o){ MultiCrossValidatedDataset cvd=(MultiCrossValidatedDataset)o; return cvd.v; } }); main.setContent(this); return main; } public MultiEvaluation getEvaluation(){ return v; } public static void main(String[] args){ System.out.println("CrossValidatedDataset"); } }