package context.arch.discoverer.query; import java.io.FileInputStream; import java.io.FileReader; import java.io.IOException; import java.io.ObjectInputStream; import java.io.Reader; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import context.arch.comm.DataObject; import context.arch.comm.DataObjects; import context.arch.discoverer.ComponentDescription; import weka.classifiers.Classifier; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; /** * Adapter class to wrap a Weka Classifier to work with the Context Toolkit. * This class deals with formatting the classifier and sending/retrieving its information from the Discoverer network. * Note that this supports only classification (to nominal strings), not regression. * * @author Brian Y. Lim * */ public class ClassifierWrapper { public static final String CLASSIFIER_WRAPPER = "CLASSIFIER_WRAPPER"; public static final String CLASSIFIER = "CLASSIFIER"; public static final String DATASET_HEADER = "DATASET_HEADER"; public static final int CACHE_LIMIT = 10; private LinkedHashMap<Instance, String> instanceClassifications = new LRUCache<Instance, String>(CACHE_LIMIT); private List<String> outcomeValues = new ArrayList<String>(); protected Classifier classifier; protected Instances header; protected Attribute classAttribute; protected int NUM_ATTRIBUTES; private String classifierFileName; private String headerFileName; @SuppressWarnings("unchecked") public ClassifierWrapper(String classifierFileName, String headerFileName) { this.classifierFileName = classifierFileName; this.headerFileName = headerFileName; // extract classifier from serialized file this.classifier = loadClassifier(classifierFileName); // extract Instances dataset header from serialized file this.header = loadDataset(headerFileName); // extract outcome values this.classAttribute = header.classAttribute(); Enumeration<String> values = classAttribute.enumerateValues(); while (values.hasMoreElements()) { outcomeValues.add(values.nextElement()); } NUM_ATTRIBUTES = header.numAttributes(); } /** * * @return name of .model file containing the WEKA classifier model */ public String getClassifierFileName() { return classifierFileName; } /** * * @return name of .arff file that contains the header information of WEKA attributes for the dataset */ public String getHeaderFileName() { return headerFileName; } /** * * @return the WEKA classifier */ public Classifier getClassifier() { return classifier; } /** * * @return a (possibly empty) dataset containing header information of WEKA attributes */ public Instances getHeader() { return header; } /** * * @return number of possible states (classes) for the outcome. */ public int numOutcomeValues() { return outcomeValues.size(); } /** * * @param index * @return */ public String getOutcomeValue(int index) { return outcomeValues.get(index); } public List<String> getOutcomeValues() { return Collections.unmodifiableList(outcomeValues); } public String getClassAttributeName() { return classAttribute.name(); } /** * * @param instance * @return null if classification failed or was invalid (e.g. null values in attributes) */ protected String classify(Instance instance) { // return cached result if recently classified if (instanceClassifications.containsKey(instance)) { return instanceClassifications.get(instance); } try { double value = classifier.classifyInstance(instance); // save label back into instance // TODO: not guaranteed to always be stored in other circumstances instance.setValue(classAttribute, value); // for debugging // double[] distroForInstance = classifier.distributionForInstance(instance); // System.out.println("ClassifierWrapper.classify distroForInstance"); // for (int i = 0; i < distroForInstance.length; i++) { // System.out.println("\t " + classAttribute.value(i) + ": " + distroForInstance[i]); // } String strValue = classAttribute.value((int)value); // cache result instanceClassifications.put(instance, strValue); return strValue; } catch (Exception e) { e.printStackTrace(); } return null; } /** * Checks if widget state can be extracted as an appropriate Instance, * since other widgets are also queried. * If this fails, then classification would fail and return null. * @param widgetState * @return */ protected boolean isInstanceExtractable(ComponentDescription widgetState) { return true; // TODO: not sure if this needs to be checked } /** * Assumes that widgetState is validated to extract instance * @param widgetState * @return */ public String classify(ComponentDescription widgetState) { Instance instance = extractInstance(widgetState); if (instance == null) { return null; } String outcomeValue = classify(instance); // store value back into widgetState widgetState.getNonConstantAttributes().addAttribute(classAttribute.name(), outcomeValue); // System.out.println("ClassifierWrapper.classifiy stored: " + String.valueOf(Enactor.getAtt(classAttribute.name(), widgetState.getNonConstantAttributes()))); return outcomeValue; } /** * Calls distributionForInstance of the Instance after extracing it from ComponentDescription * @param widgetState * @return */ public double[] distributionForInstance(ComponentDescription widgetState) { Instance instance = extractInstance(widgetState); try { return classifier.distributionForInstance(instance); // TODO: utilize caching! } catch (Exception e) { e.printStackTrace(); return null; } } /** * Extracts a weka Instance from the attributes of ComponentDescription widget state * @param widgetState of a widget from which to extract instance * @return */ public Instance extractInstance(ComponentDescription widgetState) { if (!isInstanceExtractable(widgetState)) { return null; } Instance instance = new DenseInstance(NUM_ATTRIBUTES); for (int i = 0; i < NUM_ATTRIBUTES; i++) { weka.core.Attribute attr = header.attribute(i); // add attribute value, depending on type if (attr.isNumeric()) { double attrVal = widgetState.getAttributeValue(attr.name()); instance.setValue(attr, attrVal); } else { // nominal or string // System.out.println("extractInstance attr.name() = " + attr.name()); // System.out.println("extractInstance widgetState = " + widgetState); String attrVal = widgetState.getAttributeValue(attr.name()); // System.out.println("attrVal = " + attrVal); // System.out.println("(attrVal != null) = " + (attrVal != null)); // System.out.println("(!attrVal.equals(\"null\")) = " + (!attrVal.equals("null"))); if (attrVal != null && !attrVal.equals("null")) { if (attr.isNumeric()) { instance.setValue(attr, Double.parseDouble(attrVal)); } else if (attr.isNominal()) { instance.setValue(attr, attrVal); } } } } // set dataset instance.setDataset(header); return instance; } /** * Convert to DataObject * @return */ public DataObject toDataObject() { DataObjects v = new DataObjects(); v.add(new DataObject(CLASSIFIER, classifierFileName)); v.add(new DataObject(DATASET_HEADER, headerFileName)); return new DataObject(CLASSIFIER_WRAPPER, v); } public static ClassifierWrapper fromDataObject(DataObject data) { @SuppressWarnings("unused") String classifierFileName = data.getDataObject(CLASSIFIER).getValue(); @SuppressWarnings("unused") String headerFileName = data.getDataObject(DATASET_HEADER).getValue(); return null; // TODO: this is an abstract class, so it cannot instantiate...need a factory // but maybe it never gets called too } /** * Extracts header in an empty Instances dataset from an .arff file. * It assumes that the last attribute is the class attribute * @param datasetFileName * @return */ public static Instances loadDataset(String datasetFileName) { try { Reader arffReader = new FileReader(datasetFileName); Instances header = new Instances(arffReader); header.setClassIndex(header.numAttributes()-1); // last attribute is class return header; } catch (IOException e) { e.printStackTrace(); } return null; } /** * Loads a serialized WEKA classifier model from a .model file * @param classifierFileName * @return */ public static Classifier loadClassifier(String classifierFileName) { ObjectInputStream ois = null; try { ois = new ObjectInputStream(new FileInputStream(classifierFileName)); Classifier classifier = (Classifier)ois.readObject(); return classifier; } catch (IOException e) { e.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } finally { try { if (ois != null) { ois.close(); } } catch (IOException e) {} } return null; } /** * Use LRU (Least Recently Used cache; instead of FIFO) map storage of classification result of instances. * This is to minimize redundant classifications of recently seen instances. * Internally manages the limiting of the size. * See: http://www.java-alg.info/O.Reilly-Java.Generics.and.Collections/0596527756/javagenerics-CHP-16-SECT-2.html */ public static class LRUCache<K, V> extends LinkedHashMap<K, V> { private static final long serialVersionUID = 3752030986272893668L; private int maxEntries; public LRUCache(int maxEntries) { super(maxEntries, // set initial capacity to max 1, // don't need to increase size, so just use unity load factor true); // order the map by access, instead of insertion this.maxEntries = maxEntries; } @Override protected boolean removeEldestEntry(Map.Entry<K,V> eldest) { return size() > maxEntries; } } }