/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.df.tools; import java.lang.reflect.Field; import java.text.DecimalFormat; import java.util.HashMap; import java.util.Map; import org.apache.commons.lang.ArrayUtils; import org.apache.mahout.classifier.df.data.Data; import org.apache.mahout.classifier.df.data.Dataset; import org.apache.mahout.classifier.df.data.Instance; import org.apache.mahout.classifier.df.node.CategoricalNode; import org.apache.mahout.classifier.df.node.Leaf; import org.apache.mahout.classifier.df.node.Node; import org.apache.mahout.classifier.df.node.NumericalNode; /** * This tool is to visualize the Decision tree */ public final class TreeVisualizer { private TreeVisualizer() { } private static String doubleToString(double value) { DecimalFormat df = new DecimalFormat("0.##"); return df.format(value); } private static String toStringNode(Node node, Dataset dataset, String[] attrNames, Map<String, Field> fields, int layer) throws IllegalAccessException { StringBuilder buff = new StringBuilder(); if (node instanceof CategoricalNode) { CategoricalNode cnode = (CategoricalNode) node; int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode); Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode); String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset); for (int i = 0; i < childs.length; i++) { buff.append('\n'); for (int j = 0; j < layer; j++) { buff.append("| "); } buff.append((attrNames == null ? attr : attrNames[attr]) + " = " + attrValues[attr][i]); int index = ArrayUtils.indexOf(values, i); if (index >= 0) { buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1)); } } } else if (node instanceof NumericalNode) { NumericalNode nnode = (NumericalNode) node; int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); double split = (Double) fields.get("NumericalNode.split").get(nnode); Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); buff.append('\n'); for (int j = 0; j < layer; j++) { buff.append("| "); } buff.append((attrNames == null ? attr : attrNames[attr]) + " < " + doubleToString(split)); buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1)); buff.append('\n'); for (int j = 0; j < layer; j++) { buff.append("| "); } buff.append((attrNames == null ? attr : attrNames[attr]) + " >= " + doubleToString(split)); buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1)); } else if (node instanceof Leaf) { Leaf leaf = (Leaf) node; double label = (Double) fields.get("Leaf.label").get(leaf); if (dataset.isNumerical(dataset.getLabelId())) { buff.append(" : ").append(doubleToString(label)); } else { buff.append(" : ").append(dataset.getLabelString((int) label)); } } return buff.toString(); } private static Map<String, Field> getReflectMap() throws Exception { Map<String, Field> fields = new HashMap<String, Field>(); Field m = CategoricalNode.class.getDeclaredField("attr"); m.setAccessible(true); fields.put("CategoricalNode.attr", m); m = CategoricalNode.class.getDeclaredField("values"); m.setAccessible(true); fields.put("CategoricalNode.values", m); m = CategoricalNode.class.getDeclaredField("childs"); m.setAccessible(true); fields.put("CategoricalNode.childs", m); m = NumericalNode.class.getDeclaredField("attr"); m.setAccessible(true); fields.put("NumericalNode.attr", m); m = NumericalNode.class.getDeclaredField("split"); m.setAccessible(true); fields.put("NumericalNode.split", m); m = NumericalNode.class.getDeclaredField("loChild"); m.setAccessible(true); fields.put("NumericalNode.loChild", m); m = NumericalNode.class.getDeclaredField("hiChild"); m.setAccessible(true); fields.put("NumericalNode.hiChild", m); m = Leaf.class.getDeclaredField("label"); m.setAccessible(true); fields.put("Leaf.label", m); m = Dataset.class.getDeclaredField("values"); m.setAccessible(true); fields.put("Dataset.values", m); return fields; } /** * Decision tree to String * @param tree * Node of tree * @param dataset * @param attrNames * attribute names */ public static String toString(Node tree, Dataset dataset, String[] attrNames) throws Exception { return toStringNode(tree, dataset, attrNames, getReflectMap(), 0); } /** * Print Decision tree * @param tree * Node of tree * @param dataset * @param attrNames * attribute names */ public static void print(Node tree, Dataset dataset, String[] attrNames) throws Exception { System.out.println(toString(tree, dataset, attrNames)); } private static String toStringPredict(Node node, Instance instance, Dataset dataset, String[] attrNames, Map<String, Field> fields) throws IllegalAccessException { StringBuilder buff = new StringBuilder(); if (node instanceof CategoricalNode) { CategoricalNode cnode = (CategoricalNode) node; int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode); Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode); String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset); int index = ArrayUtils.indexOf(values, instance.get(attr)); if (index >= 0) { buff.append((attrNames == null ? attr : attrNames[attr]) + " = " + attrValues[attr][(int) instance.get(attr)]); buff.append(" -> "); buff.append(toStringPredict(childs[index], instance, dataset, attrNames, fields)); } } else if (node instanceof NumericalNode) { NumericalNode nnode = (NumericalNode) node; int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); double split = (Double) fields.get("NumericalNode.split").get(nnode); Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); if (instance.get(attr) < split) { buff.append("(" + (attrNames == null ? attr : attrNames[attr]) + " = " + doubleToString(instance.get(attr)) + ") < " + doubleToString(split)); buff.append(" -> "); buff.append(toStringPredict(loChild, instance, dataset, attrNames, fields)); } else { buff.append("(" + (attrNames == null ? attr : attrNames[attr]) + " = " + doubleToString(instance.get(attr)) + ") >= " + doubleToString(split)); buff.append(" -> "); buff.append(toStringPredict(hiChild, instance, dataset, attrNames, fields)); } } else if (node instanceof Leaf) { Leaf leaf = (Leaf) node; double label = (Double) fields.get("Leaf.label").get(leaf); if (dataset.isNumerical(dataset.getLabelId())) { buff.append(doubleToString(label)); } else { buff.append(dataset.getLabelString((int) label)); } } return buff.toString(); } /** * Predict trace to String * @param tree * Node of tree * @param data * @param attrNames * attribute names */ public static String[] predictTrace(Node tree, Data data, String[] attrNames) throws Exception { Map<String, Field> reflectMap = getReflectMap(); String[] prediction = new String[data.size()]; for (int i = 0; i < data.size(); i++) { prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(), attrNames, reflectMap); } return prediction; } /** * Print predict trace * @param tree * Node of tree * @param data * @param attrNames * attribute names */ public static void predictTracePrint(Node tree, Data data, String[] attrNames) throws Exception { Map<String, Field> reflectMap = getReflectMap(); for (int i = 0; i < data.size(); i++) { System.out.println(toStringPredict(tree, data.get(i), data.getDataset(), attrNames, reflectMap)); } } }