/** * Copyright (c) 2007-2014 The LIBLINEAR Project. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this list of conditions * and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials provided with * the distribution. * * 3. Neither name of copyright holders nor the names of its contributors may be used to endorse or * promote products derived from this software without specific prior written permission. * * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF * THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package de.bwaldvogel.liblinear; import static de.bwaldvogel.liblinear.Linear.atof; import static de.bwaldvogel.liblinear.Linear.atoi; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.NoSuchElementException; import java.util.StringTokenizer; public class Train { public static void main(String[] args) throws IOException, InvalidInputDataException { new Train().run(args); } private double bias = 1; private boolean cross_validation = false; private String inputFilename; private String modelFilename; private int nr_fold; private Parameter param = null; private Problem prob = null; private void do_cross_validation() { double total_error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; double[] target = new double[prob.l]; long start, stop; start = System.currentTimeMillis(); Linear.crossValidation(prob, param, nr_fold, target); stop = System.currentTimeMillis(); System.out.println("time: " + (stop - start) + " ms"); if (param.solverType.isSupportVectorRegression()) { for (int i = 0; i < prob.l; i++) { double y = prob.y[i]; double v = target[i]; total_error += (v - y) * (v - y); sumv += v; sumy += y; sumvv += v * v; sumyy += y * y; sumvy += v * y; } System.out.printf("Cross Validation Mean squared error = %g%n", total_error / prob.l); System.out.printf("Cross Validation Squared correlation coefficient = %g%n", // (prob.l * sumvy - sumv * sumy) * (prob.l * sumvy - sumv * sumy) / ((prob.l * sumvv - sumv * sumv) * (prob.l * sumyy - sumy * sumy))); } else { int total_correct = 0; for (int i = 0; i < prob.l; i++) { if (target[i] == prob.y[i]) { ++total_correct; } } System.out.printf("correct: %d%n", total_correct); System.out.printf("Cross Validation Accuracy = %g%%%n", 100.0 * total_correct / prob.l); } } private void exit_with_help() { System.out.printf("Usage: train [options] training_set_file [model_file]%n" // + "options:%n" + "-s type : set type of solver (default 1)%n" + " for multi-class classification%n" + " 0 -- L2-regularized logistic regression (primal)%n" + " 1 -- L2-regularized L2-loss support vector classification (dual)%n" + " 2 -- L2-regularized L2-loss support vector classification (primal)%n" + " 3 -- L2-regularized L1-loss support vector classification (dual)%n" + " 4 -- support vector classification by Crammer and Singer%n" + " 5 -- L1-regularized L2-loss support vector classification%n" + " 6 -- L1-regularized logistic regression%n" + " 7 -- L2-regularized logistic regression (dual)%n" + " for regression%n" + " 11 -- L2-regularized L2-loss support vector regression (primal)%n" + " 12 -- L2-regularized L2-loss support vector regression (dual)%n" + " 13 -- L2-regularized L1-loss support vector regression (dual)%n" + "-c cost : set the parameter C (default 1)%n" + "-p epsilon : set the epsilon in loss function of SVR (default 0.1)%n" + "-e epsilon : set tolerance of termination criterion%n" + " -s 0 and 2%n" + " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,%n" + " where f is the primal function and pos/neg are # of%n" + " positive/negative data (default 0.01)%n" + " -s 11%n" + " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)%n" + " -s 1, 3, 4 and 7%n" + " Dual maximal violation <= eps; similar to libsvm (default 0.1)%n" + " -s 5 and 6%n" + " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,%n" + " where f is the primal function (default 0.01)%n" + " -s 12 and 13\n" + " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n" + " where f is the dual function (default 0.1)\n" + "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)%n" + "-wi weight: weights adjust the parameter C of different classes (see README for details)%n" + "-v n: n-fold cross validation mode%n" + "-q : quiet mode (no outputs)%n"); System.exit(1); } Problem getProblem() { return prob; } double getBias() { return bias; } Parameter getParameter() { return param; } void parse_command_line(String argv[]) { int i; // eps: see setting below param = new Parameter(SolverType.L2R_L2LOSS_SVC_DUAL, 1, Double.POSITIVE_INFINITY, 0.1); // default values bias = -1; cross_validation = false; // parse options for (i = 0; i < argv.length; i++) { if (argv[i].charAt(0) != '-') { break; } if (++i >= argv.length) { exit_with_help(); } switch (argv[i - 1].charAt(1)) { case 's': param.solverType = SolverType.getById(atoi(argv[i])); break; case 'c': param.setC(atof(argv[i])); break; case 'p': param.setP(atof(argv[i])); break; case 'e': param.setEps(atof(argv[i])); break; case 'B': bias = atof(argv[i]); break; case 'w': int weightLabel = atoi(argv[i - 1].substring(2)); double weight = atof(argv[i]); param.weightLabel = addToArray(param.weightLabel, weightLabel); param.weight = addToArray(param.weight, weight); break; case 'v': cross_validation = true; nr_fold = atoi(argv[i]); if (nr_fold < 2) { System.err.println("n-fold cross validation: n must >= 2"); exit_with_help(); } break; case 'q': i--; Linear.disableDebugOutput(); break; default: System.err.println("unknown option"); exit_with_help(); } } // determine filenames if (i >= argv.length) { exit_with_help(); } inputFilename = argv[i]; if (i < argv.length - 1) { modelFilename = argv[i + 1]; } else { int p = argv[i].lastIndexOf('/'); ++p; // whew... modelFilename = argv[i].substring(p) + ".model"; } if (param.eps == Double.POSITIVE_INFINITY) { switch (param.solverType) { case L2R_LR: case L2R_L2LOSS_SVC: param.setEps(0.01); break; case L2R_L2LOSS_SVR: param.setEps(0.001); break; case L2R_L2LOSS_SVC_DUAL: case L2R_L1LOSS_SVC_DUAL: case MCSVM_CS: case L2R_LR_DUAL: param.setEps(0.1); break; case L1R_L2LOSS_SVC: case L1R_LR: param.setEps(0.01); break; case L2R_L1LOSS_SVR_DUAL: case L2R_L2LOSS_SVR_DUAL: param.setEps(0.1); break; default: throw new IllegalStateException("unknown solver type: " + param.solverType); } } } /** * reads a problem from LibSVM format * * @param file * the SVM file * @throws IOException * obviously in case of any I/O exception ;) * @throws InvalidInputDataException * if the input file is not correctly formatted */ public static Problem readProblem(File file, double bias) throws IOException, InvalidInputDataException { BufferedReader fp = new BufferedReader(new FileReader(file)); List<Double> vy = new ArrayList<Double>(); List<Feature[]> vx = new ArrayList<Feature[]>(); int max_index = 0; int lineNr = 0; try { while (true) { String line = fp.readLine(); if (line == null) { break; } lineNr++; StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); String token; try { token = st.nextToken(); } catch (NoSuchElementException e) { throw new InvalidInputDataException("empty line", file, lineNr, e); } try { vy.add(atof(token)); } catch (NumberFormatException e) { throw new InvalidInputDataException("invalid label: " + token, file, lineNr, e); } int m = st.countTokens() / 2; Feature[] x; if (bias >= 0) { x = new Feature[m + 1]; } else { x = new Feature[m]; } int indexBefore = 0; for (int j = 0; j < m; j++) { token = st.nextToken(); int index; try { index = atoi(token); } catch (NumberFormatException e) { throw new InvalidInputDataException("invalid index: " + token, file, lineNr, e); } // assert that indices are valid and sorted if (index < 0) { throw new InvalidInputDataException("invalid index: " + index, file, lineNr); } if (index <= indexBefore) { throw new InvalidInputDataException("indices must be sorted in ascending order", file, lineNr); } indexBefore = index; token = st.nextToken(); try { double value = atof(token); x[j] = new FeatureNode(index, value); } catch (NumberFormatException e) { throw new InvalidInputDataException("invalid value: " + token, file, lineNr); } } if (m > 0) { max_index = Math.max(max_index, x[m - 1].getIndex()); } vx.add(x); } return constructProblem(vy, vx, max_index, bias); } finally { fp.close(); } } void readProblem(String filename) throws IOException, InvalidInputDataException { prob = Train.readProblem(new File(filename), bias); } private static int[] addToArray(int[] array, int newElement) { int length = array != null ? array.length : 0; int[] newArray = new int[length + 1]; if (array != null && length > 0) { System.arraycopy(array, 0, newArray, 0, length); } newArray[length] = newElement; return newArray; } private static double[] addToArray(double[] array, double newElement) { int length = array != null ? array.length : 0; double[] newArray = new double[length + 1]; if (array != null && length > 0) { System.arraycopy(array, 0, newArray, 0, length); } newArray[length] = newElement; return newArray; } private static Problem constructProblem(List<Double> vy, List<Feature[]> vx, int max_index, double bias) { Problem prob = new Problem(); prob.bias = bias; prob.l = vy.size(); prob.n = max_index; if (bias >= 0) { prob.n++; } prob.x = new Feature[prob.l][]; for (int i = 0; i < prob.l; i++) { prob.x[i] = vx.get(i); if (bias >= 0) { assert prob.x[i][prob.x[i].length - 1] == null; prob.x[i][prob.x[i].length - 1] = new FeatureNode(max_index + 1, bias); } } prob.y = new double[prob.l]; for (int i = 0; i < prob.l; i++) { prob.y[i] = vy.get(i).doubleValue(); } return prob; } private void run(String[] args) throws IOException, InvalidInputDataException { parse_command_line(args); readProblem(inputFilename); if (cross_validation) { do_cross_validation(); } else { Model model = Linear.train(prob, param); Linear.saveModel(new File(modelFilename), model); } } }