/** * Copyright (c) 2010, Regents of the University of Colorado All rights * reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. Redistributions in binary * form must reproduce the above copyright notice, this list of conditions and * the following disclaimer in the documentation and/or other materials provided * with the distribution. Neither the name of the University of Colorado at * Boulder nor the names of its contributors may be used to endorse or promote * products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ package clear.engine; import clear.train.AbstractTrainer; import clear.train.BinaryTrainer; import clear.train.OneVsAllTrainer; import clear.train.algorithm.IAlgorithm; import clear.train.algorithm.LibLinearL2; import clear.train.algorithm.RRM; import clear.train.kernel.NoneKernel; import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; import org.kohsuke.args4j.Option; /** * Trains a classifier. * * @author Jinho D. Choi <b>Last update:</b> 11/8/2010 */ public class MLTrain { @Option(name = "-i", usage = "instance file", required = true, metaVar = "REQUIRED") String s_instanceFile; @Option(name = "-m", usage = "model file", required = true, metaVar = "REQUIRED") String s_modelFile; @Option(name = "-a", usage = "algorithm ::= " + IAlgorithm.LIBLINEAR_L2 + " (LibLinear L2-SVM; default) |\n " + IAlgorithm.RRM + " (Robust Risk Minimization)", metaVar = "OPTIONAL") String s_algorithm = IAlgorithm.LIBLINEAR_L2; @Option(name = "-s", usage = "strategy ::= " + AbstractTrainer.ST_BINARY + " (binary) | " + AbstractTrainer.ST_ONE_VS_ALL + " (one-vs-all; default)", metaVar = "OPTIONAL") byte i_strategy = AbstractTrainer.ST_ONE_VS_ALL; @Option(name = "-n", usage = "# of threads to train with (default = 2)", metaVar = "OPTIONAL") int i_numThreads = 2; @Option(name = "-L", usage = "LIB: loss type ::= 1 (L1-loss; default) | 2 (L2-loss)", metaVar = "OPTIONAL") byte i_lossType = 1; @Option(name = "-E", usage = "LIB: termination criterion (default = 0.1)\nRRM: learning rate (default = 0.001)", metaVar = "OPTIONAL") double d_e = 0.1; @Option(name = "-B", usage = "LIB: bias (default = -1)", metaVar = "OPTIONAL") double d_bias = -1; @Option(name = "-C", usage = "LIB: penalty (default = 0.1)\nRRM: regularization (default = 0.1)", metaVar = "OPTIONAL") double d_c = 0.1; @Option(name = "-K", usage = "RRM: max # of iterations (default = 40)", metaVar = "OPTIONAL") int i_K = 40; @Option(name = "-M", usage = "RRM: initial weights (default = 1.0)", metaVar = "OPTIONAL") double d_mu = 1.0; public MLTrain(String[] args) { CmdLineParser cmd = new CmdLineParser(this); try { cmd.parseArgument(args); long st = System.currentTimeMillis(); IAlgorithm algorithm; if (s_algorithm.equals(IAlgorithm.LIBLINEAR_L2)) { algorithm = new LibLinearL2(i_lossType, d_c, d_e, d_bias); } else // RRM { d_e = 0.001; algorithm = new RRM(i_K, d_mu, d_e, d_c); } if (i_strategy == AbstractTrainer.ST_BINARY) { new BinaryTrainer(s_modelFile, algorithm, new NoneKernel(s_instanceFile)); } else // One-vs-all { new OneVsAllTrainer(s_modelFile, algorithm, new NoneKernel(s_instanceFile), i_numThreads); } long time = System.currentTimeMillis() - st; System.out.printf("\n* Training time: %d hours, %d minutes\n", time / (1000 * 3600), time / (1000 * 60)); } catch (CmdLineException e) { System.err.println(e.getMessage()); cmd.printUsage(System.err); } } static public void main(String[] args) { new MLTrain(args); } }