/* Copyright 2003, Carnegie Mellon, All Rights Reserved */ package edu.cmu.minorthird.classify.sequential; import edu.cmu.minorthird.classify.*; import edu.cmu.minorthird.classify.experiments.ClassifiedDataset; import edu.cmu.minorthird.util.gui.*; import java.util.HashMap; import java.util.Iterator; import java.util.Map; /** * A SequenceDataset that has been classified with a * SequenceClassifier. * * @author William Cohen */ public class ClassifiedSequenceDataset implements Visible { // private SequenceClassifier sequenceClassifier; private SequenceDataset sequenceDataset; private Classifier adaptedClassifier; public ClassifiedSequenceDataset(SequenceClassifier sequenceClassifier,SequenceDataset sequenceDataset) { // this.sequenceClassifier = sequenceClassifier; this.sequenceDataset = sequenceDataset; this.adaptedClassifier = new AdaptedSequenceClassifier(sequenceClassifier,sequenceDataset); } public Classifier getClassifier() { return adaptedClassifier; } @Override public Viewer toGUI() { Viewer cdv = new ClassifiedDataset(adaptedClassifier,sequenceDataset).toGUI(); Viewer v = new TransformedViewer(cdv) { static final long serialVersionUID=20080207L; @Override public Object transform(Object o) { ClassifiedSequenceDataset csd = (ClassifiedSequenceDataset)o; return new ClassifiedDataset(csd.adaptedClassifier,csd.sequenceDataset); } }; v.setContent(this); return v; } /** Classifies examples from the sequenceDataset, by (a) mapping an * example to it position in the containing sequence (b) classifying the * containing sequence - caching it if necessary. */ private static class AdaptedSequenceClassifier implements Classifier, Visible { private SequenceClassifier sequenceClassifier; private class Place { Example[] seq; int index; public Place(Example[] seq,int index) { this.seq=seq; this.index=index; } } private Map<Object,Place> instanceToPlace = new HashMap<Object,Place>(); private Map<Example[],ClassLabel[]> classifiedSeq = new HashMap<Example[],ClassLabel[]>(); private Map<Example[],String> explainedSeq = new HashMap<Example[],String>(); public AdaptedSequenceClassifier(SequenceClassifier sequenceClassifier,SequenceDataset sequenceDataset) { this.sequenceClassifier = sequenceClassifier; for (Iterator<Example[]> i=sequenceDataset.sequenceIterator(); i.hasNext(); ) { Example[] seq = i.next(); for (int j=0; j<seq.length; j++) { instanceToPlace.put( seq[j].getSource(), new Place(seq,j) ); } } } @Override public ClassLabel classification(Instance instance) { Place place = instanceToPlace.get(instance.getSource()); if (place==null) throw new IllegalArgumentException("instance src"+instance.getSource()+" not in "+instanceToPlace); ClassLabel[] labelSeq = classifiedSeq.get(place.seq); if (labelSeq==null) { classifiedSeq.put(place.seq, (labelSeq=sequenceClassifier.classification(place.seq)) ); } return labelSeq[place.index]; } @Override public String explain(Instance instance) { Place place = instanceToPlace.get(instance.getSource()); if (place==null) throw new IllegalArgumentException("no explanation available"); String explanation = explainedSeq.get(place.seq); if (explanation==null) { explainedSeq.put(place.seq, (explanation=sequenceClassifier.explain(place.seq)) ); } return explanation; } @Override public Explanation getExplanation(Instance instance) { Place place = instanceToPlace.get(instance.getSource()); if (place==null) throw new IllegalArgumentException("no explanation available"); Explanation ex = sequenceClassifier.getExplanation(place.seq); return ex; } @Override public Viewer toGUI() { return new SmartVanillaViewer(sequenceClassifier); } } }