/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.multi; import java.awt.Component; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Set; import java.util.TreeSet; import javax.swing.JComponent; import javax.swing.JList; import javax.swing.JScrollPane; import javax.swing.ListCellRenderer; import edu.cmu.minorthird.classify.BasicDataset; import edu.cmu.minorthird.classify.BatchVersion; import edu.cmu.minorthird.classify.CascadingBinaryLearner; import edu.cmu.minorthird.classify.ClassLabel; import edu.cmu.minorthird.classify.ClassifierLearner; import edu.cmu.minorthird.classify.Dataset; import edu.cmu.minorthird.classify.DatasetLoader; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.ExampleSchema; import edu.cmu.minorthird.classify.FeatureFactory; import edu.cmu.minorthird.classify.Instance; import edu.cmu.minorthird.classify.Splitter; import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron; import edu.cmu.minorthird.classify.experiments.CrossValSplitter; import edu.cmu.minorthird.util.Saveable; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.Visible; import edu.cmu.minorthird.util.gui.ZoomedViewer; /** * A set of examples for learning. * * @author Cameron Williams */ public class MultiDataset implements Dataset,Visible,Saveable{ static final long serialVersionUID=20080130L; protected List<MultiExample> examples=new ArrayList<MultiExample>(); protected List<Instance> unlabeledExamples=new ArrayList<Instance>(); protected FeatureFactory factory=new FeatureFactory(); protected List<Set<String>> classNameSets; public int numPosExamples=0; /** Overridden, provides ExampleSchema for first dimension */ @Override public ExampleSchema getSchema(){ ExampleSchema schema= new ExampleSchema(classNameSets.get(0) .toArray(new String[classNameSets.get(0).size()])); if(schema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA)) return ExampleSchema.BINARY_EXAMPLE_SCHEMA; else return schema; } public MultiExampleSchema getMultiSchema(){ ExampleSchema[] schemas=new ExampleSchema[classNameSets.size()]; for(int i=0;i<schemas.length;i++){ schemas[i]= new ExampleSchema(classNameSets.get(i) .toArray(new String[classNameSets.get(i).size()])); if(schemas.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA)) schemas[i]=ExampleSchema.BINARY_EXAMPLE_SCHEMA; } MultiExampleSchema multiSchema=new MultiExampleSchema(schemas); return multiSchema; } public MultiExample getMultiExample(int i){ return examples.get(i); } @Override public FeatureFactory getFeatureFactory(){ return factory; } // // methods for semisupervised data, part of the SemiSupervisedDataset // interface // public void addUnlabeled(Instance instance){ unlabeledExamples.add(factory.compress(instance)); } public Iterator<Instance> iteratorOverUnlabeled(){ return unlabeledExamples.iterator(); } // public ArrayList getUnlabeled() { return this.unlabeledExamples; } public int sizeUnlabeled(){ return unlabeledExamples.size(); } public boolean hasUnlabeled(){ return (unlabeledExamples.size()>0)?true:false; } @Override public void add(Example example){ throw new IllegalArgumentException( "You must add a MultiExample to a MutiDataset"); } @Override public void add(Example example,boolean compress){ throw new IllegalArgumentException( "You must add a MultiExample to a MutiDataset"); } // // methods for labeled data, part of the Dataset interface // public void addMulti(MultiExample example){ if(classNameSets==null){ classNameSets=new ArrayList<Set<String>>(example.getMultiLabel().numDimensions()); for(int i=0;i<classNameSets.size();i++){ classNameSets.add(new TreeSet<String>()); } } if(classNameSets.size()!=example.getMultiLabel().numDimensions()) throw new IllegalArgumentException( "This example does not have the same number of dimensions as previous examples"); examples.add(factory.compress(example)); List<Set<String>> possibleLabels=example.getMultiLabel().possibleLabels(); for(int i=0;i<classNameSets.size();i++){ classNameSets.get(i).addAll(possibleLabels.get(i)); } // Maybe change ClassLabel cl=example.getLabel(); if(cl.isPositive()) numPosExamples++; } public Dataset[] separateDatasets(){ Example[] ex_one=(examples.get(0)).getExamples(); Dataset[] d=new BasicDataset[ex_one.length]; for(int i=0;i<d.length;i++){ d[i]=new BasicDataset(); } for(int i=0;i<examples.size();i++){ Example[] ex=(examples.get(i)).getExamples(); for(int j=0;j<ex.length;j++){ d[j].add(ex[j]); } } return d; } public int getNumPosExamples(){ return numPosExamples; } // Why don't we just overwrite these methods? Also, it's not an illegal argument. - frank @Override public Iterator<Example> iterator(){ throw new IllegalArgumentException( "Must use multiIterator to iterate through MultiExamples"); } public Iterator<MultiExample> multiIterator(){ return examples.iterator(); } @Override public int size(){ return examples.size(); } @Override public void shuffle(Random r){ Collections.shuffle(examples,r); } @Override public void shuffle(){ shuffle(new Random(999)); } @Override public Dataset shallowCopy(){ MultiDataset copy=new MultiDataset(); for(Iterator<MultiExample> i=multiIterator();i.hasNext();){ copy.addMulti(i.next()); } return copy; } // // Implement Saveable interface. // static private final String FORMAT_NAME="Minorthird MultiDataset"; @Override public String[] getFormatNames(){ return new String[]{FORMAT_NAME}; } @Override public String getExtensionFor(String s){ return ".multidata"; } @Override public void saveAs(File file,String format) throws IOException{ if(!format.equals(FORMAT_NAME)) throw new IllegalArgumentException("illegal format "+format); DatasetLoader.save(this,file); } @Override public Object restore(File file) throws IOException{ try{ return DatasetLoader.loadFile(file); }catch(NumberFormatException ex){ throw new IllegalStateException("error loading from "+file+": "+ex); } } /** A string view of the dataset */ @Override public String toString(){ StringBuffer buf=new StringBuffer(""); for(Iterator<MultiExample> i=this.multiIterator();i.hasNext();){ MultiExample ex=i.next(); buf.append(ex.toString()); buf.append("\n"); } return buf.toString(); } public MultiDataset annotateData(){ MultiDataset annotatedDataset=new MultiDataset(); Splitter<MultiExample> splitter=new CrossValSplitter<MultiExample>(9); MultiDataset.MultiSplit s=this.MultiSplit(splitter); for(int x=0;x<9;x++){ MultiClassifierTeacher teacher= new MultiDatasetClassifierTeacher(s.getTrain(x)); ClassifierLearner lnr= new CascadingBinaryLearner(new BatchVersion(new VotedPerceptron())); MultiClassifier c=teacher.train(lnr); for(Iterator<MultiExample> i=s.getTest(x).multiIterator();i.hasNext();){ MultiExample ex=i.next(); Instance instance=ex.asInstance(); MultiClassLabel predicted=c.multiLabelClassification(instance); Instance annotatedInstance= new InstanceFromPrediction(instance,predicted.bestClassName()); MultiExample newEx= new MultiExample(annotatedInstance,ex.getMultiLabel(),ex .getWeight()); annotatedDataset.addMulti(newEx); } } return annotatedDataset; } public MultiDataset annotateData(MultiClassifier multiClassifier){ MultiDataset annotatedDataset=new MultiDataset(); for(Iterator<MultiExample> i=this.multiIterator();i.hasNext();){ MultiExample ex=i.next(); Instance instance=ex.asInstance(); MultiClassLabel predicted= multiClassifier.multiLabelClassification(instance); Instance annotatedInstance= new InstanceFromPrediction(instance,predicted.bestClassName()); MultiExample newEx= new MultiExample(annotatedInstance,ex.getMultiLabel(),ex.getWeight()); annotatedDataset.addMulti(newEx); } return annotatedDataset; } /** A GUI view of the dataset. */ @Override public Viewer toGUI(){ Viewer dbGui=new SimpleDatasetViewer(); dbGui.setContent(this); Viewer instGui=GUI.newSourcedMultiExampleViewer(); return new ZoomedViewer(dbGui,instGui); } public static class SimpleDatasetViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; @Override public boolean canReceive(Object o){ return o instanceof Dataset; } @Override public JComponent componentFor(Object o){ final MultiDataset d=(MultiDataset)o; final MultiExample[] tmp=new MultiExample[d.size()]; int k=0; for(Iterator<MultiExample> i=d.multiIterator();i.hasNext();){ tmp[k++]=i.next(); } final JList jList=new JList(tmp); jList.setCellRenderer(new ListCellRenderer(){ @Override public Component getListCellRendererComponent(JList el,Object v, int index,boolean sel,boolean focus){ return GUI.conciseMultiExampleRendererComponent( tmp[index],60,sel); } }); monitorSelections(jList); return new JScrollPane(jList); } } // // splitter // @Override public Split split(final Splitter<Example> splitter){ System.err.println("Split split() not implemented."); return null; } public class MultiSplit{ Splitter<MultiExample> splitter; public MultiSplit(Splitter<MultiExample> splitter){ this.splitter=splitter; } public int getNumPartitions(){ return splitter.getNumPartitions(); } public MultiDataset getTrain(int k){ return invertMultiIteration(splitter.getTrain(k)); } public MultiDataset getTest(int k){ return invertMultiIteration(splitter.getTest(k)); } } public MultiSplit MultiSplit(final Splitter<MultiExample> splitter){ splitter.split(examples.iterator()); return new MultiSplit(splitter); } private MultiDataset invertMultiIteration(Iterator<MultiExample> i){ MultiDataset copy=new MultiDataset(); while(i.hasNext()) copy.addMulti(i.next()); return copy; } // // test routine // /** Simple test routine */ static public void main(String[] args){ System.out.println("Not working yet"); } }