package edu.stanford.nlp.parser.lexparser; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.io.EncodingPrintWriter; import edu.stanford.nlp.ling.TaggedWord; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.util.Index; public class ArabicUnknownWordModelTrainer extends AbstractUnknownWordModelTrainer { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(BaseUnknownWordModelTrainer.class); // Records the number of times word/tag pair was seen in training data. ClassicCounter<IntTaggedWord> seenCounter; ClassicCounter<IntTaggedWord> unSeenCounter; double indexToStartUnkCounting; private static final boolean DOCUMENT_UNKNOWNS = false; // if UNK were a word, counts would merge private static final String UNKNOWN_WORD = "UNK"; // boundary tag -- assumed not a real tag private static final String BOUNDARY_TAG = ".$$."; UnknownWordModel model; @Override public void initializeTraining(Options op, Lexicon lex, Index<String> wordIndex, Index<String> tagIndex, double totalTrees) { super.initializeTraining(op, lex, wordIndex, tagIndex, totalTrees); this.totalTrees = totalTrees; indexToStartUnkCounting = (totalTrees * op.trainOptions.fractionBeforeUnseenCounting); seenCounter = new ClassicCounter<>(20000); unSeenCounter = new ClassicCounter<>(20000); model = new ArabicUnknownWordModel(op, lex, wordIndex, tagIndex, unSeenCounter); if (DOCUMENT_UNKNOWNS) { log.info("Collecting " + UNKNOWN_WORD + " from trees " + (indexToStartUnkCounting + 1) + " to " + totalTrees); } } /** * Trains this lexicon on the Collection of trees. */ @Override public void train(TaggedWord tw, int loc, double weight) { IntTaggedWord iTW = new IntTaggedWord(tw.word(), tw.tag(), wordIndex, tagIndex); IntTaggedWord iT = new IntTaggedWord(nullWord, iTW.tag); IntTaggedWord iW = new IntTaggedWord(iTW.word, nullTag); seenCounter.incrementCount(iW, weight); IntTaggedWord i = NULL_ITW; if (treesRead > indexToStartUnkCounting) { // start doing this once some way through trees; // treesRead is 1 based counting if (seenCounter.getCount(iW) < 2) { // it's an entirely unknown word int s = model.getSignatureIndex(iTW.word, loc, wordIndex.get(iTW.word)); if (DOCUMENT_UNKNOWNS) { String wStr = wordIndex.get(iTW.word); String tStr = tagIndex.get(iTW.tag); String sStr = wordIndex.get(s); EncodingPrintWriter.err.println("Unknown word/tag/sig:\t" + wStr + '\t' + tStr + '\t' + sStr, "UTF-8"); } IntTaggedWord iTS = new IntTaggedWord(s, iTW.tag); IntTaggedWord iS = new IntTaggedWord(s, nullTag); unSeenCounter.incrementCount(iTS, weight); unSeenCounter.incrementCount(iT, weight); unSeenCounter.incrementCount(iS, weight); unSeenCounter.incrementCount(i, weight); } // else { } } @Override public UnknownWordModel finishTraining() { // make sure the unseen counter isn't empty! If it is, put in // a uniform unseen over tags if (unSeenCounter.isEmpty()) { int numTags = tagIndex.size(); for (int tt = 0; tt < numTags; tt++) { if ( ! BOUNDARY_TAG.equals(tagIndex.get(tt))) { IntTaggedWord iT = new IntTaggedWord(nullWord, tt); IntTaggedWord i = NULL_ITW; unSeenCounter.incrementCount(iT); unSeenCounter.incrementCount(i); } } } // index the possible tags for each word // numWords = wordIndex.size(); // unknownWordIndex = wordIndex.indexOf(Lexicon.UNKNOWN_WORD, true); // initRulesWithWord(); return model; } }