// Copyright 2014 Thomas Müller // This file is part of HMMLA, which is licensed under GPLv3. package hmmla; import hmmla.hmm.Model; import hmmla.io.PosFileOptions; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; public class Properties extends java.util.Properties { public static final String SEED = "seed"; public static final String NUM_THREADS = "num-threads"; public static final String RANDOMNESS = "randomness"; public static final String MERGE_FACTOR = "merge-factor"; public static final String EM_STEPS = "em-steps"; public static final String NUM_TAGS = "num-tags"; public static final String EXACT_LOSS = "exact-loss"; public static final String MERGE = "merge"; public static final String SAMPLE = "sample"; public static final String SAMPLING_FRACTION = "sampling-fraction"; public static final String TRAIN_FILE = "train-file"; public static final String TEST = "test-file"; public static final String TEST_FILE = "test-file"; public static final String HMM_TRAINER = "hmm-trainer"; public static final String IS_RARE_THRESHOLD = "is-rare-threshold"; public static final String UNIVERSAL_POS_FILE = "universal-pos-file"; public static final String LANGUAGE = "language"; public static final String UNIVERSAL_POS = "universal-pos"; public static final String MODEL_NAME = "model-name"; public static final String SMOOTHER = "smoother"; public static final String REFINE = "refine"; public static final String COARSE_DECODER = "coarse-decoder"; public static final String DUMP_INTERMEDIATE_MODEL = "dump-intermediate-model"; public static final String PRED_FILE = "pred-file"; private static final long serialVersionUID = 1L; private static HashMap<String, String> defaultValues = new HashMap<>(); private static Map<String, String> comments = new HashMap<>(); static { defaultValues.put(NUM_THREADS, "1"); comments.put(NUM_THREADS, "Number of threas"); defaultValues.put(RANDOMNESS, "0.1"); comments.put(RANDOMNESS, "Randomness (cf. to the paper)"); defaultValues.put(MERGE_FACTOR, "0.75"); comments.put(MERGE_FACTOR, "Merge factor (cf. to the paper)"); defaultValues.put(EM_STEPS, "10"); comments.put(EM_STEPS, "Number of EM steps"); defaultValues.put(EXACT_LOSS, "false"); comments.put(EXACT_LOSS, "Use exact loss. This is a testing option."); defaultValues.put(MERGE, "true"); comments.put(MERGE, "Whether to merge."); defaultValues.put(SAMPLE, "true"); comments.put(SAMPLE, "Whether to sample. (Uses different parts of the training set at every EM step)"); defaultValues.put(SAMPLING_FRACTION, "0.1"); comments.put(SAMPLING_FRACTION, "Sampling fraction. See option " + SAMPLE); defaultValues.put(HMM_TRAINER, "signaturehmmtrainer"); comments.put(HMM_TRAINER, "Which trainer to use: signaturehmmtrainer or simplehmmtrainer"); defaultValues.put(IS_RARE_THRESHOLD, "5"); comments.put(IS_RARE_THRESHOLD, "Word form rareness threshold."); defaultValues.put(LANGUAGE, "none"); comments.put(LANGUAGE, "To unable language specific behavior."); defaultValues.put(UNIVERSAL_POS, "false"); comments.put(UNIVERSAL_POS, "Use universal POS."); defaultValues.put(SMOOTHER, "wb"); comments.put(SMOOTHER, "Smoother to use: Can be none, linear(x) or wb. Where x is a real number with 0 < x < 1"); defaultValues.put(REFINE, "false"); comments.put(REFINE, "Refine. (cf. to the paper)"); defaultValues.put(COARSE_DECODER, "false"); comments.put(COARSE_DECODER, "Use a coarse to fine decoder."); defaultValues.put(TEST, "true"); comments.put(TEST, "Run test. Needs option: " + TEST_FILE); defaultValues.put(DUMP_INTERMEDIATE_MODEL, "false"); comments.put(DUMP_INTERMEDIATE_MODEL, "Write a model after each iteration"); defaultValues.put(SEED, ""); comments.put(SEED, "Random seed to use. Empty for random"); defaultValues.put(TRAIN_FILE, ""); comments.put(TRAIN_FILE, "Train file"); defaultValues.put(TEST_FILE, ""); comments.put(TEST_FILE, "Test file"); defaultValues.put(PRED_FILE, ""); comments.put(PRED_FILE, "Pred file"); defaultValues.put(NUM_TAGS, ""); comments.put(NUM_TAGS, "Number of tags to induce. Define the number of iterations"); defaultValues.put(UNIVERSAL_POS_FILE, ""); comments.put(UNIVERSAL_POS_FILE, "File containing mapping of treebank POS to universal POS"); defaultValues.put(MODEL_NAME, ""); comments.put(MODEL_NAME, "Model file to store the model to"); } public Properties() { super(); putAll(defaultValues); } public long getSeed() { return Long.parseLong(getProperty(SEED)); } public int getNumThreads() { return Integer.parseInt(getProperty(NUM_THREADS)); } public double getRandomness() { return Double.parseDouble(getProperty(RANDOMNESS)); } public double getMergeFactor() { return Double.parseDouble(getProperty(MERGE_FACTOR)); } public int getEmSteps() { return Integer.parseInt(getProperty(EM_STEPS)); } public int getNumTags() { return Integer.parseInt(getProperty(NUM_TAGS)); } public boolean getExactLoss() { return Boolean.parseBoolean(getProperty(EXACT_LOSS)); } public boolean getMerge() { return Boolean.parseBoolean(getProperty(MERGE)); } public boolean getSample() { return Boolean.parseBoolean(getProperty(SAMPLE)); } public double getSamplingFraction() { return Double.parseDouble(getProperty(SAMPLING_FRACTION)); } public String getTrainFile() { return getProperty(TRAIN_FILE); } public String getTestFile() { return getProperty(TEST_FILE); } public String toSimpleString() { String string = ""; Set<Object> key_set = keySet(); List<String> key_list = new ArrayList<String>(key_set.size()); for (Object key : keySet()) { key_list.add((String) (key)); } Collections.sort(key_list); for (String key : key_list) { String value = getProperty(key); string += String.format("%s = %s\n", key, value); } return string; } public String getHmmTrainer() { return getProperty(HMM_TRAINER); } public int getIsRareThreshold() { return Integer.parseInt(getProperty(IS_RARE_THRESHOLD)); } public String getUniversalPosFile() { return getProperty(UNIVERSAL_POS_FILE); } public void writePropertiesToFile(String filename) { try { BufferedWriter writer = new BufferedWriter(new FileWriter(filename)); writer.write(toSimpleString()); writer.close(); } catch (IOException e) { throw new RuntimeException(e); } } public void setPropertiesFromFile(String filename) { try { BufferedReader reader = new BufferedReader(new FileReader(filename)); setPropertiesFromReader(reader); reader.close(); } catch (FileNotFoundException e) { throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } } private String normalizeOption(String option) { return option.trim().replace("_", "-").toLowerCase(); } public void setPropertiesFromReader(BufferedReader reader) { Pattern p = Pattern.compile("([^:=]*)[:=](.*)"); try { while (reader.ready()) { String line = reader.readLine(); line = line.trim(); if (line.length() == 0) { continue; } Matcher m = p.matcher(line); if (!m.matches()) { throw new RuntimeException(String.format( "Invalid line: %s\n", line)); } String key = normalizeOption(m.group(1)); if (!this.containsKey(key)) { throw new RuntimeException(String.format( "Unknown property: %s\n", key)); } String value = m.group(2).trim(); if (value.endsWith(";")) { value = value.substring(0, value.length() - 1); } if (value.endsWith("\"") && value.startsWith("\"")) { value = value.substring(1, value.length() - 1); } value = new String(value); this.setProperty(key, value); } } catch (IOException e) { throw new RuntimeException(e); } } public boolean getUniversalPos() { return Boolean.parseBoolean(getProperty(UNIVERSAL_POS)); } public String getLanguage() { return getProperty(LANGUAGE); } public String getModelFile() { return getProperty(MODEL_NAME); } public String getSmoother() { return getProperty(SMOOTHER); } public boolean getRefine() { return Boolean.parseBoolean(getProperty(REFINE)); } public boolean getCoarseDecoder() { return Boolean.parseBoolean(getProperty(COARSE_DECODER)); } private static final Pattern OPTION_PATTERN = Pattern.compile("-*(.*)"); public void setPropertiesFromStrings(String[] args) { int index = 0; while (index < args.length) { String option = args[index++]; Matcher m = OPTION_PATTERN.matcher(option); if (!m.matches()) { throw new RuntimeException("Unexpected argument: " + option + ". Missing '-'?"); } option = normalizeOption(m.group(1)); if (option.equalsIgnoreCase("props")) { checkBoundaries(index, args); setPropertiesFromFile(args[index++]); } else if (this.containsKey(option)) { checkBoundaries(index, args); this.setProperty(option, args[index++]); } else { throw new RuntimeException(String.format("Unknown property: %s\n", option)); } } } private void checkBoundaries(int index, String[] args) { if (index >= args.length) { throw new RuntimeException("Missing argument"); } } public void check(String class_name) { if (class_name.equals(Tagger.class.getSimpleName()) || getTest()) { checkNotEmpty(TEST_FILE); PosFileOptions options = new PosFileOptions(getProperty(TEST_FILE)); checkFileExists(options.getFile()); if (getTest() || getRefine()) { if (options.getTagIndex() < 0) { throw new RuntimeException("No tag index specified in %s"+ getProperty(TEST_FILE) +"!"); } } } checkNotEmpty(MODEL_NAME); if (getProperty(Properties.SEED).isEmpty()) { long seed = hmmla.util.Random.getRandomSeed(); setProperty(Properties.SEED, Long.toString(seed)); } if (class_name.equals(Trainer.class.getSimpleName())) { checkNotEmpty(NUM_TAGS); checkNotEmpty(TRAIN_FILE); checkFileExists(new PosFileOptions(getProperty(TRAIN_FILE)).getFile()); } if (getUniversalPos()) { checkNotEmpty(UNIVERSAL_POS_FILE); checkFileExists(new File(getProperty(UNIVERSAL_POS_FILE))); } } private void checkNotEmpty(String option) { if (getProperty(option).isEmpty()) { throw new RuntimeException("Property: \"" + option + "\" has to be set!"); } } private void checkFileExists(File file) { if (!file.canRead()) { throw new RuntimeException("Can't read from: " + file.getAbsolutePath()); } } public String getIntermediateModelName(Model model) { StringBuffer sb = new StringBuffer(); String model_name = getModelFile(); int index = model_name.lastIndexOf('.'); String extension; if (index >= 0) { extension = model_name.substring(index + 1, model_name.length()); sb.append(model_name.substring(0, index)); } else { sb.append(model_name); extension = "tagger"; } sb.append('_'); int needed_digits = (int) Math.ceil(Math.log10(model.getProperties().getNumTags())); int current_digits = (int) Math.ceil(Math.log10(model.getNumTags())); for (int i=current_digits; i<=needed_digits; i++) { sb.append('0'); } sb.append(model.getNumTags()); if (!extension.isEmpty()) { sb.append('.'); sb.append(extension); } return sb.toString(); } public boolean getTest() { return Boolean.parseBoolean(getProperty(TEST)); } public boolean getDumpIntermediateModels() { return Boolean.parseBoolean(getProperty(DUMP_INTERMEDIATE_MODEL)); } public String getPredFile() { return getProperty(PRED_FILE); } public void usage() { System.err.println("General Options:"); usage(defaultValues, comments); System.err.println(); } protected void usage(Map<String, String> defaults, Map<String, String> comments) { for (Map.Entry<String, String> entry : defaults.entrySet()) { System.err.format("\t%s:\n", entry.getKey()); String comment = comments.get(entry.getKey()); assert comment != null; System.err.format("\t\t%s\n", comment); System.err.format("\t\tDefault value: \"%s\"\n", entry.getValue() .replaceAll("\\\\", "\\\\\\\\")); } } }