/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.conf.cmdline; import java.io.FileOutputStream; import java.io.InputStream; import java.io.PrintWriter; import java.util.List; import java.util.Properties; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.commons.cli2.util.HelpFormatter; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.classifier.sgd.LogisticModelParameters; import org.apache.mahout.classifier.sgd.TrainLogistic; import com.cloudera.iterativereduce.ConfigFields; import com.cloudera.iterativereduce.yarn.client.Client; import com.google.common.collect.Lists; public class ModelTrainerCmdLineDriver extends Client { private static String input_dir = ""; private static String output_dir = ""; public static void main(String[] args) throws Exception { mainToOutput(args, new PrintWriter(System.out, true)); int rc = ToolRunner.run(new Configuration(), new ModelTrainerCmdLineDriver(), args); // Log, because been bitten before on daemon threads; sanity check System.out.println("Calling System.exit(" + rc + ")"); System.exit(rc); } static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (parseArgs(args)) { output.write("Parse:correct"); } // if } // mainToOutput private static boolean parseArgs(String[] args) { DefaultOptionBuilder builder = new DefaultOptionBuilder(); Option help = builder.withLongName("help").withDescription( "print this list").create(); // Option quiet = // builder.withLongName("quiet").withDescription("be extra quiet").create(); // Option scores = // builder.withLongName("scores").withDescription("output score diagnostics during training").create(); ArgumentBuilder argumentBuilder = new ArgumentBuilder(); Option inputFile = builder .withLongName("input") .withRequired(true) .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) .withDescription("where to get training data").create(); Option outputFile = builder.withLongName("output").withRequired(true) .withArgument( argumentBuilder.withName("output").withMaximum(1).create()) .withDescription("where to get training data").create(); Option features = builder.withLongName("features").withArgument( argumentBuilder.withName("numFeatures").withDefault("1000") .withMaximum(1).create()).withDescription( "the number of internal hashed features to use").create(); // optionally can be { 20Newsgroups, rcv1 } Option RecordFactoryType = builder.withLongName("recordFactoryType") .withArgument( argumentBuilder.withName("recordFactoryType").withDefault( "20Newsgroups").withMaximum(1).create()).withDescription( "the record vectorization factory to use").create(); Option passes = builder.withLongName("passes").withArgument( argumentBuilder.withName("passes").withDefault("2").withMaximum(1) .create()).withDescription( "the number of times to pass over the input data").create(); Option lambda = builder.withLongName("lambda").withArgument( argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1) .create()) .withDescription("the amount of coefficient decay to use").create(); Option rate = builder.withLongName("rate").withArgument( argumentBuilder.withName("learningRate").withDefault("1e-3") .withMaximum(1).create()).withDescription("the learning rate") .create(); Option noBias = builder.withLongName("noBias").withDescription( "don't include a bias term").create(); Group normalArgs = new GroupBuilder().withOption(help) .withOption(inputFile).withOption(outputFile).withOption( RecordFactoryType).withOption(passes).withOption(lambda) .withOption(rate).withOption(noBias).withOption(features).create(); Parser parser = new Parser(); parser.setHelpOption(help); parser.setHelpTrigger("--help"); parser.setGroup(normalArgs); parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); CommandLine cmdLine = parser.parseAndHelp(args); if (cmdLine == null) { System.out.println("null!"); return false; } input_dir = getStringArgument(cmdLine, inputFile); output_dir = getStringArgument(cmdLine, outputFile); /* * TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile); * TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile); * * List<String> typeList = Lists.newArrayList(); for (Object x : * cmdLine.getValues(types)) { typeList.add(x.toString()); } * * List<String> predictorList = Lists.newArrayList(); for (Object x : * cmdLine.getValues(predictors)) { predictorList.add(x.toString()); } * * lmp = new LogisticModelParameters(); * lmp.setTargetVariable(getStringArgument(cmdLine, target)); * lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, * targetCategories)); lmp.setNumFeatures(getIntegerArgument(cmdLine, * features)); lmp.setUseBias(!getBooleanArgument(cmdLine, noBias)); * lmp.setTypeMap(predictorList, typeList); * * lmp.setLambda(getDoubleArgument(cmdLine, lambda)); * lmp.setLearningRate(getDoubleArgument(cmdLine, rate)); * * TrainLogistic.scores = getBooleanArgument(cmdLine, scores); * TrainLogistic.passes = getIntegerArgument(cmdLine, passes); */ return true; } public Configuration generateDebugConfigurationObject() { Configuration c = new Configuration(); // feature vector size c.setInt("com.cloudera.knittingboar.setup.FeatureVectorSize", 10000); c.setInt("com.cloudera.knittingboar.setup.numCategories", 20); c.setInt("com.cloudera.knittingboar.setup.BatchSize", 200); c.setInt("com.cloudera.knittingboar.setup.NumberPasses", 1); // local input split path c.set("com.cloudera.knittingboar.setup.LocalInputSplitPath", "hdfs://127.0.0.1/input/0"); // setup 20newsgroups c.set("com.cloudera.knittingboar.setup.RecordFactoryClassname", "com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory"); return c; } private void BuildPropertiesFile() throws Exception { // Setup app.properties InputStream is = Thread.currentThread().getContextClassLoader() .getResourceAsStream("app.properties"); if (is == null) throw new RuntimeException( "Could not find 'app.properties' template file in classpath"); Properties props = new Properties(); props.load(is); props.put(ConfigFields.JAR_PATH, "/dev/null"); // what about these? props.put(ConfigFields.APP_JAR_PATH, "/dev/null"); // what about these? props.put(ConfigFields.APP_INPUT_PATH, ModelTrainerCmdLineDriver.input_dir); props.put(ConfigFields.APP_OUTPUT_PATH, ModelTrainerCmdLineDriver.output_dir); props.put(ConfigFields.YARN_MASTER, "com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode"); props.put(ConfigFields.YARN_WORKER, "com.cloudera.knittingboar.sgd.iterativereduce.POLRWorkerNode"); props.put("com.cloudera.knittingboar.setup.FeatureVectorSize", 10000); props.put("com.cloudera.knittingboar.setup.numCategories", 20); props.put("com.cloudera.knittingboar.setup.BatchSize", 200); props.put("com.cloudera.knittingboar.setup.NumberPasses", 1); // local input split path // props.put( "com.cloudera.knittingboar.setup.LocalInputSplitPath", // "hdfs://127.0.0.1/input/0" ); // setup 20newsgroups props.put("com.cloudera.knittingboar.setup.RecordFactoryClassname", "com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory"); props.store(new FileOutputStream("app.properties"), null); } /* * public void Train() { * * Client client = new Client(); client.setConf(yarnCluster.getConfig()); * client.run(new String[] { testDir + "/app.properties"}); * * } */ private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { return cmdLine.hasOption(option); } private static String getStringArgument(CommandLine cmdLine, Option inputFile) { return (String) cmdLine.getValue(inputFile); } }