package hex.genmodel.tools;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.drf.DrfMojoModel;
import hex.genmodel.algos.gbm.GbmMojoModel;
import hex.genmodel.algos.tree.SharedTreeGraph;
import java.io.*;
/**
* Print dot (graphviz) representation of one or more trees in a DRF or GBM model.
*/
public class PrintMojo {
private GenModel genModel;
private static boolean printRaw = false;
private static int treeToPrint = -1;
private static int maxLevelsToPrintPerEdge = 10;
private static boolean detail = false;
private static String outputFileName = null;
private static String optionalTitle = null;
public static void main(String[] args) {
// Parse command line arguments
PrintMojo main = new PrintMojo();
main.parseArgs(args);
// Run the main program
try {
main.run();
} catch (Exception e) {
e.printStackTrace();
System.exit(2);
}
// Success
System.exit(0);
}
private void loadMojo(String modelName) throws IOException {
genModel = MojoModel.load(modelName);
}
private static void usage() {
System.out.println("Emit a human-consumable graph of a model for use with dot (graphviz).");
System.out.println("The currently supported model types are DRF and GBM.");
System.out.println("");
System.out.println("Usage: java [...java args...] hex.genmodel.tools.PrintMojo [--tree n] [--levels n] [--title sss] [-o outputFileName]");
System.out.println("");
System.out.println(" --tree Tree number to print.");
System.out.println(" [default all]");
System.out.println("");
System.out.println(" --levels Number of levels per edge to print.");
System.out.println(" [default " + maxLevelsToPrintPerEdge + "]");
System.out.println("");
System.out.println(" --title (Optional) Force title of tree graph.");
System.out.println("");
System.out.println(" --detail Specify to print additional detailed information like node numbers.");
System.out.println("");
System.out.println(" --input | -i Input mojo file.");
System.out.println("");
System.out.println(" --output | -o Output dot filename.");
System.out.println(" [default stdout]");
System.out.println("");
System.out.println("Example:");
System.out.println("");
System.out.println(" (brew install graphviz)");
System.out.println(" java -cp h2o.jar hex.genmodel.tools.PrintMojo --tree 0 -i model_mojo.zip -o model.gv");
System.out.println(" dot -Tpng model.gv -o model.png");
System.out.println(" open model.png");
System.out.println("");
System.exit(1);
}
private void parseArgs(String[] args) {
try {
for (int i = 0; i < args.length; i++) {
String s = args[i];
switch (s) {
case "--tree":
i++;
if (i >= args.length) usage();
s = args[i];
try {
treeToPrint = Integer.parseInt(s);
}
catch (Exception e) {
System.out.println("ERROR: invalid --tree argument (" + s + ")");
System.exit(1);
}
break;
case "--levels":
i++;
if (i >= args.length) usage();
s = args[i];
try {
maxLevelsToPrintPerEdge = Integer.parseInt(s);
}
catch (Exception e) {
System.out.println("ERROR: invalid --levels argument (" + s + ")");
System.exit(1);
}
break;
case "--title":
i++;
if (i >= args.length) usage();
optionalTitle = args[i];
break;
case "--detail":
detail = true;
break;
case "--input":
case "-i":
i++;
if (i >= args.length) usage();
s = args[i];
loadMojo(s);
break;
case "--raw":
printRaw = true;
break;
case "-o":
case "--output":
i++;
if (i >= args.length) usage();
outputFileName = args[i];
break;
default:
System.out.println("ERROR: Unknown command line argument: " + s);
usage();
break;
}
}
} catch (Exception e) {
e.printStackTrace();
usage();
}
}
private void validateArgs() {
if (genModel == null) {
System.out.println("ERROR: Must specify -i");
usage();
}
}
private void run() throws Exception {
validateArgs();
PrintStream os;
if (outputFileName != null) {
os = new PrintStream(new FileOutputStream(new File(outputFileName)));
}
else {
os = System.out;
}
if (genModel instanceof GbmMojoModel) {
SharedTreeGraph g = ((GbmMojoModel) genModel)._computeGraph(treeToPrint);
if (printRaw) {
g.print();
}
g.printDot(os, maxLevelsToPrintPerEdge, detail, optionalTitle);
}
else if (genModel instanceof DrfMojoModel) {
SharedTreeGraph g = ((DrfMojoModel) genModel)._computeGraph(treeToPrint);
if (printRaw) {
g.print();
}
g.printDot(os, maxLevelsToPrintPerEdge, detail, optionalTitle);
}
else {
System.out.println("ERROR: Unknown MOJO type");
System.exit(1);
}
}
}