/**
* Portions Copyright 2001 Sun Microsystems, Inc.
* Portions Copyright 1999-2001 Language Technologies Institute,
* Carnegie Mellon University.
* All Rights Reserved. Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*/
package edu.cmu.sphinx.alignment.tokenizer;
import java.io.*;
import java.net.URL;
import java.util.StringTokenizer;
import java.util.logging.Logger;
import java.util.regex.Pattern;
/**
* Implementation of a Classification and Regression Tree (CART) that is used
* more like a binary decision tree, with each node containing a decision or a
* final value. The decision nodes in the CART trees operate on an Item and
* have the following format:
*
* <pre>
* NODE feat operand value qfalse
* </pre>
*
* <p>
* Where <code>feat</code> is an string that represents a feature to pass to
* the <code>findFeature</code> method of an item.
*
* <p>
* The <code>value</code> represents the value to be compared against the
* feature obtained from the item via the <code>feat</code> string. The
* <code>operand</code> is the operation to do the comparison. The available
* operands are as follows:
*
* <ul>
* <li>< - the feature is less than value
* <li>=- the feature is equal to the value
* <li>>- the feature is greater than the value
* <li>MATCHES - the feature matches the regular expression stored in value
* <li>IN - [[[TODO: still guessing because none of the CART's in Flite seem to
* use IN]]] the value is in the list defined by the feature.
* </ul>
*
* <p>
* [[[TODO: provide support for the IN operator.]]]
*
* <p>
* For < and >, this CART coerces the value and feature to float's. For =,
* this CART coerces the value and feature to string and checks for string
* equality. For MATCHES, this CART uses the value as a regular expression and
* compares the obtained feature to that.
*
* <p>
* A CART is represented by an array in this implementation. The
* <code>qfalse</code> value represents the index of the array to go to if the
* comparison does not match. In this implementation, qtrue index is always
* implied, and represents the next element in the array. The root node of the
* CART is the first element in the array.
*
* <p>
* The interpretations always start at the root node of the CART and continue
* until a final node is found. The final nodes have the following form:
*
* <pre>
* LEAF value
* </pre>
*
* <p>
* Where <code>value</code> represents the value of the node. Reaching a final
* node indicates the interpretation is over and the value of the node is the
* interpretation result.
*/
public class DecisionTree {
/** Logger instance. */
private static final Logger logger = Logger.getLogger(DecisionTree.class.getSimpleName());
/**
* Entry in file represents the total number of nodes in the file. This
* should be at the top of the file. The format should be "TOTAL n" where n
* is an integer value.
*/
final static String TOTAL = "TOTAL";
/**
* Entry in file represents a node. The format should be
* "NODE feat op val f" where 'feat' represents a feature, op represents an
* operand, val is the value, and f is the index of the node to go to is
* there isn't a match.
*/
final static String NODE = "NODE";
/**
* Entry in file represents a final node. The format should be "LEAF val"
* where val represents the value.
*/
final static String LEAF = "LEAF";
/**
* OPERAND_MATCHES
*/
final static String OPERAND_MATCHES = "MATCHES";
/**
* The CART. Entries can be DecisionNode or LeafNode. An ArrayList could be
* used here -- I chose not to because I thought it might be quicker to
* avoid dealing with the dynamic resizing.
*/
Node[] cart = null;
/**
* The number of nodes in the CART.
*/
transient int curNode = 0;
/**
* Creates a new CART by reading from the given URL.
*
* @param url the location of the CART data
*
* @throws IOException if errors occur while reading the data
*/
public DecisionTree(URL url) throws IOException {
BufferedReader reader;
String line;
reader = new BufferedReader(new InputStreamReader(url.openStream()));
line = reader.readLine();
while (line != null) {
if (!line.startsWith("***")) {
parseAndAdd(line);
}
line = reader.readLine();
}
reader.close();
}
/**
* Creates a new CART by reading from the given reader.
*
* @param reader the source of the CART data
* @param nodes the number of nodes to read for this cart
*
* @throws IOException if errors occur while reading the data
*/
public DecisionTree(BufferedReader reader, int nodes) throws IOException {
this(nodes);
String line;
for (int i = 0; i < nodes; i++) {
line = reader.readLine();
if (!line.startsWith("***")) {
parseAndAdd(line);
}
}
}
/**
* Creates a new CART that will be populated with nodes later.
*
* @param numNodes the number of nodes
*/
private DecisionTree(int numNodes) {
cart = new Node[numNodes];
}
/**
* Dump the CART tree as a dot file.
* <p>
* The dot tool is part of the graphviz distribution at <a
* href="http://www.graphviz.org/">http://www.graphviz.org/</a>. If
* installed, call it as "dot -O -Tpdf *.dot" from the console to generate
* pdfs.
* </p>
*
* @param out The PrintWriter to write to.
*/
public void dumpDot(PrintWriter out) {
out.write("digraph \"" + "CART Tree" + "\" {\n");
out.write("rankdir = LR\n");
for (Node n : cart) {
out.println("\t\"node" + n.hashCode() + "\" [ label=\""
+ n.toString() + "\", color=" + dumpDotNodeColor(n)
+ ", shape=" + dumpDotNodeShape(n) + " ]\n");
if (n instanceof DecisionNode) {
DecisionNode dn = (DecisionNode) n;
if (dn.qtrue < cart.length && cart[dn.qtrue] != null) {
out.write("\t\"node" + n.hashCode() + "\" -> \"node"
+ cart[dn.qtrue].hashCode()
+ "\" [ label=" + "TRUE" + " ]\n");
}
if (dn.qfalse < cart.length && cart[dn.qfalse] != null) {
out.write("\t\"node" + n.hashCode() + "\" -> \"node"
+ cart[dn.qfalse].hashCode()
+ "\" [ label=" + "FALSE" + " ]\n");
}
}
}
out.write("}\n");
out.close();
}
protected String dumpDotNodeColor(Node n) {
if (n instanceof LeafNode) {
return "green";
}
return "red";
}
protected String dumpDotNodeShape(Node n) {
return "box";
}
/**
* Creates a node from the given input line and add it to the CART. It
* expects the TOTAL line to come before any of the nodes.
*
* @param line a line of input to parse
*/
protected void parseAndAdd(String line) {
StringTokenizer tokenizer = new StringTokenizer(line, " ");
String type = tokenizer.nextToken();
if (type.equals(LEAF) || type.equals(NODE)) {
cart[curNode] = getNode(type, tokenizer, curNode);
cart[curNode].setCreationLine(line);
curNode++;
} else if (type.equals(TOTAL)) {
cart = new Node[Integer.parseInt(tokenizer.nextToken())];
curNode = 0;
} else {
throw new Error("Invalid CART type: " + type);
}
}
/**
* Gets the node based upon the type and tokenizer.
*
* @param type <code>NODE</code> or <code>LEAF</code>
* @param tokenizer the StringTokenizer containing the data to get
* @param currentNode the index of the current node we're looking at
*
* @return the node
*/
protected Node getNode(String type, StringTokenizer tokenizer,
int currentNode) {
if (type.equals(NODE)) {
String feature = tokenizer.nextToken();
String operand = tokenizer.nextToken();
Object value = parseValue(tokenizer.nextToken());
int qfalse = Integer.parseInt(tokenizer.nextToken());
if (operand.equals(OPERAND_MATCHES)) {
return new MatchingNode(feature, value.toString(),
currentNode + 1, qfalse);
} else {
return new ComparisonNode(feature, value, operand,
currentNode + 1, qfalse);
}
} else if (type.equals(LEAF)) {
return new LeafNode(parseValue(tokenizer.nextToken()));
}
return null;
}
/**
* Coerces a string into a value.
*
* @param string of the form "type(value)"; for example, "Float(2.3)"
*
* @return the value
*/
protected Object parseValue(String string) {
int openParen = string.indexOf("(");
String type = string.substring(0, openParen);
String value = string.substring(openParen + 1, string.length() - 1);
if (type.equals("String")) {
return value;
} else if (type.equals("Float")) {
return new Float(Float.parseFloat(value));
} else if (type.equals("Integer")) {
return new Integer(Integer.parseInt(value));
} else if (type.equals("List")) {
StringTokenizer tok = new StringTokenizer(value, ",");
int size = tok.countTokens();
int[] values = new int[size];
for (int i = 0; i < size; i++) {
float fval = Float.parseFloat(tok.nextToken());
values[i] = Math.round(fval);
}
return values;
} else {
throw new Error("Unknown type: " + type);
}
}
/**
* Passes the given item through this CART and returns the interpretation.
*
* @param item the item to analyze
*
* @return the interpretation
*/
public Object interpret(Item item) {
int nodeIndex = 0;
DecisionNode decision;
while (!(cart[nodeIndex] instanceof LeafNode)) {
decision = (DecisionNode) cart[nodeIndex];
nodeIndex = decision.getNextNode(item);
}
logger.fine("LEAF " + cart[nodeIndex].getValue());
return ((LeafNode) cart[nodeIndex]).getValue();
}
/**
* A node for the CART.
*/
static abstract class Node {
/**
* The value of this node.
*/
protected Object value;
/**
* Create a new Node with the given value.
*/
public Node(Object value) {
this.value = value;
}
/**
* Get the value.
*/
public Object getValue() {
return value;
}
/**
* Return a string representation of the type of the value.
*/
public String getValueString() {
if (value == null) {
return "NULL()";
} else if (value instanceof String) {
return "String(" + value.toString() + ")";
} else if (value instanceof Float) {
return "Float(" + value.toString() + ")";
} else if (value instanceof Integer) {
return "Integer(" + value.toString() + ")";
} else {
return value.getClass().toString() + "(" + value.toString()
+ ")";
}
}
/**
* sets the line of text used to create this node.
*
* @param line the creation line
*/
public void setCreationLine(String line) {}
}
/**
* A decision node that determines the next Node to go to in the CART.
*/
abstract static class DecisionNode extends Node {
/**
* The feature used to find a value from an Item.
*/
private PathExtractor path;
/**
* Index of Node to go to if the comparison doesn't match.
*/
protected int qfalse;
/**
* Index of Node to go to if the comparison matches.
*/
protected int qtrue;
/**
* The feature used to find a value from an Item.
*/
public String getFeature() {
return path.toString();
}
/**
* Find the feature associated with this DecisionNode and the given
* item
*
* @param item the item to start from
* @return the object representing the feature
*/
public Object findFeature(Item item) {
return path.findFeature(item);
}
/**
* Returns the next node based upon the descision determined at this
* node
*
* @param item the current item.
* @return the index of the next node
*/
public final int getNextNode(Item item) {
return getNextNode(findFeature(item));
}
/**
* Create a new DecisionNode.
*
* @param feature the string used to get a value from an Item
* @param value the value to compare to
* @param qtrue the Node index to go to if the comparison matches
* @param qfalse the Node machine index to go to upon no match
*/
public DecisionNode(String feature, Object value, int qtrue, int qfalse) {
super(value);
this.path = new PathExtractor(feature, true);
this.qtrue = qtrue;
this.qfalse = qfalse;
}
/**
* Get the next Node to go to in the CART. The return value is an index
* in the CART.
*/
abstract public int getNextNode(Object val);
}
/**
* A decision Node that compares two values.
*/
static class ComparisonNode extends DecisionNode {
/**
* LESS_THAN
*/
final static String LESS_THAN = "<";
/**
* EQUALS
*/
final static String EQUALS = "=";
/**
* GREATER_THAN
*/
final static String GREATER_THAN = ">";
/**
* The comparison type. One of LESS_THAN, GREATER_THAN, or EQUAL_TO.
*/
String comparisonType;
/**
* Create a new ComparisonNode with the given values.
*
* @param feature the string used to get a value from an Item
* @param value the value to compare to
* @param comparisonType one of LESS_THAN, EQUAL_TO, or GREATER_THAN
* @param qtrue the Node index to go to if the comparison matches
* @param qfalse the Node index to go to upon no match
*/
public ComparisonNode(String feature, Object value,
String comparisonType, int qtrue, int qfalse) {
super(feature, value, qtrue, qfalse);
if (!comparisonType.equals(LESS_THAN)
&& !comparisonType.equals(EQUALS)
&& !comparisonType.equals(GREATER_THAN)) {
throw new Error("Invalid comparison type: " + comparisonType);
} else {
this.comparisonType = comparisonType;
}
}
/**
* Compare the given value and return the appropriate Node index.
* IMPLEMENTATION NOTE: LESS_THAN and GREATER_THAN, the Node's value
* and the value passed in are converted to floating point values. For
* EQUAL, the Node's value and the value passed in are treated as
* String compares. This is the way of Flite, so be it Flite.
*
* @param val the value to compare
*/
public int getNextNode(Object val) {
boolean yes = false;
int ret;
if (comparisonType.equals(LESS_THAN)
|| comparisonType.equals(GREATER_THAN)) {
float cart_fval;
float fval;
if (value instanceof Float) {
cart_fval = ((Float) value).floatValue();
} else {
cart_fval = Float.parseFloat(value.toString());
}
if (val instanceof Float) {
fval = ((Float) val).floatValue();
} else {
fval = Float.parseFloat(val.toString());
}
if (comparisonType.equals(LESS_THAN)) {
yes = (fval < cart_fval);
} else {
yes = (fval > cart_fval);
}
} else { // comparisonType = "="
String sval = val.toString();
String cart_sval = value.toString();
yes = sval.equals(cart_sval);
}
if (yes) {
ret = qtrue;
} else {
ret = qfalse;
}
logger.fine(trace(val, yes, ret));
return ret;
}
private String trace(Object value, boolean match, int next) {
return "NODE " + getFeature() + " [" + value + "] "
+ comparisonType + " [" + getValue() + "] "
+ (match ? "Yes" : "No") + " next " + next;
}
/**
* Get a string representation of this Node.
*/
public String toString() {
return "NODE " + getFeature() + " " + comparisonType + " "
+ getValueString() + " " + Integer.toString(qtrue) + " "
+ Integer.toString(qfalse);
}
}
/**
* A Node that checks for a regular expression match.
*/
static class MatchingNode extends DecisionNode {
Pattern pattern;
/**
* Create a new MatchingNode with the given values.
*
* @param feature the string used to get a value from an Item
* @param regex the regular expression
* @param qtrue the Node index to go to if the comparison matches
* @param qfalse the Node index to go to upon no match
*/
public MatchingNode(String feature, String regex, int qtrue, int qfalse) {
super(feature, regex, qtrue, qfalse);
this.pattern = Pattern.compile(regex);
}
/**
* Compare the given value and return the appropriate CART index.
*
* @param val the value to compare -- this must be a String
*/
public int getNextNode(Object val) {
return pattern.matcher((String) val).matches() ? qtrue : qfalse;
}
/**
* Get a string representation of this Node.
*/
public String toString() {
StringBuffer buf =
new StringBuffer(NODE + " " + getFeature() + " "
+ OPERAND_MATCHES);
buf.append(getValueString() + " ");
buf.append(Integer.toString(qtrue) + " ");
buf.append(Integer.toString(qfalse));
return buf.toString();
}
}
/**
* The final Node of a CART. This just a marker class.
*/
static class LeafNode extends Node {
/**
* Create a new LeafNode with the given value.
*
* @param the value of this LeafNode
*/
public LeafNode(Object value) {
super(value);
}
/**
* Get a string representation of this Node.
*/
public String toString() {
return "LEAF " + getValueString();
}
}
}