package edu.stanford.nlp.parser.lexparser; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.classify.LinearClassifier; import edu.stanford.nlp.classify.LinearClassifierFactory; import edu.stanford.nlp.classify.WeightedDataset; import edu.stanford.nlp.io.NumberRangesFileFilter; import edu.stanford.nlp.ling.BasicDatum; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.TaggedWord; import edu.stanford.nlp.optimization.QNMinimizer; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.trees.Treebank; import edu.stanford.nlp.trees.TreebankLanguagePack; import edu.stanford.nlp.util.*; import edu.stanford.nlp.stats.*; import java.io.*; import java.util.*; import java.util.function.Function; import java.util.regex.Pattern; /** * A Lexicon class that computes the score of word|tag according to a maxent model * of tag|word (divided by MLE estimate of P(tag)). * <p/> * It would be nice to factor out a superclass MaxentLexicon that takes a WordFeatureExtractor * * @author Galen Andrew */ public class ChineseMaxentLexicon implements Lexicon { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(ChineseMaxentLexicon.class); private static final long serialVersionUID = 238834703409896852L; private static final boolean verbose = true; public static final boolean seenTagsOnly = false; private ChineseWordFeatureExtractor featExtractor; public static final boolean fixUnkFunctionWords = false; private static final Pattern wordPattern = Pattern.compile(".*-W"); private static final Pattern charPattern = Pattern.compile(".*-.C"); private static final Pattern bigramPattern = Pattern.compile(".*-.B"); private static final Pattern conjPattern = Pattern.compile(".*&&.*"); private final Pair<Pattern, Integer> wordThreshold = new Pair<>(wordPattern, 0); private final Pair<Pattern, Integer> charThreshold = new Pair<>(charPattern, 2); private final Pair<Pattern, Integer> bigramThreshold = new Pair<>(bigramPattern, 3); private final Pair<Pattern, Integer> conjThreshold = new Pair<>(conjPattern, 3); private final List<Pair<Pattern, Integer>> featureThresholds = new ArrayList<>(); private final int universalThreshold = 0; private LinearClassifier scorer; private Map<String, String> functionWordTags = Generics.newHashMap(); private Distribution<String> tagDist; private final Index<String> wordIndex; private final Index<String> tagIndex; private transient Counter<String> logProbs; private double iteratorCutoffFactor = 4; private transient int lastWord = -1; String initialWeightFile = null; boolean trainFloat = false; private static final String featureDir = "gbfeatures"; private double tol = 1e-4; private double sigma = 0.4; static final boolean tuneSigma = false; static final int trainCountThreshold = 5; final int featureLevel; static final int DEFAULT_FEATURE_LEVEL = 2; private boolean trainOnLowCount = false; private boolean trainByType = false; private final TreebankLangParserParams tlpParams; private final TreebankLanguagePack ctlp; private final Options op; public boolean isKnown(int word) { return isKnown(wordIndex.get(word)); } public boolean isKnown(String word) { return tagsForWord.containsKey(word); } /** {@inheritDoc} */ @Override public Set<String> tagSet(Function<String,String> basicCategoryFunction) { Set<String> tagSet = new HashSet<>(); for (String tag : tagIndex.objectsList()) { tagSet.add(basicCategoryFunction.apply(tag)); } return tagSet; } private void ensureProbs(int word) { ensureProbs(word, true); } private void ensureProbs(int word, boolean subtractTagScore) { if (word == lastWord) { return; } lastWord = word; if (functionWordTags.containsKey(wordIndex.get(word))) { logProbs = new ClassicCounter<>(); String trueTag = functionWordTags.get(wordIndex.get(word)); for (String tag : tagIndex.objectsList()) { if (ctlp.basicCategory(tag).equals(trueTag)) { logProbs.setCount(tag, 0); } else { logProbs.setCount(tag, Double.NEGATIVE_INFINITY); } } return; } Datum datum = new BasicDatum(featExtractor.makeFeatures(wordIndex.get(word))); logProbs = scorer.logProbabilityOf(datum); if (subtractTagScore) { Set<String> tagSet = logProbs.keySet(); for (String tag : tagSet) { logProbs.incrementCount(tag, -Math.log(tagDist.probabilityOf(tag))); } } } public CollectionValuedMap<String, String> tagsForWord = new CollectionValuedMap<>(); public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) { ensureProbs(word); List<IntTaggedWord> rules = new ArrayList<>(); if (seenTagsOnly) { String wordString = wordIndex.get(word); Collection<String> tags = tagsForWord.get(wordString); for (String tag : tags) { rules.add(new IntTaggedWord(wordString, tag, wordIndex, tagIndex)); } } else { double max = Counters.max(logProbs); for (int tag = 0; tag < tagIndex.size(); tag++) { IntTaggedWord iTW = new IntTaggedWord(word, tag); double score = logProbs.getCount(tagIndex.get(tag)); if (score > max - iteratorCutoffFactor) { rules.add(iTW); } } } return rules.iterator(); } public Iterator<IntTaggedWord> ruleIteratorByWord(String word, int loc, String featureSpec) { return ruleIteratorByWord(wordIndex.indexOf(word), loc, featureSpec); } /** Returns the number of rules (tag rewrites as word) in the Lexicon. * This method isn't yet implemented in this class. * It currently just returns 0, which may or may not be helpful. */ public int numRules() { int accumulated = 0; for (int w = 0, tot = wordIndex.size(); w < tot; w++) { Iterator<IntTaggedWord> iter = ruleIteratorByWord(w, 0, null); while (iter.hasNext()) { iter.next(); accumulated++; } } return accumulated; } private String getTag(String word) { int iW = wordIndex.addToIndex(word); ensureProbs(iW, false); return Counters.argmax(logProbs); } private void verbose(String s) { if (verbose) { log.info(s); } } public ChineseMaxentLexicon(Options op, Index<String> wordIndex, Index<String> tagIndex, int featureLevel) { this.op = op; this.tlpParams = op.tlpParams; this.ctlp = op.tlpParams.treebankLanguagePack();; this.wordIndex = wordIndex; this.tagIndex = tagIndex; this.featureLevel = featureLevel; if (fixUnkFunctionWords) { String filename = "unknown_function_word-simple.gb"; try { BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "GB18030")); for (String line = in.readLine(); line != null; line = in.readLine()) { String[] parts = line.split("\\s+", 2); functionWordTags.put(parts[0], parts[1]); } } catch (IOException e) { throw new RuntimeException("Couldn't read function word file " + filename); } } } // only used at training time transient IntCounter<TaggedWord> datumCounter; @Override public void initializeTraining(double numTrees) { verbose("Training ChineseMaxentLexicon."); verbose("trainOnLowCount = " + trainOnLowCount + ", trainByType = " + trainByType + ", featureLevel = " + featureLevel + ", tuneSigma = " + tuneSigma); verbose("Making dataset..."); if (featExtractor == null) { featExtractor = new ChineseWordFeatureExtractor(featureLevel); } this.datumCounter = new IntCounter<>(); } /** * Add the given collection of trees to the statistics counted. Can * be called multiple times with different trees. */ public final void train(Collection<Tree> trees) { train(trees, 1.0); } /** * Add the given collection of trees to the statistics counted. Can * be called multiple times with different trees. */ @Override public void train(Collection<Tree> trees, double weight) { for (Tree tree : trees) { train(tree, weight); } } /** * Add the given tree to the statistics counted. Can * be called multiple times with different trees. */ @Override public void train(Tree tree, double weight) { train(tree.taggedYield(), weight); } /** * Add the given sentence to the statistics counted. Can * be called multiple times with different sentences. */ @Override public void train(List<TaggedWord> sentence, double weight) { featExtractor.train(sentence, weight); for (TaggedWord word : sentence) { datumCounter.incrementCount(word, weight); tagsForWord.add(word.word(), word.tag()); } } @Override public void trainUnannotated(List<TaggedWord> sentence, double weight) { // TODO: for now we just punt on these throw new UnsupportedOperationException("This version of the parser does not support non-tree training data"); } @Override public void incrementTreesRead(double weight) { throw new UnsupportedOperationException(); } @Override public void train(TaggedWord tw, int loc, double weight) { throw new UnsupportedOperationException(); } @Override public void finishTraining() { IntCounter<String> tagCounter = new IntCounter<>(); WeightedDataset data = new WeightedDataset(datumCounter.size()); for (TaggedWord word : datumCounter.keySet()) { int count = datumCounter.getIntCount(word); if (trainOnLowCount && count > trainCountThreshold) { continue; } if (functionWordTags.containsKey(word.word())) { continue; } tagCounter.incrementCount(word.tag()); if (trainByType) { count = 1; } data.add(new BasicDatum(featExtractor.makeFeatures(word.word()), word.tag()), count); } datumCounter = null; tagDist = Distribution.laplaceSmoothedDistribution(tagCounter, tagCounter.size(), 0.5); tagCounter = null; applyThresholds(data); verbose("Making classifier..."); QNMinimizer minim = new QNMinimizer();//new ResultStoringMonitor(5, "weights")); // minim.shutUp(); LinearClassifierFactory factory = new LinearClassifierFactory(minim); factory.setTol(tol); factory.setSigma(sigma); if (tuneSigma) { factory.setTuneSigmaHeldOut(); } scorer = factory.trainClassifier(data); verbose("Done training."); } private void applyThresholds(WeightedDataset data) { if (wordThreshold.second > 0) { featureThresholds.add(wordThreshold); } if (featExtractor.chars && charThreshold.second > 0) { featureThresholds.add(charThreshold); } if (featExtractor.bigrams && bigramThreshold.second > 0) { featureThresholds.add(bigramThreshold); } if ((featExtractor.conjunctions || featExtractor.mildConjunctions) && conjThreshold.second > 0) { featureThresholds.add(conjThreshold); } int types = data.numFeatureTypes(); if (universalThreshold > 0) { data.applyFeatureCountThreshold(universalThreshold); } if (featureThresholds.size() > 0) { data.applyFeatureCountThreshold(featureThresholds); } int numRemoved = types - data.numFeatureTypes(); if (numRemoved > 0) { verbose("Thresholding removed " + numRemoved + " features."); } } public static void main(String[] args) { TreebankLangParserParams tlpParams = new ChineseTreebankParserParams(); TreebankLanguagePack ctlp = tlpParams.treebankLanguagePack(); Options op = new Options(tlpParams); TreeAnnotator ta = new TreeAnnotator(tlpParams.headFinder(), tlpParams, op); log.info("Reading Trees..."); FileFilter trainFilter = new NumberRangesFileFilter(args[1], true); Treebank trainTreebank = tlpParams.memoryTreebank(); trainTreebank.loadPath(args[0], trainFilter); log.info("Annotating trees..."); Collection<Tree> trainTrees = new ArrayList<>(); for (Tree tree : trainTreebank) { trainTrees.add(ta.transformTree(tree)); } trainTreebank = null; // saves memory log.info("Training lexicon..."); Index<String> wordIndex = new HashIndex<>(); Index<String> tagIndex = new HashIndex<>(); int featureLevel = DEFAULT_FEATURE_LEVEL; if (args.length > 3) { featureLevel = Integer.parseInt(args[3]); } ChineseMaxentLexicon lex = new ChineseMaxentLexicon(op, wordIndex, tagIndex, featureLevel); lex.initializeTraining(trainTrees.size()); lex.train(trainTrees); lex.finishTraining(); log.info("Testing"); FileFilter testFilter = new NumberRangesFileFilter(args[2], true); Treebank testTreebank = tlpParams.memoryTreebank(); testTreebank.loadPath(args[0], testFilter); List<TaggedWord> testWords = new ArrayList<>(); for (Tree t : testTreebank) { for (TaggedWord tw : t.taggedYield()) { testWords.add(tw); } //testWords.addAll(t.taggedYield()); } int[] totalAndCorrect = lex.testOnTreebank(testWords); log.info("done."); System.out.println(totalAndCorrect[1] + " correct out of " + totalAndCorrect[0] + " -- ACC: " + ((double) totalAndCorrect[1]) / totalAndCorrect[0]); } private int[] testOnTreebank(Collection<TaggedWord> testWords) { int[] totalAndCorrect = new int[2]; totalAndCorrect[0] = 0; totalAndCorrect[1] = 0; for (TaggedWord word : testWords) { String goldTag = word.tag(); String guessTag = ctlp.basicCategory(getTag(word.word())); totalAndCorrect[0]++; if (goldTag.equals(guessTag)) { totalAndCorrect[1]++; } } return totalAndCorrect; } public float score(IntTaggedWord iTW, int loc, String word, String featureSpec) { ensureProbs(iTW.word()); double max = Counters.max(logProbs); double score = logProbs.getCount(iTW.tagString(tagIndex)); if (score > max - iteratorCutoffFactor) { return (float) score; } else { return Float.NEGATIVE_INFINITY; } } public void writeData(Writer w) throws IOException { throw new UnsupportedOperationException(); } public void readData(BufferedReader in) throws IOException { throw new UnsupportedOperationException(); } public UnknownWordModel getUnknownWordModel() { // TODO Auto-generated method stub return null; } public void setUnknownWordModel(UnknownWordModel uwm) { // TODO Auto-generated method stub } @Override public void train(Collection<Tree> trees, Collection<Tree> rawTrees) { train(trees); } }