package context.arch.discoverer.query; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.io.ObjectInputStream; import java.io.Reader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Properties; import be.ac.ulg.montefiore.run.jahmm.Hmm; import be.ac.ulg.montefiore.run.jahmm.ObservationVector; import be.ac.ulg.montefiore.run.jahmm.io.FileFormatException; import be.ac.ulg.montefiore.run.jahmm.io.HmmReader; import be.ac.ulg.montefiore.run.jahmm.io.OpdfVectorReader; import context.arch.comm.DataObject; import context.arch.comm.DataObjects; import context.arch.discoverer.ComponentDescription; import context.arch.intelligibility.hmm.HmmSupervisedLearner; import context.arch.widget.SequenceWidget; import weka.classifiers.Classifier; import weka.core.Instances; public abstract class HmmWrapper { public static final String HMM_WRAPPER = "HMM_WRAPPER"; public static final String HMM_MODEL = "HMM_MODEL"; public static final String HEADER_FILE_NAME = "HEADER_FILE_NAME"; public static final String SEQUENCE_LENGTH = "SEQUENCE_LENGTH"; public static final int CACHE_LIMIT = 10; // private LinkedHashMap<Instance, String> instanceClassifications = new BoundedSizeMap<Instance, String>(CACHE_LIMIT); // private List<String> outcomeValues = new ArrayList<String>(); protected Hmm<ObservationVector> hmm; protected int numObservationValues; protected int sequenceLength; /* * These are used to map numbers in the text format to names */ protected List<String> OUTPUT_NAMES; // names of states protected List<String> INPUT_NAMES; // names of each feature input private String headerFileName; private String hmmModelFileName; public HmmWrapper(String headerFileName, String observationSequencesFileName, String stateSequencesFileName, int sequenceLength) { loadHeaderInfo(headerFileName); // learn model from source dataset HmmSupervisedLearner learner = new HmmSupervisedLearner(OUTPUT_NAMES.size(), INPUT_NAMES.size(), numObservationValues); this.hmm = learner.learn( new File(observationSequencesFileName), new File(stateSequencesFileName)); this.sequenceLength = sequenceLength; } public HmmWrapper(String headerFileName, String hmmModelFileName, int sequenceLength) { loadHeaderInfo(headerFileName); // extract classifier from serialized file try { this.hmm = HmmReader.read( new FileReader(hmmModelFileName), new OpdfVectorReader()); } catch (FileFormatException e) { e.printStackTrace(); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } this.headerFileName = headerFileName; this.hmmModelFileName = hmmModelFileName; this.sequenceLength = sequenceLength; } public void loadHeaderInfo(String fileName) { INPUT_NAMES = new ArrayList<String>(); OUTPUT_NAMES = new ArrayList<String>(); Properties properties = new Properties(); try { properties.load(new FileInputStream(fileName)); numObservationValues = Integer.parseInt(properties.getProperty("numObservationValues")); String[] inputNames = properties.getProperty("observationVector").split(","); for (String name : inputNames) { INPUT_NAMES.add(name.trim()); } String[] outputNames = properties.getProperty("states").split(","); for (String name : outputNames) { OUTPUT_NAMES.add(name.trim()); } } catch (IOException e) { } } public String[] getInputNames() { return INPUT_NAMES.toArray(new String[INPUT_NAMES.size()]); } public String[] getOutputNames() { return OUTPUT_NAMES.toArray(new String[OUTPUT_NAMES.size()]); } public int getSequenceLength() { return sequenceLength; } public Hmm<ObservationVector> getHmm() { return hmm; } public int numOutcomeValues() { // return outcomeValues.size(); return OUTPUT_NAMES.size(); } public String getOutcomeValue(int index) { // return outcomeValues.get(index); return OUTPUT_NAMES.get(index); } public List<String> getOutcomeValues() { // return Collections.unmodifiableList(outcomeValues); return Collections.unmodifiableList(OUTPUT_NAMES); } /** * * @param instance * @return null if classification failed or was invalid (e.g. null values in attributes) */ protected String[] classify(List<ObservationVector> obs) { // TODO: caching? int[] x = hmm.mostLikelyStateSequence(obs); String[] stateSeqs = new String[x.length]; for (int t = 0; t < x.length; t++) { stateSeqs[t] = OUTPUT_NAMES.get(x[t]); // assign numeric values to output names } System.out.println("classify obs = " + obs); System.out.println("classify x = " + HmmSupervisedLearner.toIntArrayString(x)); return stateSeqs; } /** * * @param stateSeq * @return first probability is actually the prior for the first state; the others are transition probabilities of matrix A */ public double[] getTransitionProbabilities(List<String> stateSeq) { double[] probs = new double[stateSeq.size()]; probs[0] = hmm.getPi(OUTPUT_NAMES.indexOf(stateSeq.get(0))); for (int t = 1; t < probs.length; t++) { probs[t] = hmm.getAij( OUTPUT_NAMES.indexOf(stateSeq.get(t-1)), OUTPUT_NAMES.indexOf(stateSeq.get(t))); } return probs; } /** * * @param stateSeq * @return first probability is actually the prior for the first state; the others are transition probabilities of matrix A */ public double[] getTransitionProbabilities(int[] x) { double[] probs = new double[x.length]; probs[0] = hmm.getPi(x[0]); for (int t = 1; t < probs.length - 1; t++) { probs[t] = hmm.getAij(x[t-1], x[t]); } return probs; } /** * Checks if widget state can be extracted as an appropriate Instance, * since other widgets are also queried. * If this fails, then classification would fail and return null. * @param widgetState * @return */ protected abstract boolean isInstanceExtractable(ComponentDescription widgetState); /** * Assumes that widgetState is validated to extract instance * @param widgetState * @return */ public List<String> classify(ComponentDescription widgetState) { List<ObservationVector> obs = extractObservations(widgetState); if (obs == null) { return null; } String[] outcomeSequence = classify(obs); // store value back into widgetState // TODO: stuff multiple values // Enactor.setAttValue(classAttribute.name(), outcomeValue, widgetState.getNonConstantAttributes()); // System.out.println("ClassifierWrapper.classifiy stored: " + String.valueOf(Enactor.getAtt(classAttribute.name(), widgetState.getNonConstantAttributes()))); return Arrays.asList(outcomeSequence); } protected double[] distributionForInstance(List<ObservationVector> obs) { // TODO: caching? int NUM_STATES = hmm.nbStates(); double[] probs = new double[NUM_STATES]; // get most likely state sequence int[] x = hmm.mostLikelyStateSequence(obs); // then permute last state int last_t = x.length - 1; for (int i = 0; i < NUM_STATES; i++) { x[last_t] = i; probs[i] = hmm.probability(obs, x); } return probs; } /** * Applicable only to getting distribution for different final states; the (earlier) rest of the sequence is fixed. * @param widgetState * @return */ public double[] distributionForInstance(ComponentDescription widgetState) { List<ObservationVector> obs = extractObservations(widgetState); return distributionForInstance(obs); } public List<ObservationVector> extractObservations(ComponentDescription widgetState) { if (widgetState == null) { return null; } if (!isInstanceExtractable(widgetState)) { return null; } /* * Need to iterate for time steps */ List<ObservationVector> observations = new ArrayList<ObservationVector>(); // grab input values for each time stamp for (int t = 0; t < sequenceLength; t++) { String seqIndexMarker = SequenceWidget.getTPrepend(t); // prepend marker double[] inputValues = new double[INPUT_NAMES.size()]; // System.out.println("extractObservations widgetState = " + widgetState); // iterate inputs names; note order is important for (int i = 0; i < INPUT_NAMES.size(); i++) { String attrName = seqIndexMarker + INPUT_NAMES.get(i); Object value = widgetState.getAttributeValue(attrName); if (value == null) { // value may be invalid if sequence not fully populated or ready yet continue; // return null; } inputValues[i] = (Integer) value; } // set observation vector for this time stamp ObservationVector o = new ObservationVector(inputValues); observations.add(o); } return observations; } public DataObject toDataObject() { DataObjects v = new DataObjects(); v.add(new DataObject(HMM_MODEL, hmmModelFileName)); v.add(new DataObject(HEADER_FILE_NAME, headerFileName)); v.add(new DataObject(SEQUENCE_LENGTH, ""+sequenceLength)); return new DataObject(HMM_WRAPPER, v); } public static HmmWrapper fromDataObject(DataObject data) { @SuppressWarnings("unused") String hmmModelFileName = data.getDataObject(HMM_WRAPPER).getValue(); @SuppressWarnings("unused") String headerFileName = data.getDataObject(HEADER_FILE_NAME).getValue(); return null; // TODO: this is an abstract class, so it cannot instantiate...need a factory // but maybe it never gets called too } public static Instances loadHeader(String headerFileName) { try { Reader arffReader = new FileReader(headerFileName); Instances header = new Instances(arffReader); header.setClassIndex(header.numAttributes()-1); // last attribute is class return header; } catch (IOException e) { e.printStackTrace(); } return null; } public static Classifier loadClassifier(String classifierFileName) { ObjectInputStream ois = null; try { ois = new ObjectInputStream(new FileInputStream(classifierFileName)); Classifier classifier = (Classifier)ois.readObject(); return classifier; } catch (IOException e) { e.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } finally { try { if (ois != null) { ois.close(); } } catch (IOException e) {} } return null; } /** * Use LRU (Least Recently Used cache; instead of FIFO) map storage of classification result of instances. * This is to minimize redundant classifications of recently seen instances. * Internally manages the limiting of the size. * See: http://www.java-alg.info/O.Reilly-Java.Generics.and.Collections/0596527756/javagenerics-CHP-16-SECT-2.html */ public static class BoundedSizeMap<K, V> extends LinkedHashMap<K, V> { private static final long serialVersionUID = 3752030986272893668L; private int maxEntries; public BoundedSizeMap(int maxEntries) { super(maxEntries, // set initial capacity to max 1, // don't need to increase size, so just use unity load factor true); // order the map by access, instead of insertion this.maxEntries = maxEntries; } @Override protected boolean removeEldestEntry(Map.Entry<K,V> eldest) { return size() > maxEntries; } } }