/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.experiments; import java.awt.Component; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.TreeMap; import javax.swing.ButtonGroup; import javax.swing.JCheckBox; import javax.swing.JComponent; import javax.swing.JRadioButton; import javax.swing.JScrollPane; import javax.swing.JSplitPane; import javax.swing.JTabbedPane; import javax.swing.JTable; import javax.swing.ScrollPaneConstants; import javax.swing.table.AbstractTableModel; import javax.swing.table.TableCellRenderer; import edu.cmu.minorthird.classify.Classifier; import edu.cmu.minorthird.classify.ClassifierLearner; import edu.cmu.minorthird.classify.Dataset; import edu.cmu.minorthird.classify.DatasetClassifierTeacher; import edu.cmu.minorthird.classify.DatasetIndex; import edu.cmu.minorthird.classify.Example; import edu.cmu.minorthird.classify.Explanation; import edu.cmu.minorthird.classify.Feature; import edu.cmu.minorthird.classify.GUI; import edu.cmu.minorthird.classify.RandomAccessDataset; import edu.cmu.minorthird.classify.SampleDatasets; import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.gui.ComponentViewer; import edu.cmu.minorthird.util.gui.Controllable; import edu.cmu.minorthird.util.gui.ControlledViewer; import edu.cmu.minorthird.util.gui.MessageViewer; import edu.cmu.minorthird.util.gui.VanillaViewer; import edu.cmu.minorthird.util.gui.Viewer; import edu.cmu.minorthird.util.gui.ViewerControls; import edu.cmu.minorthird.util.gui.ViewerFrame; import edu.cmu.minorthird.util.gui.Visible; /** * Pairs a dataset and a classifier, for easy inspection of the actions of a * classifier. * * @author William Cohen */ public class ClassifiedDataset implements Visible{ private Classifier classifier; private RandomAccessDataset dataset; private DatasetIndex index; public ClassifiedDataset(Classifier classifier,Dataset dataset){ this(classifier,dataset,new DatasetIndex(dataset)); } public ClassifiedDataset(Classifier classifier,Dataset dataset, DatasetIndex index){ this.classifier=classifier; if(dataset instanceof RandomAccessDataset){ this.dataset=(RandomAccessDataset)dataset; }else{ this.dataset=new RandomAccessDataset(); for(Iterator<Example> i=dataset.iterator();i.hasNext();){ this.dataset.add(i.next()); } } this.index=index; } public Classifier getClassifier(){ return classifier; } public Dataset getDataset(){ return dataset; } @Override public Viewer toGUI(){ Viewer v=new MessageViewer(new MyViewer()); v.setContent(this); return v; } /** * A toolbar to govern how data is filtered. */ private static class DataControls extends ViewerControls{ static final long serialVersionUID=20080130L; public JCheckBox filterOnFeatureBox; public Feature targetFeature; public JRadioButton correctButton; //public JRadioButton incorrectButton; public JRadioButton allButton; @Override public void initialize(){ // indicates if we should filter on some feature filterOnFeatureBox=new JCheckBox(); filterOnFeatureBox.setText("[none]"); targetFeature=null; add(filterOnFeatureBox); ButtonGroup group=new ButtonGroup(); correctButton=addButton("correct",false,group); //incorrectButton=addButton("incorrect",false,group); allButton=addButton("all",true,group); addApplyButton(); } private JRadioButton addButton(String s,boolean selected,ButtonGroup group){ JRadioButton button=new JRadioButton(s,selected); group.add(button); add(button); return button; } } /** * A toolbar-controlled viewer for data/classifications in a classified * dataset */ static private class ControlledDataViewer extends ComponentViewer implements Controllable{ static final long serialVersionUID=20080130L; // cached last display private ClassifiedDataset cd; // If true, only show example for which the classification is // correct (or incorrect) depending on targetCorrectness private boolean filterOnCorrectness=false; private boolean targetCorrectness=false; // * If true, only show example which contain the target feature private boolean filterOnFeature=false; private Feature targetFeature=null; @Override public void applyControls(ViewerControls controls){ DataControls dc=(DataControls)controls; if(dc.allButton.isSelected()) filterOnCorrectness=false; else{ filterOnCorrectness=true; targetCorrectness=dc.correctButton.isSelected(); } filterOnFeature=dc.filterOnFeatureBox.isSelected(); targetFeature=dc.targetFeature; // setContent here is incorrect - we want to bypass // any caching and force an update receiveContent(cd); revalidate(); } @Override public JComponent componentFor(Object o){ cd=(ClassifiedDataset)o; JTable jtable=new JTable(new MyTableModel(filteredClassifiedDataset())); jtable.setDefaultRenderer(Example.class,new TableCellRenderer(){ @Override public Component getTableCellRendererComponent(JTable table, Object value,boolean isSelected,boolean hasFocus,int row,int column){ return GUI.conciseExampleRendererComponent((Example)value,60, isSelected); } }); monitorSelections(jtable,1); JScrollPane scrollpane=new JScrollPane(jtable); scrollpane .setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_AS_NEEDED); return scrollpane; } private ClassifiedDataset filteredClassifiedDataset(){ if(!filterOnCorrectness&&!filterOnFeature){ return cd; }else{ RandomAccessDataset filteredData=new RandomAccessDataset(); ProgressCounter pc= new ProgressCounter("classifying for ClassifiedDataset","example", filteredData.size()); for(Iterator<Example> i=cd.dataset.iterator();i.hasNext();){ Example e=i.next(); boolean pass1=true; if(filterOnCorrectness) pass1= targetCorrectness==cd.classifier.classification(e).isCorrect( e.getLabel()); boolean pass2=true; if(filterOnFeature) pass2=targetFeature==null||e.getWeight(targetFeature)>0; if(pass1&&pass2){ filteredData.add(e); } pc.progress(); } pc.finished(); return new ClassifiedDataset(cd.classifier,filteredData,cd.index); } } /** models the data in the RandomAccessDataset of the ClassifiedDataset */ private class MyTableModel extends AbstractTableModel{ static final long serialVersionUID=20080130L; private ClassifiedDataset cd; public MyTableModel(ClassifiedDataset cd){ this.cd=cd; } @Override public int getRowCount(){ return cd.dataset.size(); } @Override public int getColumnCount(){ return 2; } // predicted, actual, instance @Override public Object getValueAt(int row,int col){ if(col==0) return cd.classifier.classification(cd.dataset.getExample(row)); else if(col==1) return cd.dataset.getExample(row); else throw new IllegalArgumentException("illegal col "+col); } @Override public String getColumnName(int col){ if(col==0) return "Prediction"; else if(col==1) return "Example"; else throw new IllegalArgumentException("illegal col "+col); } } } static public class ExplanationViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; Explanation ex; public ExplanationViewer(Explanation ex){ this.ex=ex; setContent(ex); } @Override public boolean canReceive(Object o){ return o instanceof Explanation; } @Override public JComponent componentFor(Object o){ ex=(Explanation)o; JScrollPane p=new JScrollPane(ex.getExplanation()); return p; } } /** * Viewer for a classified dataset */ static private class MyViewer extends ComponentViewer{ static final long serialVersionUID=20080130L; private Viewer instanceViewer,classifierViewer,explanationViewer; private ControlledViewer dataViewer; private ClassifiedDataset cd; @Override public JComponent componentFor(Object o){ cd=(ClassifiedDataset)o; JSplitPane left=new JSplitPane(); left.setOrientation(JSplitPane.VERTICAL_SPLIT); left.setResizeWeight(0.75); dataViewer= new ControlledViewer(new ControlledDataViewer(),new DataControls()); dataViewer.setContent(cd); left.setTopComponent(dataViewer); instanceViewer=GUI.newSourcedExampleViewer(); left.setBottomComponent(instanceViewer); dataViewer.setSuperView(this,"data"); instanceViewer.setSuperView(this,"instance"); JSplitPane right=new JSplitPane(); right.setOrientation(JSplitPane.VERTICAL_SPLIT); right.setResizeWeight(0.75); classifierViewer= (cd.classifier instanceof Visible)?((Visible)cd.classifier).toGUI() :new VanillaViewer(cd.classifier); right.setTopComponent(classifierViewer); explanationViewer=new ExplanationViewer(new Explanation("[explanation]")); right.setBottomComponent(explanationViewer); classifierViewer.setSuperView(this,"classifier"); explanationViewer.setSuperView(this,"explanation"); JSplitPane split=new JSplitPane(); split.setOrientation(JSplitPane.HORIZONTAL_SPLIT); split.setResizeWeight(0.50); split.setLeftComponent(left); split.setRightComponent(right); Evaluation e=new Evaluation(cd.dataset.getSchema()); e.extend(cd.classifier,cd.dataset,Evaluation.DEFAULT_PARTITION_ID); Viewer evalViewer=e.toGUI(); JTabbedPane main=new JTabbedPane(); main.add("Details",split); main.add("Evaluation",evalViewer); evalViewer.setSuperView(this,"evaluation"); return main; } @Override protected boolean canHandle(int signal,Object argument,List<Viewer> senders){ // protected boolean canHandle(int signal,Object argument){ return (signal==OBJECT_SELECTED)&&(argument instanceof Example)|| (signal==OBJECT_SELECTED)&&(argument instanceof Feature); } @Override protected void handle(int signal,Object argument,List<Viewer> senders){ // protected void handle(int signal,Object argument){ if(argument instanceof Example){ Example example=(Example)argument; instanceViewer.setContent(example); explanationViewer.setContent(cd.classifier.getExplanation(example)); revalidate(); }else if(argument instanceof Feature){ DataControls dc=(DataControls)dataViewer.getControls(); dc.targetFeature=(Feature)argument; dc.filterOnFeatureBox.setText(argument.toString()); sendSignal(TEXT_MESSAGE,featureSummary((Feature)argument,cd.index)); } } private String featureSummary(Feature f,DatasetIndex index){ StringBuffer buf=new StringBuffer(f.toString()); buf.append(" appears in "); buf.append(index.size(f)); buf.append(" examples:"); Map<String,Integer> map=new TreeMap<String,Integer>(); for(int i=0;i<index.size(f);i++){ String label=index.getExample(f,i).getLabel().bestClassName(); Integer count=map.get(label); if(count==null) map.put(label,(count=new Integer(0))); map.put(label,new Integer(count.intValue()+1)); } for(Iterator<String> j=map.keySet().iterator();j.hasNext();){ String label=j.next(); Integer count=map.get(label); buf.append(" "+count+":"+label); } return buf.toString(); } } public static void main(String[] args){ Dataset train=SampleDatasets.sampleData("toy",false); // ClassifierLearner learner = new DecisionTreeLearner(); ClassifierLearner learner=new NaiveBayes(); // ClassifierLearner learner = new AdaBoost(new BatchVersion(new // NaiveBayes()),4); // ClassifierLearner learner = new AdaBoost(new DecisionTreeLearner(), 3); Classifier cl=new DatasetClassifierTeacher(train).train(learner); Dataset test=SampleDatasets.sampleData("toy",true); ClassifiedDataset cd=new ClassifiedDataset(cl,test); new ViewerFrame("ClassifiedDataset",cd.toGUI()); } }