package context.arch.intelligibility.weka.j48; import java.io.Reader; import java.io.StringReader; import java.util.Enumeration; import java.util.HashMap; import java.util.Map; import weka.classifiers.trees.J48; import weka.core.Attribute; import weka.core.Instances; import weka.gui.treevisualizer.Edge; import weka.gui.treevisualizer.Node; import weka.gui.treevisualizer.TreeBuild; import context.arch.discoverer.query.ClassifierWrapper; import context.arch.intelligibility.expression.Comparison; import context.arch.intelligibility.expression.DNF; import context.arch.intelligibility.expression.Parameter; import context.arch.intelligibility.expression.Reason; import context.arch.storage.AttributeNameValue; import context.arch.widget.ClassifierWidget; /** * Utility class to parse the decision tree model from WEKA's J48 class, * and build a Disjunctions of traces (Conjunctions) from each class value. * * Converts proprietary TreeDOT format used by J48 to format used in Intelligibility Toolkit. * In the former, nodes represent attribute names (e.g. "brightness") and edges represent conditions (e.g. "<= 5"). * In the latter, nodes represent full conditions (e.g. "brightness <= 5") * @author Brian Y. Lim * */ public class J48Parser { /** * Parse the J48 tree model with supplementary information about weka attributes from the header. * Each Disjunction in Disjunctive Normal Form (DNF). * @param cModel J48 tree model to parse * @param header to reference Weka Attributes * @return Map of Disjunctions of traces for each class value * @throws Exception */ @SuppressWarnings("unchecked") public static Map<String, DNF> parse(J48 cModel, Instances header) throws Exception { Map<String, DNF> valueTraces = new HashMap<String, DNF>(); // get tree node structure from model Reader treeDot = new StringReader(cModel.graph()); TreeBuild treeBuild = new TreeBuild(); Node treeRoot = treeBuild.create(treeDot); // set up one disjunction per class value Attribute classAttr = header.classAttribute(); Enumeration<String> values = classAttr.enumerateValues(); while (values.hasMoreElements()) { valueTraces.put(values.nextElement(), new DNF()); } // recursively parse parse( treeRoot, header, new Reason(), valueTraces ); // each Disjunction is naturally in DNF according to parsing process return valueTraces; } /** * Recursively parse the tree Node. * @param node of the tree structure from the J48 model * @param header to reference Weka Attributes * @param parentTrace of the trace to be appended to as we recurse * @param valueTraces to store Disjunctions of traces for each class value */ private static void parse(Node node, Instances header, Reason parentTrace, Map<String, DNF> valueTraces) { Edge edgeToChild; int i = 0; // need to use funky counter because Node doesn't return count of children while ((edgeToChild = node.getChild(i)) != null) { // depth-first search /* * Each branch means there's an alternative path, * so create a new trace for it. * Make sure it is a duplicate copy. * Always clone so that sibling traces would not see changes done by previous siblings */ Reason trace = parentTrace.clone(); /* * Next, prepare extension to the trace */ Parameter<?> childExpression = createParameter(node, edgeToChild, header); /* * Append child to trace */ // first check if a previous Comparison was already about this attribute, then just update its bounds trace.addOrMerge(childExpression); // System.out.println("trace = " + trace); // System.out.println("parentTrace = " + parentTrace); // recurse Node childNode = edgeToChild.getTarget(); parse(childNode, header, trace, valueTraces); i++; } if (i == 0) { // no child edges => is leaf node // expect format "yes (2.0)" or "Not Exercising (2.0)" String nodeLabel = node.getLabel(); String classValue = nodeLabel.split("[()]")[0].trim(); // just take the part before '(' valueTraces.get(classValue) // get Disjunction for the class value .add(parentTrace); // add to its traces } } @SuppressWarnings("unchecked") private static <T extends Comparable<? super T>> Parameter<?> createParameter(Node node, Edge edgeToChild, Instances header) { String attrName = node.getLabel(); // e.g. humidity // extract relation and value String cond = edgeToChild.getLabel(); // e.g. "<= 42", "= hello world" String[] condParts = cond.split(" ", 2); // split at first ' ' Comparison.Relation relation = Comparison.Relation.toRelation(condParts[0]); String strValue = condParts[1]; // cast value class type Attribute attr = header.attribute(attrName); T value = (T) AttributeNameValue.valueOf( ClassifierWidget.wekaTypeToClass(attr.type()), strValue); // create Expression for name and condition Parameter<?> expression; if (attr.isNumeric()) { expression = Comparison.instance(attrName, value, relation); } else { // assume nominal, string or date expression = Parameter.instance(attrName, value); } return expression; } /** * For testing * @param args */ public static void main(String[] args) { // load cModel and header from files J48 cModel = (J48) ClassifierWrapper.loadClassifier("demos/imautostatus-dtree/imautostatus.model"); Instances header = ClassifierWrapper.loadDataset("demos/imautostatus-dtree/imautostatus-test.arff"); // Instance instance = header.instance(0); // use one instance for testing // then parse it try { Map<String, DNF> valueTraces = J48Parser.parse(cModel, header); for (String value : valueTraces.keySet()) { DNF traces = valueTraces.get(value); System.out.println(value + "(size=" + traces.size() + "): " + traces); } } catch (Exception e) { e.printStackTrace(); } } }