package hu.u_szeged.kpe; import hu.u_szeged.kpe.candidates.NGram; import hu.u_szeged.kpe.main.ExtractionModelBuilder; import hu.u_szeged.kpe.main.KPEFilter; import hu.u_szeged.kpe.main.KeyPhraseExtractor; import hu.u_szeged.kpe.readers.DocumentSet; import hu.u_szeged.kpe.readers.KpeReader; import hu.u_szeged.ml.mallet.MalletClassifier; import hu.u_szeged.utils.NLPUtils; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; public class KpeMain { private int totalFolds; private int actualFold; private String mode; private List<String> selectedFeatures; private boolean wordNetUsage; private boolean noStopWordPruning; private boolean noPosEndingPruning; private boolean mweFeatureIsOn; private boolean neFeatureIsOn; private boolean syntacticFeatureIsOn; private ExtractionModelBuilder km; private KeyPhraseExtractor ke; public static long time; public KpeMain(int phrases2Return, boolean prune, int featureCoding, boolean[] prunings, String wnLocation) { wordNetUsage = NGram.initWordNet(wnLocation); noPosEndingPruning = prunings[0]; noStopWordPruning = prunings[1]; km = new ExtractionModelBuilder(); km.setMaxPhraseLength(4); km.setMinNumOccur(1); // km.setUseSynonyms(useSynonyms[0]); ke = new KeyPhraseExtractor(); ke.setNumPhrases(phrases2Return); ke.setPrune(prune); // ke.setUseSynonyms(useSynonyms[1]); List<String> featureClasses = new LinkedList<>(); NLPUtils.readDocToCollection(System.getProperty("user.dir") + "/resources/features", featureClasses); selectedFeatures = new LinkedList<String>(); mode = extendToDesiredLength(Integer.toBinaryString(featureCoding), featureClasses.size()); // mode = featureCoding; String escapedFeatureEncoding = mode.replaceAll("_", ""); int c = 0; for (Iterator<String> it = featureClasses.iterator(); it.hasNext(); ++c) { String feature = it.next().split("\t")[1]; if (escapedFeatureEncoding.charAt(c) == '1') { if (feature.equals("MweFeature")) { mweFeatureIsOn = true; } else if (feature.equals("StrangeOrthographyFeature")) { neFeatureIsOn = true; } selectedFeatures.add(feature); } } System.err.println("Features used:\t" + selectedFeatures); } public boolean getNoStopWordPruning() { return noStopWordPruning; } public boolean getNoPosEndingPruning() { return noPosEndingPruning; } private List<String> parseLine(String line) { List<String> readerAndLocations = new ArrayList<String>(); int arrowIndex = line.indexOf("->"); if (line.equals("null") || arrowIndex == -1) return readerAndLocations; String reader = line.substring(0, arrowIndex); readerAndLocations.add(reader); String[] differentPaths = line.substring(arrowIndex + 2).split("\\|"); for (String path : differentPaths) { readerAndLocations.add(path); } return readerAndLocations; } private Map<Boolean, List<String>> parseReaderSettings(String trainingData, String testData) { Map<Boolean, List<String>> readerSettings = new HashMap<Boolean, List<String>>(); readerSettings.put(true, parseLine(trainingData)); readerSettings.put(false, parseLine(testData)); return readerSettings; } /** * @param trainReader * @param testReader * @param trainLocs * @param testLocs * @param foldNum * @return returns the fact whether a new model is going to be trained */ public boolean setReaders(String trainingData, String testData, int foldNum, boolean[] goldAnn, int adaptation, boolean serialize, String lang) { totalFolds = foldNum; Map<Boolean, List<String>> readerSettings = parseReaderSettings(trainingData, testData); for (Entry<Boolean, List<String>> entry : readerSettings.entrySet()) { try { Boolean train = entry.getKey(); String readerName = entry.getValue().get(0); if ((readerName.equals("null") || entry.getValue().size() < 2) && train) { continue; } KpeReader reader = Class.forName("hu.u_szeged.kpe.readers." + readerName).asSubclass(KpeReader.class) .newInstance(); reader.initGrammar(mweFeatureIsOn, neFeatureIsOn, syntacticFeatureIsOn, lang); DocumentSet ds = new DocumentSet(adaptation, reader); reader.setUseGoldAnnotation(train ? goldAnn[0] : goldAnn[1]); for (int i = 1; i < entry.getValue().size(); ++i) { String path = entry.getValue().get(i); ds.setBaseDir(path); if (serialize) { System.err .println("Note that (due to the config parameters) files containing serializations of grammatic analysis will be saved to location: " + path + "grammar/"); } reader.addDirectoryOfFiles(path, train, ds); } if (train) { km.setDocSet(ds); } else { if (ds.size() == 0) { System.err.println("No test documents were added. The program will exit now."); System.exit(2); } ke.setDocSet(ds); } } catch (Exception e) { e.printStackTrace(); } } return km.getDocSet() != null; } protected void setFold(int fold) { actualFold = fold; } protected void createModel(String classifier, boolean synonyms, boolean goldAnn, boolean[] employBIESmarkup, double commonWords, double selectedFeatureRatio, String[] loc, boolean serialize) throws Exception { String modelName = "models/" + loc[0] + "/" + (loc[1] != null ? loc[1] + "/" : "") + mode + "_" + (selectedFeatureRatio < 1.0d ? "fs_" + selectedFeatureRatio + "_" : "") + (employBIESmarkup[0] ? "BIES_pos_" : "") + (employBIESmarkup[1] ? "BIES_ne_" : "") + (employBIESmarkup[2] ? "BIES_suffix_" : "") + (wordNetUsage ? "wn_" : "") + (noStopWordPruning ? "" : "sw_") + (noPosEndingPruning ? "" : "pos_") + (classifier.equals("MaxEntL1") ? "" : "_" + classifier) + (totalFolds > 1 ? "fold" + actualFold + "_" : "") + goldAnn + ".model"; System.err.println(modelName); String log = km.buildModel(actualFold, totalFolds, selectedFeatures, classifier, commonWords, selectedFeatureRatio, employBIESmarkup, ke.getDocSet(), noStopWordPruning, noPosEndingPruning, serialize); KPEFilter kf = km.getKPEFilter(); if (kf.getClassifierName().contains("MaxEnt")) { new File("models/" + loc[0]).mkdirs(); PrintWriter logger = new PrintWriter(modelName + "_statistics.txt"); logger.println(new Date()); logger.println(log); ((MalletClassifier) kf.getModel()).printModel(logger, 50); logger.close(); } NLPUtils.serialize(km.getKPEFilter(), modelName); } private void extractKeyphrases(String classifier, boolean synonyms, boolean[] goldAnn, boolean[] employBIESmarkup, double selectedFeatureRatio, String[] loc, boolean serialize) { String modelName = "models/" + loc[0] + "/" + (loc[1] != null ? loc[1] + "/" : "") + mode + "_" + (selectedFeatureRatio < 1.0d ? "fs_" + selectedFeatureRatio + "_" : "") + (employBIESmarkup[0] ? "BIES_pos_" : "") + (employBIESmarkup[1] ? "BIES_ne_" : "") + (employBIESmarkup[2] ? "BIES_suffix_" : "") + (wordNetUsage ? "wn_" : "") + (noStopWordPruning ? "" : "sw_") + (noPosEndingPruning ? "" : "pos_") + (classifier.equals("MaxEntL1") ? "" : "_" + classifier) + (totalFolds > 1 ? "fold" + actualFold + "_" : "") + goldAnn[0] + ".model"; try { if (!new File(modelName).exists()) { System.err.println("The desired model (" + modelName + ") cannot be found on the computer, the config needs to be modified in order to generate it."); System.exit(1); } System.err.println("Extrcation of keyphrases begins. Output will be located at ./models/" + loc[0] + "/ directory."); ke.loadModel(modelName); ke.extractKeyphrases(actualFold, totalFolds, modelName.replace(".model", "_" + goldAnn[1] + ".out"), serialize); } catch (Exception e) { e.printStackTrace(); } } private String extendToDesiredLength(String toExtend, int length) { return toExtend.length() < length ? extendToDesiredLength("0" + toExtend, length) : toExtend; } private static String[] processParamFile(String file) { String[] options = { "train", "test", "classifier", "featureEncoding", "numOfKeyphrases", "wordnet_dir", "posEndPrune", "stopWordPrune", "beisMarkup", "serializeAnnotations", "numOfFolds", "language" }; List<String> lines = new LinkedList<String>(); NLPUtils.readDocToCollection(file, lines); String[] newArgs = new String[options.length]; for (String line : lines) { int commentIndex = line.indexOf("//"); if (commentIndex != -1) { line = line.substring(0, commentIndex); } if (line.length() == 0) continue; String[] parts = line.split("="); for (int p = 0; p < options.length; ++p) { if (options[p].equalsIgnoreCase(parts[0].trim())) { newArgs[p] = parts[1].trim(); break; } } } return newArgs; } public static void main(String[] args) { String modelPrefix = null; if (args[0].equalsIgnoreCase("-paramFile")) { if (args.length > 2) { modelPrefix = args[2]; } System.err.println("Configuration read from config file: " + args[1]); args = processParamFile(args[1]); } String[] params = { "Reader class and location(s) for training:", "Reader class and location(s) for testing:", "Classifier to use:", "Feature encoding to use:", "Number of keyphrases to extract?", "Location of WordNet dict directory (or type in 'FALSE' in case you do not wish to use it)?", "pos ending-based candidate phrase pruning _not_ to be used?", "stopword candidate phrase pruning _not_ to be used?", "B(egin)I(nside)E(nd)S(ingle) feature markup?", "serialize grammar files?", "Number of test folds?", "What is the language of the input data ('en' or 'hu')?" }; args = Arrays.copyOf(args, params.length); for (int i = 0; i < args.length; ++i) { if (args[i] == null) { try { System.err.print(params[i] + "\t"); BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); args[i] = br.readLine(); } catch (IOException e) { e.printStackTrace(); } } } for (int i = 0; i < args.length; ++i) { System.err.println(params[i] + "\t--->\t" + args[i]); } String locationPrefix = args[1].replaceAll("->.*[\\/]?", ""); String classifier = args[2]; String[] featureCodings = args[3].split(","); int numOfKeyphrases = Integer.parseInt(args[4]); String wordNetParameter = args[5]; boolean[] filtration = { Boolean.parseBoolean(args[6]), Boolean.parseBoolean(args[7]) }; String[] beisSettings = args[8].split(","); boolean[] employBIESmarkup = new boolean[3]; for (int i = 0; i < beisSettings.length; ++i) { employBIESmarkup[i] = Boolean.parseBoolean(beisSettings[i]); } boolean serializeGrammar = Boolean.parseBoolean(args[9]); int numOfFolds = Integer.parseInt(args[10]); String lang = args[11].replaceAll("^[\\p{Punct}]+|[\\p{Punct}]+$", "").toLowerCase(); // these are just dummy, burned-in values in order to neutralize experimental and/or not well-tested // enough features temporarily boolean finalPrune = false; boolean[] goldAnn = { true, true }; boolean[] useSynonyms = new boolean[2]; int adaptation = -1; double selectedFeatureRatio = 1.0d; for (String featureCoding : featureCodings) { int encodedFeatures = Integer.parseInt(featureCoding.trim()); KpeMain kpe = new KpeMain(numOfKeyphrases, finalPrune, encodedFeatures, filtration, wordNetParameter); String trainParameters = args[0], testParameters = args[1]; boolean newModel = kpe.setReaders(trainParameters, testParameters, numOfFolds, goldAnn, adaptation, serializeGrammar, lang); String[] location = { locationPrefix, modelPrefix }; for (int fold = 1; fold <= numOfFolds; ++fold) { System.err.println("Fold #" + fold); time = System.currentTimeMillis(); kpe.setFold(fold); if (newModel) { try { kpe.createModel(classifier, useSynonyms[0], goldAnn[0], employBIESmarkup, 0.1d, selectedFeatureRatio, location, serializeGrammar); } catch (Exception e) { e.printStackTrace(); continue; } if (fold == 1) { System.err.println(encodedFeatures); System.err.println("Reader phrases " + (goldAnn[0] ? "" : "not ") + "used."); System.err.println("WordNet is " + (wordNetParameter.equalsIgnoreCase("false") ? "not " : "") + "used."); System.err.println("POS ending pruning is " + (kpe.getNoPosEndingPruning() ? "not " : "") + "used."); System.err.println("Stopword pruning is " + (kpe.getNoStopWordPruning() ? "not " : "") + "used."); System.err.println(args[1] + " will be keyphrased with a " + classifier + " classifier" + (newModel ? " and a new " + numOfFolds + "-fold model is being created." : ".")); } System.err.println("Model done: " + (System.currentTimeMillis() - time) / 1000.0 + " secs"); } kpe.extractKeyphrases(classifier, useSynonyms[0], goldAnn, employBIESmarkup, selectedFeatureRatio, location, serializeGrammar); System.err.println((System.currentTimeMillis() - time) / 1000.0); } } } }