package edu.ucla.nesl.mca.classifier; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import android.util.Log; import edu.ucla.nesl.mca.feature.Feature; import edu.ucla.nesl.mca.feature.Feature.OPType; import edu.ucla.nesl.mca.xdr.XDRDataInput; import edu.ucla.nesl.mca.xdr.XDRDataOutput; import edu.ucla.nesl.mca.xdr.XDRSerializable; public class DecisionTree extends Classifier implements XDRSerializable { // protected enum RealOperator { // LESSTHAN("<") { // boolean evaluate(double featureValue, double threshold) { // return featureValue < threshold; // } // }, // LESSOREQUAL("<=") { // boolean evaluate(double featureValue, double threshold) { // return featureValue <= threshold; // } // }, // GREATERTHAN(">") { // boolean evaluate(double featureValue, double threshold) { // return featureValue > threshold; // } // }, // GREATEROREQUAL(">=") { // boolean evaluate(double featureValue, double threshold) { // return featureValue >= threshold; // } // }; // // abstract boolean evaluate(double featureValue, double threshold); // // private final String m_stringVal; // // RealOperator(String name) { // m_stringVal = name; // } // // public String toString() { // return m_stringVal; // } // } class TreeNode implements XDRSerializable { /** ID for this node */ private int m_id = -1; /** Type of this node: NOMINAL or REAL */ private Feature m_feature = null; /** Type of this node's feature: NOMINAL or REAL */ private OPType m_type = null; /** Operator if type is REAL */ private RealOperator m_realOp = null; /** Threshold if type is REAL */ private double m_realThes = Double.NaN; /** Type of result: NOMINAL or REAL */ private OPType m_resultType = null; /** result if resultType is REAL */ private double m_realResult = Double.NaN; private String m_nominalResult = null; private int m_parameter = -1; public int getM_id() { return m_id; } public void setM_id(int m_id) { this.m_id = m_id; } public Feature getM_feature() { return m_feature; } public void setM_feature(Feature m_feature) { this.m_feature = m_feature; } public OPType getM_type() { return m_type; } public void setM_type(OPType m_type) { this.m_type = m_type; } public RealOperator getM_realOp() { return m_realOp; } public void setM_realOp(RealOperator m_realOp) { this.m_realOp = m_realOp; } public double getM_realThes() { return m_realThes; } public void setM_realThes(double m_realThes) { this.m_realThes = m_realThes; } public OPType getM_resultType() { return m_resultType; } public void setM_resultType(OPType m_resultType) { this.m_resultType = m_resultType; } public double getM_realResult() { return m_realResult; } public void setM_realResult(double m_realResult) { this.m_realResult = m_realResult; } public String getM_nominalResult() { return m_nominalResult; } public void setM_nominalResult(String m_nominalResult) { this.m_nominalResult = m_nominalResult; } public int getM_parameter() { return m_parameter; } public void setM_parameter(int m_parameter) { this.m_parameter = m_parameter; } public ArrayList<TreeNode> getM_childNodes() { return m_childNodes; } public void setM_childNodes(ArrayList<TreeNode> m_childNodes) { this.m_childNodes = m_childNodes; } /** Child nodes of this node */ private ArrayList<TreeNode> m_childNodes = new ArrayList<TreeNode>(); /** Temp array to store children IDs, not exported to XDR */ private int[] childList; private int childCount; public TreeNode() { // Used for readXDR } public TreeNode(JSONObject nodeObj, DecisionTree parent) throws JSONException { m_id = nodeObj.getInt("ID"); Log.i("DecisionTree", "Node ID=" + m_id); if (nodeObj.has("FeatureID")) { int featureID = nodeObj.getInt("FeatureID"); Log.i("DecisionTree", "FeatureID=" + featureID); m_feature = parent.getInputs().getFeature(featureID); Log.i("DecisionTree", "Feature=" + m_feature.getName()); if (nodeObj.has("Parameter")) { m_parameter = nodeObj.getInt("Parameter"); } m_type = m_feature.getOpType(); if (m_type == OPType.REAL) { String op = nodeObj.getString("Operator"); for (RealOperator o : RealOperator.values()) { if (o.toString().equals(op)) { m_realOp = o; break; } } m_realThes = nodeObj.getDouble("Value"); } JSONArray childNodeList = nodeObj.getJSONArray("ChildNode"); childCount = childNodeList.length(); childList = new int[childCount]; for (int i = 0; i < childCount; i++) { childList[i] = childNodeList.getInt(i); } } else if (nodeObj.has("Result")) { m_resultType = parent.getOutput().getOpType(); if (m_resultType == OPType.REAL) { m_realResult = nodeObj.getDouble("Result"); } else if (m_resultType == OPType.NOMINAL) { m_nominalResult = nodeObj.getString("Result"); Log.i("DecisionTree", "Result=" + m_nominalResult); } } else { throw new JSONException("Cannot have a node with no Feature nor Result defined."); } } public int getID() { return m_id; } public void updateChild(HashMap<Integer, TreeNode> nodeDict) { for (int i = 0; i < childCount; i++) { m_childNodes.add(nodeDict.get(childList[i])); } } public void evaluate() { } @Override public void writeXDR(XDRDataOutput output) throws IOException { } @Override public void readXDR(XDRDataInput input) throws IOException { } } // Basic assumptions // This is always a Classification // NO MissingValueStrategy implemented // NO MissingValuePenalty implemented // NoTrueChild NOT allowed // SplitCharacteristic treat binarySplit as MultiSplit /** The root of the tree */ protected TreeNode m_root = null; /** Default evaluation result */ protected String defaultResult; public DecisionTree() { // Nothing to do here } public TreeNode getM_root() { return m_root; } public void setM_root(TreeNode m_root) { this.m_root = m_root; } @Override protected void getModel(JSONObject modelObj) throws JSONException { defaultResult = modelObj.getString("Default Result"); JSONArray nodeList = modelObj.getJSONArray("Nodes"); TreeNode[] nodeArray = new TreeNode[nodeList.length()]; HashMap<Integer, TreeNode> nodeDict = new HashMap<Integer, TreeNode>(); Log.i("DecisionTree", "node length=" + nodeList.length()); // Read and build all the tree nodes for (int i = 0; i < nodeList.length(); i++) { nodeArray[i] = new TreeNode(nodeList.getJSONObject(i), this); nodeDict.put(nodeArray[i].getID(), nodeArray[i]); } m_root = nodeArray[0]; Log.i("DecisionTree", "Root Node ID=" + m_root.m_id); // Need to loop the node list once more to construct node hierarchy for (int i = 0; i < nodeList.length(); i++) { nodeArray[i].updateChild(nodeDict); } } public ArrayList<TreeNode> preOrderTraversal(TreeNode node) { // Perform a pre-order traversal of the tree // return list of visited nodes ArrayList<TreeNode> r = new ArrayList<TreeNode>(); r.add(node); for (TreeNode child : node.m_childNodes) { r.addAll(preOrderTraversal(child)); } return r; } public ArrayList<TreeNode> traversal() { return preOrderTraversal(m_root); } public Feature getRootFeature () { return m_root.m_feature; } @Override public void writeXDR(XDRDataOutput output) throws IOException { // Write the name of the classifier // Note: the length of the classifier name must be unified output.writeString("TREE"); // Perform an pre-order traversal ArrayList<TreeNode> nodeList = traversal(); // Write the number of node to XDR output.writeInt(nodeList.size()); for (TreeNode node : nodeList) { node.writeXDR(output); } output.close(); } @Override public void readXDR(XDRDataInput input) throws IOException { // Check classifier type String classifier = input.readString(); if (classifier.equals("TREE")) { // read number of nodes int n = input.readInt(); // System.out.println(n); // map the IDs to each node ArrayList<TreeNode> list = new ArrayList<TreeNode>(); HashMap<Integer, TreeNode> map = new HashMap<Integer, TreeNode>(); // read all nodes for (int i = 0; i < n; i++) { TreeNode node = new TreeNode(); node.readXDR(input); map.put(node.m_id, node); list.add(node); } // build the hierarchy of the tree m_root = list.get(0); for (int i = 0; i < n; i++) { TreeNode node = list.get(i); int[] children = node.childList; for (int j = 0; j < children.length; j++) { //System.out.print(children[j] + " "); node.m_childNodes.add(map.get(children[j])); } //System.out.println(); } } } @Override public Object evaluate() { // TODO Auto-generated method stub TreeNode cur = m_root; Log.i("DecisionTree", "root name " + cur.m_feature.getName()); Log.i("DecisionTree", "root value = " + cur.m_feature.evaluate(cur.m_parameter)); while (true) { // do the evaluation in decision tree if (cur.getM_resultType() != null) { if (cur.getM_resultType() == OPType.REAL) { return Double.valueOf(cur.getM_realResult()); } if (cur.getM_resultType() == OPType.NOMINAL) { Log.i("DecisionTreeEvaluate", "Reach leaf node, result=" + cur.getM_nominalResult()); String res = new String(cur.getM_nominalResult()); String features = ""; for (Feature fea:LogUtil.features) { features = features + fea.getName() + "(" + fea.getParameter() + ")=" + fea.getDataValue() + "; "; } Log.i("DecisionTreeEvaluate", "Result done: " + features + res); LogUtil.features.clear(); return res; } } else { if (cur.getM_type() == OPType.REAL) { double var = (Double)cur.m_feature.evaluate(cur.m_parameter); Log.i("DecisionTreeEvaluate", "value=" + var + " threshold=" + cur.getM_realThes()); if (cur.getM_realOp().evaluate(var, cur.getM_realThes())) { Log.i("DecisionTreeEvaluate", "go to left child"); cur = cur.getM_childNodes().get(0); } else { Log.i("DecisionTreeEvaluate", "go to right child"); cur = cur.getM_childNodes().get(1); } } else if (cur.getM_type() == OPType.NOMINAL) { int sel = (int) Math.round((Double)cur.m_feature.evaluate(cur.m_parameter)); cur = cur.getM_childNodes().get(sel); } } } } }