package hex.genmodel.tools; import hex.ModelCategory; import hex.genmodel.GenModel; import hex.genmodel.MojoModel; import hex.genmodel.easy.EasyPredictModelWrapper; import hex.genmodel.easy.RowData; import hex.genmodel.easy.prediction.*; import au.com.bytecode.opencsv.CSVReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; /** * Simple driver program for reading a CSV file and making predictions. * * This driver program is used as a test harness by several tests in the testdir_javapredict directory. * <p></p> * See the top-of-tree master version of this file <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-genmodel/src/main/java/hex/genmodel/tools/PredictCsv.java" target="_blank">here on github</a>. */ public class PredictCsv { private String modelName; private String inputCSVFileName; private String outputCSVFileName; private boolean useDecimalOutput = false; // Model instance private EasyPredictModelWrapper model; public static void main(String[] args) { // Parse command line arguments PredictCsv main = new PredictCsv(); main.parseArgs(args); // Run the main program try { main.run(); } catch (Exception e) { e.printStackTrace(); System.exit(2); } // Predictions were successfully generated. System.exit(0); } private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) { // Assemble the input values for the row. RowData row = new RowData(); int maxI = Math.min(inputColumnNames.length, splitLine.length); for (int i = 0; i < maxI; i++) { String columnName = inputColumnNames[i]; String cellData = splitLine[i]; switch (cellData) { case "": case "NA": case "N/A": case "-": continue; default: row.put(columnName, cellData); } } return row; } private String myDoubleToString(double d) { if (Double.isNaN(d)) { return "NA"; } return useDecimalOutput? Double.toString(d) : Double.toHexString(d); } private void run() throws Exception { ModelCategory category = model.getModelCategory(); CSVReader reader = new CSVReader(new FileReader(inputCSVFileName)); BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName)); // Emit outputCSV column names. switch (category) { case AutoEncoder: output.write(model.getHeader()); break; case Binomial: case Multinomial: output.write("predict"); String[] responseDomainValues = model.getResponseDomainValues(); for (String s : responseDomainValues) { output.write(","); output.write(s); } break; case Clustering: output.write("cluster"); break; case Regression: output.write("predict"); break; default: throw new Exception("Unknown model category " + category); } output.write("\n"); // Loop over inputCSV one row at a time. // // TODO: performance of scoring can be considerably improved if instead of scoring each row at a time we passed // all the rows to the score function, in which case it can evaluate each tree for each row, avoiding // multiple rounds of fetching each tree from the filesystem. // int lineNum = 0; try { String[] inputColumnNames = null; String[] splitLine; while ((splitLine = reader.readNext()) != null) { lineNum++; // Handle the header. if (lineNum == 1) { inputColumnNames = splitLine; continue; } // Parse the CSV line. Don't handle quoted commas. This isn't a parser test. RowData row = formatDataRow(splitLine, inputColumnNames); // Do the prediction. // Emit the result to the output file. switch (category) { case AutoEncoder: { throw new UnsupportedOperationException(); // AutoEncoderModelPrediction p = model.predictAutoEncoder(row); // break; } case Binomial: { BinomialModelPrediction p = model.predictBinomial(row); output.write(p.label); output.write(","); for (int i = 0; i < p.classProbabilities.length; i++) { if (i > 0) { output.write(","); } output.write(myDoubleToString(p.classProbabilities[i])); } break; } case Multinomial: { MultinomialModelPrediction p = model.predictMultinomial(row); output.write(p.label); output.write(","); for (int i = 0; i < p.classProbabilities.length; i++) { if (i > 0) { output.write(","); } output.write(myDoubleToString(p.classProbabilities[i])); } break; } case Clustering: { ClusteringModelPrediction p = model.predictClustering(row); output.write(myDoubleToString(p.cluster)); break; } case Regression: { RegressionModelPrediction p = model.predictRegression(row); output.write(myDoubleToString(p.value)); break; } default: throw new Exception("Unknown model category " + category); } output.write("\n"); } } catch (Exception e) { System.out.println("Caught exception on line " + lineNum); System.out.println(""); e.printStackTrace(); System.exit(1); } // Clean up. output.close(); reader.close(); } private void loadModel(String modelName) throws Exception { try { loadMojo(modelName); } catch (IOException e) { loadPojo(modelName); // may throw an exception too } } private void loadPojo(String className) throws Exception { GenModel genModel = (GenModel) Class.forName(className).newInstance(); model = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true)); } private void loadMojo(String modelName) throws IOException { GenModel genModel = MojoModel.load(modelName); model = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true)); } private static void usage() { System.out.println(""); System.out.println("Usage: java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName"); System.out.println(" --pojo pojoName --input inputFile --output outputFile --decimal"); System.out.println(""); System.out.println(" --mojo Name of the zip file containing model's MOJO."); System.out.println(" --pojo Name of the java class containing the model's POJO. Either this "); System.out.println(" parameter or --model must be specified."); System.out.println(" --input CSV file containing the test data set to score."); System.out.println(" --output Name of the output CSV file with computed predictions."); System.out.println(" --decimal Use decimal numbers in the output (default is to use hexademical)."); System.out.println(""); System.exit(1); } private void parseArgs(String[] args) { try { for (int i = 0; i < args.length; i++) { String s = args[i]; if (s.equals("--header")) continue; if (s.equals("--decimal")) useDecimalOutput = true; else { i++; if (i >= args.length) usage(); String sarg = args[i]; switch (s) { case "--model": loadModel(sarg); break; case "--mojo": loadMojo(sarg); break; case "--pojo": loadPojo(sarg); break; case "--input": inputCSVFileName = sarg; break; case "--output": outputCSVFileName = sarg; break; default: System.out.println("ERROR: Unknown command line argument: " + s); usage(); } } } } catch (Exception e) { e.printStackTrace(); usage(); } } }