package edu.stanford.nlp.patterns; import java.io.IOException; import java.util.HashSet; import java.util.Set; import java.util.Map.Entry; import edu.stanford.nlp.process.WordShapeClassifier; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.stats.TwoDimensionalCounter; import edu.stanford.nlp.util.ArgumentParser.Option; import edu.stanford.nlp.util.GoogleNGramsSQLBacked; import edu.stanford.nlp.util.logging.Redwood; public abstract class PhraseScorer<E extends Pattern> { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(PhraseScorer.class); ConstantsAndVariables constVars; //these get overwritten in ScorePhrasesLearnFeatWt class double OOVExternalFeatWt = 0.5; double OOVdictOdds = 1e-10; double OOVDomainNgramScore = 1e-10; double OOVGoogleNgramScore = 1e-10; @Option(name = "usePatternWeights") public boolean usePatternWeights = true; @Option(name = "wordFreqNorm") Normalization wordFreqNorm = Normalization.valueOf("LOG"); /** * For phrases, some phrases are evaluated as a combination of their * individual words. Default is taking minimum of all the words. This flag * takes average instead of the min. */ @Option(name = "useAvgInsteadofMinPhraseScoring") boolean useAvgInsteadofMinPhraseScoring = false; public enum Normalization { NONE, SQRT, LOG }; static public enum Similarities{NUMITEMS, AVGSIM, MAXSIM}; public PhraseScorer(ConstantsAndVariables constvar) { this.constVars = constvar; } Counter<CandidatePhrase> learnedScores = new ClassicCounter<>(); abstract Counter<CandidatePhrase> scorePhrases(String label, TwoDimensionalCounter<CandidatePhrase, E> terms, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns, Set<CandidatePhrase> alreadyIdentifiedWords, boolean forLearningPatterns) throws IOException, ClassNotFoundException; Counter<CandidatePhrase> getLearnedScores() { return learnedScores; } double getPatTFIDFScore(CandidatePhrase word, Counter<E> patsThatExtractedThis, Counter<E> allSelectedPatterns) { if(Data.processedDataFreq.getCount(word) == 0.0) { Redwood.log(Redwood.WARN, "How come the processed corpus freq has count of " + word + " 0. The count in raw freq is " + Data.rawFreq.getCount(word) + " and the Data.rawFreq size is " + Data.rawFreq.size()); return 0; } else { double total = 0; Set<E> rem = new HashSet<>(); for (Entry<E, Double> en2 : patsThatExtractedThis.entrySet()) { double weight = 1.0; if (usePatternWeights) { weight = allSelectedPatterns.getCount(en2.getKey()); if (weight == 0){ Redwood.log(Redwood.FORCE, "Warning: Weight zero for " + en2.getKey() + ". May be pattern was removed when choosing other patterns (if subsumed by another pattern)."); rem.add(en2.getKey()); } } total += weight; } Counters.removeKeys(patsThatExtractedThis, rem); double score = total / Data.processedDataFreq.getCount(word); return score; } } public static double getGoogleNgramScore(CandidatePhrase g) { double count = GoogleNGramsSQLBacked.getCount(g.getPhrase().toLowerCase()) + GoogleNGramsSQLBacked.getCount(g.getPhrase()); if (count != -1) { if(!Data.rawFreq.containsKey(g)) //returning 1 because usually lower this tf-idf score the better. if we don't have raw freq info, give it a bad score return 1; else return (1 + Data.rawFreq.getCount(g) * Math.sqrt(Data.ratioGoogleNgramFreqWithDataFreq)) / count; } return 0; } public double getDomainNgramScore(String g) { String gnew = g; if(!Data.domainNGramRawFreq.containsKey(gnew)){ gnew = g.replaceAll(" ",""); } if(!Data.domainNGramRawFreq.containsKey(gnew)){ gnew = g.replaceAll("-",""); }else g = gnew; if(!Data.domainNGramRawFreq.containsKey(gnew)){ log.info("domain count 0 for " + g); return 0; } else g = gnew; return ((1 + Data.rawFreq.getCount(g) * Math.sqrt(Data.ratioDomainNgramFreqWithDataFreq)) / Data.domainNGramRawFreq .getCount(g)); } public double getDistSimWtScore(String ph, String label) { Integer num = constVars.getWordClassClusters().get(ph); if(num == null){ num = constVars.getWordClassClusters().get(ph.toLowerCase()); } if (num != null && constVars.distSimWeights.get(label).containsKey(num)) { return constVars.distSimWeights.get(label).getCount(num); } else { String[] t = ph.split("\\s+"); if (t.length < 2) { return OOVExternalFeatWt; } double totalscore = 0; double minScore = Double.MAX_VALUE; for (String w : t) { double score = OOVExternalFeatWt; Integer numw = constVars.getWordClassClusters().get(w); if(num == null){ num = constVars.getWordClassClusters().get(w.toLowerCase()); } if (numw != null && constVars.distSimWeights.get(label).containsKey(numw)) score = constVars.distSimWeights.get(label).getCount(numw); if (score < minScore) minScore = score; totalscore += score; } if (useAvgInsteadofMinPhraseScoring) return totalscore / ph.length(); else return minScore; } } public String wordShape(String word){ String wordShape = constVars.getWordShapeCache().get(word); if(wordShape == null){ wordShape = WordShapeClassifier.wordShape(word, constVars.wordShaper); constVars.getWordShapeCache().put(word, wordShape); } return wordShape; } public double getWordShapeScore(String word, String label){ String wordShape = wordShape(word); double thislabel = 0, alllabels =0; for(Entry<String, Counter<String>> en: constVars.getWordShapesForLabels().entrySet()){ if(en.getKey().equals(label)) thislabel = en.getValue().getCount(wordShape); alllabels += en.getValue().getCount(wordShape); } double score = thislabel/ (alllabels + 1); return score; } public double getDictOddsScore(CandidatePhrase word, String label, double defaultWt) { double dscore; Counter<CandidatePhrase> dictOddsWordWeights = constVars.dictOddsWeights.get(label); assert dictOddsWordWeights != null : "dictOddsWordWeights is null for label " + label; if (dictOddsWordWeights.containsKey(word)) { dscore = dictOddsWordWeights.getCount(word); } else dscore = getPhraseWeightFromWords(dictOddsWordWeights, word, defaultWt); return dscore; } public double getPhraseWeightFromWords(Counter<CandidatePhrase> weights, CandidatePhrase ph, double defaultWt) { String[] t = ph.getPhrase().split("\\s+"); if (t.length < 2) { if (weights.containsKey(ph)) return weights.getCount(ph); else return defaultWt; } double totalscore = 0; double minScore = Double.MAX_VALUE; for (String w : t) { double score = defaultWt; if (weights.containsKey(CandidatePhrase.createOrGet(w))) score = weights.getCount(w); if (score < minScore) minScore = score; totalscore += score; } if (useAvgInsteadofMinPhraseScoring) return totalscore / ph.getPhrase().length(); else return minScore; } abstract public Counter<CandidatePhrase> scorePhrases(String label, Set<CandidatePhrase> terms, boolean forLearningPatterns) throws IOException, ClassNotFoundException; public abstract void printReasonForChoosing(Counter<CandidatePhrase> phrases); }