import java.io.*; import java.util.HashMap; class PredictCSV { private static String modelClassName; private static String inputCSVFileName; private static String outputCSVFileName; private static int skipFirstLine = -1; private static void usage() { System.out.println(""); System.out.println("usage: java [...java args...] PredictCSV (--header | --noheader) --model modelClassName --input inputCSVFileName --output outputCSVFileName"); System.out.println(""); System.out.println(" model class name is something like GBMModel_blahblahblahblah."); System.out.println(""); System.out.println(" inputCSV is the test data set."); System.out.println(" Specify --header or --noheader as appropriate."); System.out.println(""); System.out.println(" outputCSV is the prediction data set (one row per test data set)."); System.out.println(""); System.exit(1); } private static void usageHeader() { System.out.println("ERROR: One of --header or --noheader must be specified exactly once"); usage(); } private static void parseArgs(String[] args) { for (int i = 0; i < args.length; i++) { String s = args[i]; if (s.equals("--model")) { i++; if (i >= args.length) usage(); modelClassName = args[i]; } else if (s.equals("--input")) { i++; if (i >= args.length) usage(); inputCSVFileName = args[i]; } else if (s.equals("--output")) { i++; if (i >= args.length) usage(); outputCSVFileName = args[i]; } else if (s.equals("--header")) { if (skipFirstLine >= 0) usageHeader(); skipFirstLine = 1; } else if (s.equals("--noheader")) { if (skipFirstLine >= 0) usageHeader(); skipFirstLine = 0; } else { System.out.println("ERROR: Bad parameter: " + s); usage(); } } if (skipFirstLine < 0) { usageHeader(); } if (modelClassName == null) { System.out.println("ERROR: model not specified"); usage(); } if (inputCSVFileName == null) { System.out.println("ERROR: input not specified"); usage(); } if (outputCSVFileName == null) { System.out.println("ERROR: output not specified"); usage(); } } public static void main(String[] args) throws Exception{ parseArgs(args); water.genmodel.GeneratedModel model; model = (water.genmodel.GeneratedModel) Class.forName(modelClassName).newInstance(); BufferedReader input = new BufferedReader(new FileReader(inputCSVFileName)); BufferedWriter output = new BufferedWriter(new FileWriter(outputCSVFileName)); System.out.println("COLS " + model.getNumCols()); // Create map of input variable domain information. // This contains the categorical string to numeric mapping. HashMap<Integer,HashMap<String,Integer>> domainMap = new HashMap<Integer,HashMap<String,Integer>>(); for (int i = 0; i < model.getNumCols(); i++) { String[] domainValues = model.getDomainValues(i); if (domainValues != null) { HashMap<String,Integer> m = new HashMap<String,Integer>(); for (int j = 0; j < domainValues.length; j++) { System.out.println("Putting ("+ i +","+ j +","+ domainValues[j] +")"); m.put(domainValues[j], new Integer(j)); } domainMap.put(i, m); } } // Print outputCSV column names. if (model.isAutoEncoder()) { output.write(model.getHeader()); } else { output.write("predict"); for (int i = 0; i < model.getNumResponseClasses(); i++) { output.write(","); output.write(model.getDomainValues(model.getResponseIdx())[i]); } } output.write("\n"); // Loop over inputCSV one row at a time. int lineno = 0; String line = null; // An array to store predicted values float[] preds = new float[model.getPredsSize()]; while ((line = input.readLine()) != null) { lineno++; if (skipFirstLine > 0) { skipFirstLine = 0; String[] names = line.trim().split(","); String[] modelNames = model.getNames(); for (int i=0; i < Math.min(names.length, modelNames.length); i++ ) if ( !names[i].equals(modelNames[i]) ) { System.out.println("ERROR: Column names does not match: input column " + i + ". "+names[i]+" != model column "+modelNames[i] ); System.exit(1); } // go to the next line continue; } // Parse the CSV line. Don't handle quoted commas. This isn't a parser test. String trimmedLine = line.trim(); String[] inputColumnsArray = trimmedLine.split(","); int numInputColumns = model.isAutoEncoder() ? model.getNames().length : model.getNames().length-1; // we do not need response ! if (inputColumnsArray.length != numInputColumns) { System.out.println("WARNING: Line " + lineno + " has " + inputColumnsArray.length + " columns (expected " + numInputColumns + ")"); } // Assemble the input values for the row. double[] row = new double[numInputColumns]; int j = 0; for (j = 0; j < inputColumnsArray.length; j++) { String cellString = inputColumnsArray[j]; // System.out.println("Line " + lineno +" column ("+ model.getNames()[i] + " == " + i + ") cellString("+cellString+")"); String[] domainValues = model.getDomainValues(j); if (cellString.equals("") || // empty field is default NA (domainValues == null) && ( // if the column is enum then NA is part of domain by default ! cellString.equals("NA") || cellString.equals("N/A") || cellString.equals("-") ) ) { row[j] = Double.NaN; } else { if (domainValues != null) { HashMap m = (HashMap<String,Integer>) domainMap.get(j); assert (m != null); Integer cellOrdinalValue = (Integer) m.get(cellString); if (cellOrdinalValue == null) { System.out.println("WARNING: Line " + lineno + " column ("+ model.getNames()[j] + " == " + j +") has unknown categorical value (" + cellString + ")"); row[j] = Double.NaN; } else { row[j] = (double) cellOrdinalValue.intValue(); } } else { try { double value = Double.parseDouble(cellString); row[j] = value; } catch (java.lang.NumberFormatException e) { row[j] = Double.NaN; } } } } for (; j< numInputColumns; j++) row[j] = Double.NaN; // Do the prediction. //model.predict(row, preds); preds = model.predict(row, preds); // Emit the result to the output file. for (int i = 0; i < preds.length; i++) { if (i == 0 && model.isClassifier()) { // See if there is a domain to map this output value to. String[] domainValues = model.getDomainValues(model.getResponseIdx()); if (domainValues != null) { // Classification. double value = preds[i]; int valueAsInt = (int)value; if (value != (int)valueAsInt) { System.out.println("ERROR: Line " + lineno + " has non-integer output for classification (" + value + ")"); System.exit(1); } String predictedOutputClassLevel = domainValues[valueAsInt]; output.write(predictedOutputClassLevel); } } else { if (i > 0) output.write(","); output.write(Double.toHexString(preds[i])); if (!model.isClassifier() && !model.isAutoEncoder()) break; } } output.write("\n"); } // Clean up. output.close(); input.close(); // Predictions were successfully generated. Calling program can now compare them with something. System.exit(0); } }