// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.core; import java.io.BufferedReader; import java.io.IOException; import java.io.Writer; import java.security.InvalidParameterException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import marmot.util.FileUtils; import marmot.util.Mutable; import marmot.util.StringUtils; public class Options extends java.util.Properties { public static final long serialVersionUID = 1L; public static final String BEAM_SIZE = "beam-size"; public static final String ORDER = "order"; public static final String PRUNE = "prune"; public static final String NUM_ITERATIONS = "num-iterations"; public static final String PENALTY = "penalty"; public static final String PROB_THRESHOLD = "prob-threshold"; public static final String SHUFFLE = "shuffle"; public static final String CANDIDATES_PER_STATE = "candidates-per-state"; public static final String EFFECTIVE_ORDER = "effective-order"; public static final String VECTOR_SIZE = "initial-vector-size"; public static final String VERBOSE = "verbose"; public static final String QUADRATIC_PENALTY = "quadratic-penalty"; public static final String ORACLE = "oracle"; public static final String MAX_TRANSITION_FEATURE_LEVEL = "max-transition-feature-level"; public static final String VERY_VERBOSE = "very-verbose"; public static final String TRAINER = "trainer"; public static final String AVERAGING = "averaging"; public static final String SEED = "seed"; public static final String OPTIMIZE_NUM_ITERATIONS = "optimize-num-iterations"; private static final Map<String, String> DEFALUT_VALUES_ = new HashMap<String, String>(); private static final Map<String, String> COMMENTS_ = new HashMap<String, String>(); static { DEFALUT_VALUES_.put(BEAM_SIZE, "1"); COMMENTS_ .put(BEAM_SIZE, "Specify the beam size of the n-best decoder."); DEFALUT_VALUES_.put(ORDER, "2"); COMMENTS_.put(ORDER, "Set the model order."); DEFALUT_VALUES_.put(PRUNE, "true"); COMMENTS_.put(PRUNE, "Whether to use pruning."); DEFALUT_VALUES_.put(NUM_ITERATIONS, "10"); COMMENTS_.put(NUM_ITERATIONS, "Number of training iterations."); DEFALUT_VALUES_.put(PENALTY, "0.0"); COMMENTS_.put(PENALTY, "L1 penalty parameter."); DEFALUT_VALUES_.put(PROB_THRESHOLD, "0.01"); COMMENTS_ .put(PROB_THRESHOLD, "Initial pruning threshold. Changing this value should have almost no effect."); DEFALUT_VALUES_.put(SHUFFLE, "true"); COMMENTS_.put(SHUFFLE, "Whether to shuffle between training iterations."); DEFALUT_VALUES_.put(CANDIDATES_PER_STATE, "[4, 2, 1.5]"); COMMENTS_ .put(CANDIDATES_PER_STATE, "Average number of states to obtain after pruning at each order. These are the mu values from the paper."); DEFALUT_VALUES_.put(EFFECTIVE_ORDER, "1"); COMMENTS_.put(EFFECTIVE_ORDER, "Maximal order to reach before increasing the level."); DEFALUT_VALUES_.put(VECTOR_SIZE, "10000000"); COMMENTS_.put(VECTOR_SIZE, "Size of the weight vector."); DEFALUT_VALUES_.put(VERBOSE, "false"); COMMENTS_.put(VERBOSE, "Whether to print status messages."); DEFALUT_VALUES_.put(QUADRATIC_PENALTY, "0.0"); COMMENTS_.put(QUADRATIC_PENALTY, "L2 penalty parameter."); DEFALUT_VALUES_.put(ORACLE, "false"); COMMENTS_ .put(ORACLE, "Whether to do oracle pruning. Probably not relevant. Have a look at the paper!"); DEFALUT_VALUES_.put(MAX_TRANSITION_FEATURE_LEVEL, "-1"); COMMENTS_.put(MAX_TRANSITION_FEATURE_LEVEL, "Something for testing the code. Don't change it."); DEFALUT_VALUES_.put(VERY_VERBOSE, "false"); COMMENTS_.put(VERY_VERBOSE, "Whether to print a lot of status messages."); DEFALUT_VALUES_.put(TRAINER, CrfTrainer.class.getCanonicalName()); COMMENTS_ .put(TRAINER, "Which trainer to use. (There is also a perceptron trainer but don't use it.)"); DEFALUT_VALUES_.put(AVERAGING, "true"); COMMENTS_.put(AVERAGING, "Whether to use averaging. Perceptron only!"); DEFALUT_VALUES_.put(SEED, "42"); COMMENTS_.put(SEED, "Random seed to use for shuffling. 0 for nondeterministic seed"); DEFALUT_VALUES_.put(OPTIMIZE_NUM_ITERATIONS, "false"); COMMENTS_.put(OPTIMIZE_NUM_ITERATIONS, "Whether to optimize the number of training iterations on the dev set."); } public Options() { super(); putAll(DEFALUT_VALUES_); } public Options(Options options) { this(); putAll(options); } public String toSimpleString() { String string = ""; Set<Object> key_set = keySet(); List<String> key_list = new ArrayList<String>(key_set.size()); for (Object key : keySet()) { key_list.add((String) (key)); } Collections.sort(key_list); for (String key : key_list) { String value = getProperty(key); string += String.format("%s = %s\n", key, value); } return string; } public void writePropertiesToFile(String filename) { try { Writer writer = FileUtils.openFileWriter(filename); writer.write(toSimpleString()); writer.close(); } catch (IOException e) { throw new RuntimeException(e); } } public void setPropertiesFromFile(String filename) { try { BufferedReader reader = FileUtils.openFile(filename); setPropertiesFromReader(reader); reader.close(); } catch (IOException e) { throw new RuntimeException(e); } } private String normalizeOption(String option) { return option.trim().replace("_", "-").toLowerCase(); } public void setPropertiesFromReader(BufferedReader reader) { Pattern p = Pattern.compile("([^:=]*)[:=](.*)"); try { while (reader.ready()) { String line = reader.readLine(); line = line.trim(); if (line.length() == 0) { continue; } Matcher m = p.matcher(line); if (!m.matches()) { throw new RuntimeException(String.format( "Invalid line: %s\n", line)); } String key = normalizeOption(m.group(1)); if (!this.containsKey(key)) { throw new RuntimeException(String.format( "Unknown property: %s\n", key)); } String value = m.group(2).trim(); if (value.endsWith(";")) { value = value.substring(0, value.length() - 1); } if (value.endsWith("\"") && value.startsWith("\"")) { value = value.substring(1, value.length() - 1); } value = new String(value); this.setProperty(key, value); } } catch (IOException e) { throw new RuntimeException(e); } } private static final Pattern OPTION_PATTERN = Pattern.compile("-*(.*)"); public void setPropertiesFromStrings(String[] args) { int index = 0; while (index < args.length) { String option = args[index++]; Matcher m = OPTION_PATTERN.matcher(option); if (!m.matches()) { throw new RuntimeException("Unexpected argument: " + option + ". Missing '-'?"); } option = normalizeOption(m.group(1)); if (option.equalsIgnoreCase("props")) { checkBoundaries(index, args); setPropertiesFromFile(args[index++]); } else if (containsKey(option)) { checkBoundaries(index, args); this.setProperty(option, args[index++]); } else { usage(); throw new RuntimeException(String.format( "Unknown property: %s\n", option)); } } if (getVerbose()) { System.err.print(toString()); } } @Override public String toString() { StringBuilder sb = new StringBuilder(); for (Map.Entry<Object, Object> prop : entrySet()) { sb.append(prop.getKey() + ": " + prop.getValue()); sb.append("\n"); } return sb.toString(); } private void checkBoundaries(int index, String[] args) { if (index >= args.length) { throw new RuntimeException("Missing argument"); } } public void dieIfPropertyIsEmpty(String property) { if (getProperty(property).isEmpty()) { usage(); System.err.format("Error: Property '%s' needs to be set!\n", property); System.exit(1); } } protected void usage(Map<String, String> defaults, Map<String, String> comments) { for (Map.Entry<String, String> entry : defaults.entrySet()) { System.err.format("\t%s:\n", entry.getKey()); String comment = comments.get(entry.getKey()); assert comment != null; System.err.format("\t\t%s\n", comment); System.err.format("\t\tDefault value: \"%s\"\n", entry.getValue() .replaceAll("\\\\", "\\\\\\\\")); } } protected void usage() { System.err.println("General Options:"); usage(DEFALUT_VALUES_, COMMENTS_); System.err.println(); } public boolean getPrune() { return Boolean.parseBoolean(getProperty(PRUNE)); } public int getBeamSize() { return Integer.parseInt(getProperty(BEAM_SIZE)); } public int getOrder() { return Integer.parseInt(getProperty(ORDER)); } public int getNumIterations() { return Integer.parseInt(getProperty(NUM_ITERATIONS)); } public double getPenalty() { return Double.parseDouble(getProperty(PENALTY)); } public double getProbThreshold() { return Double.parseDouble(getProperty(PROB_THRESHOLD)); } public boolean getShuffle() { return Boolean.parseBoolean(getProperty(SHUFFLE)); } public double[] getCandidatesPerState() { double[] array = StringUtils .parseDoubleArray(getProperty(CANDIDATES_PER_STATE), new Mutable<Integer>(0)); for (double element : array) { if (element < 1.0) { throw new InvalidParameterException("Candidates per state must be >= 1.0: " + getProperty(CANDIDATES_PER_STATE)); } } return array; } public int getEffectiveOrder() { return Integer.parseInt(getProperty(EFFECTIVE_ORDER)); } public int getInitialVectorSize() { return (int) Double.parseDouble(getProperty(VECTOR_SIZE)); } public boolean getVerbose() { return Boolean.parseBoolean(getProperty(VERBOSE)); } public double getQuadraticPenalty() { return Double.parseDouble(getProperty(QUADRATIC_PENALTY)); } public boolean getOracle() { return Boolean.parseBoolean(getProperty(ORACLE)); } public int getMaxTransitionFeatureLevel() { return Integer.parseInt(getProperty(MAX_TRANSITION_FEATURE_LEVEL)); } public boolean getVeryVerbose() { return Boolean.parseBoolean(getProperty(VERY_VERBOSE)); } public String getTrainer() { return getProperty(TRAINER); } public boolean getAveraging() { return Boolean.parseBoolean(getProperty(AVERAGING)); } public long getSeed() { return Long.parseLong(getProperty(SEED)); } public boolean getOptimizeNumIterations() { return Boolean.parseBoolean(getProperty(OPTIMIZE_NUM_ITERATIONS)); } }