/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.df.tools;
import java.lang.reflect.Field;
import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.node.CategoricalNode;
import org.apache.mahout.classifier.df.node.Leaf;
import org.apache.mahout.classifier.df.node.Node;
import org.apache.mahout.classifier.df.node.NumericalNode;
/**
* This tool is to visualize the Decision tree
*/
public final class TreeVisualizer {
private TreeVisualizer() {
}
private static String doubleToString(double value) {
DecimalFormat df = new DecimalFormat("0.##");
return df.format(value);
}
private static String toStringNode(Node node, Dataset dataset, String[] attrNames,
Map<String, Field> fields, int layer) throws IllegalAccessException {
StringBuilder buff = new StringBuilder();
if (node instanceof CategoricalNode) {
CategoricalNode cnode = (CategoricalNode) node;
int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
for (int i = 0; i < childs.length; i++) {
buff.append('\n');
for (int j = 0; j < layer; j++) {
buff.append("| ");
}
buff.append((attrNames == null ? attr : attrNames[attr]) + " = " + attrValues[attr][i]);
int index = ArrayUtils.indexOf(values, i);
if (index >= 0) {
buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1));
}
}
} else if (node instanceof NumericalNode) {
NumericalNode nnode = (NumericalNode) node;
int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
double split = (Double) fields.get("NumericalNode.split").get(nnode);
Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
buff.append('\n');
for (int j = 0; j < layer; j++) {
buff.append("| ");
}
buff.append((attrNames == null ? attr : attrNames[attr]) + " < " + doubleToString(split));
buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1));
buff.append('\n');
for (int j = 0; j < layer; j++) {
buff.append("| ");
}
buff.append((attrNames == null ? attr : attrNames[attr]) + " >= " + doubleToString(split));
buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1));
} else if (node instanceof Leaf) {
Leaf leaf = (Leaf) node;
double label = (Double) fields.get("Leaf.label").get(leaf);
if (dataset.isNumerical(dataset.getLabelId())) {
buff.append(" : ").append(doubleToString(label));
} else {
buff.append(" : ").append(dataset.getLabelString((int) label));
}
}
return buff.toString();
}
private static Map<String, Field> getReflectMap() throws Exception {
Map<String, Field> fields = new HashMap<String, Field>();
Field m = CategoricalNode.class.getDeclaredField("attr");
m.setAccessible(true);
fields.put("CategoricalNode.attr", m);
m = CategoricalNode.class.getDeclaredField("values");
m.setAccessible(true);
fields.put("CategoricalNode.values", m);
m = CategoricalNode.class.getDeclaredField("childs");
m.setAccessible(true);
fields.put("CategoricalNode.childs", m);
m = NumericalNode.class.getDeclaredField("attr");
m.setAccessible(true);
fields.put("NumericalNode.attr", m);
m = NumericalNode.class.getDeclaredField("split");
m.setAccessible(true);
fields.put("NumericalNode.split", m);
m = NumericalNode.class.getDeclaredField("loChild");
m.setAccessible(true);
fields.put("NumericalNode.loChild", m);
m = NumericalNode.class.getDeclaredField("hiChild");
m.setAccessible(true);
fields.put("NumericalNode.hiChild", m);
m = Leaf.class.getDeclaredField("label");
m.setAccessible(true);
fields.put("Leaf.label", m);
m = Dataset.class.getDeclaredField("values");
m.setAccessible(true);
fields.put("Dataset.values", m);
return fields;
}
/**
* Decision tree to String
* @param tree
* Node of tree
* @param dataset
* @param attrNames
* attribute names
*/
public static String toString(Node tree, Dataset dataset, String[] attrNames)
throws Exception {
return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
}
/**
* Print Decision tree
* @param tree
* Node of tree
* @param dataset
* @param attrNames
* attribute names
*/
public static void print(Node tree, Dataset dataset, String[] attrNames) throws Exception {
System.out.println(toString(tree, dataset, attrNames));
}
private static String toStringPredict(Node node, Instance instance, Dataset dataset,
String[] attrNames, Map<String, Field> fields) throws IllegalAccessException {
StringBuilder buff = new StringBuilder();
if (node instanceof CategoricalNode) {
CategoricalNode cnode = (CategoricalNode) node;
int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
int index = ArrayUtils.indexOf(values, instance.get(attr));
if (index >= 0) {
buff.append((attrNames == null ? attr : attrNames[attr]) + " = "
+ attrValues[attr][(int) instance.get(attr)]);
buff.append(" -> ");
buff.append(toStringPredict(childs[index], instance, dataset, attrNames, fields));
}
} else if (node instanceof NumericalNode) {
NumericalNode nnode = (NumericalNode) node;
int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
double split = (Double) fields.get("NumericalNode.split").get(nnode);
Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
if (instance.get(attr) < split) {
buff.append("(" + (attrNames == null ? attr : attrNames[attr]) + " = "
+ doubleToString(instance.get(attr)) + ") < " + doubleToString(split));
buff.append(" -> ");
buff.append(toStringPredict(loChild, instance, dataset, attrNames, fields));
} else {
buff.append("(" + (attrNames == null ? attr : attrNames[attr]) + " = "
+ doubleToString(instance.get(attr)) + ") >= " + doubleToString(split));
buff.append(" -> ");
buff.append(toStringPredict(hiChild, instance, dataset, attrNames, fields));
}
} else if (node instanceof Leaf) {
Leaf leaf = (Leaf) node;
double label = (Double) fields.get("Leaf.label").get(leaf);
if (dataset.isNumerical(dataset.getLabelId())) {
buff.append(doubleToString(label));
} else {
buff.append(dataset.getLabelString((int) label));
}
}
return buff.toString();
}
/**
* Predict trace to String
* @param tree
* Node of tree
* @param data
* @param attrNames
* attribute names
*/
public static String[] predictTrace(Node tree, Data data, String[] attrNames)
throws Exception {
Map<String, Field> reflectMap = getReflectMap();
String[] prediction = new String[data.size()];
for (int i = 0; i < data.size(); i++) {
prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(), attrNames, reflectMap);
}
return prediction;
}
/**
* Print predict trace
* @param tree
* Node of tree
* @param data
* @param attrNames
* attribute names
*/
public static void predictTracePrint(Node tree, Data data, String[] attrNames)
throws Exception {
Map<String, Field> reflectMap = getReflectMap();
for (int i = 0; i < data.size(); i++) {
System.out.println(toStringPredict(tree, data.get(i), data.getDataset(), attrNames,
reflectMap));
}
}
}