package edu.stanford.nlp.patterns; import java.io.IOException; import java.util.*; import java.util.Map.Entry; import java.util.function.Function; import edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures; import edu.stanford.nlp.patterns.GetPatternsFromDataMultiClass.PatternScoring; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.stats.TwoDimensionalCounter; import edu.stanford.nlp.util.ArgumentParser; import edu.stanford.nlp.util.Pair; public class ScorePatternsRatioModifiedFreq<E> extends ScorePatterns<E> { public ScorePatternsRatioModifiedFreq( ConstantsAndVariables constVars, PatternScoring patternScoring, String label, Set<CandidatePhrase> allCandidatePhrases, TwoDimensionalCounter<E, CandidatePhrase> patternsandWords4Label, TwoDimensionalCounter<E, CandidatePhrase> negPatternsandWords4Label, TwoDimensionalCounter<E, CandidatePhrase> unLabeledPatternsandWords4Label, TwoDimensionalCounter<CandidatePhrase, ScorePhraseMeasures> phInPatScores, ScorePhrases scorePhrases, Properties props) { super(constVars, patternScoring, label, allCandidatePhrases, patternsandWords4Label, negPatternsandWords4Label, unLabeledPatternsandWords4Label, props); this.phInPatScores = phInPatScores; this.scorePhrases = scorePhrases; } // cached values private TwoDimensionalCounter<CandidatePhrase, ScorePhraseMeasures> phInPatScores; private ScorePhrases scorePhrases; @Override public void setUp(Properties props) { } @Override public Counter<E> score() throws IOException, ClassNotFoundException { Counter<CandidatePhrase> externalWordWeightsNormalized = null; if (constVars.dictOddsWeights.containsKey(label)) externalWordWeightsNormalized = GetPatternsFromDataMultiClass .normalizeSoftMaxMinMaxScores(constVars.dictOddsWeights.get(label), true, true, false); Counter<E> currentPatternWeights4Label = new ClassicCounter<>(); boolean useFreqPhraseExtractedByPat = false; if (patternScoring.equals(PatternScoring.SqrtAllRatio)) useFreqPhraseExtractedByPat = true; Function<Pair<E, CandidatePhrase>, Double> numeratorScore = x -> patternsandWords4Label.getCount(x.first(), x.second()); Counter<E> numeratorPatWt = this.convert2OneDim(label, numeratorScore, allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, false, null, useFreqPhraseExtractedByPat); Counter<E> denominatorPatWt = null; Function<Pair<E, CandidatePhrase>, Double> denoScore; if (patternScoring.equals(PatternScoring.PosNegUnlabOdds)) { denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second()); denominatorPatWt = this.convert2OneDim(label, denoScore, allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, false, externalWordWeightsNormalized, useFreqPhraseExtractedByPat); } else if (patternScoring.equals(PatternScoring.RatioAll)) { denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second()) + patternsandWords4Label.getCount(x.first(), x.second()); denominatorPatWt = this.convert2OneDim(label, denoScore,allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, false, externalWordWeightsNormalized, useFreqPhraseExtractedByPat); } else if (patternScoring.equals(PatternScoring.PosNegOdds)) { denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()); denominatorPatWt = this.convert2OneDim(label, denoScore, allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, false, externalWordWeightsNormalized, useFreqPhraseExtractedByPat); } else if (patternScoring.equals(PatternScoring.PhEvalInPat) || patternScoring.equals(PatternScoring.PhEvalInPatLogP) || patternScoring.equals(PatternScoring.LOGREG) || patternScoring.equals(PatternScoring.LOGREGlogP)) { denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second()); denominatorPatWt = this.convert2OneDim(label, denoScore, allCandidatePhrases, patternsandWords4Label, constVars.sqrtPatScore, true, externalWordWeightsNormalized, useFreqPhraseExtractedByPat); } else if (patternScoring.equals(PatternScoring.SqrtAllRatio)) { denoScore = x -> negPatternsandWords4Label.getCount(x.first(), x.second()) + unLabeledPatternsandWords4Label.getCount(x.first(), x.second()); denominatorPatWt = this.convert2OneDim(label, denoScore, allCandidatePhrases, patternsandWords4Label, true, false, externalWordWeightsNormalized, useFreqPhraseExtractedByPat); } else throw new RuntimeException("Cannot understand patterns scoring"); currentPatternWeights4Label = Counters.divisionNonNaN(numeratorPatWt, denominatorPatWt); //Multiplying by logP if (patternScoring.equals(PatternScoring.PhEvalInPatLogP) || patternScoring.equals(PatternScoring.LOGREGlogP)) { Counter<E> logpos_i = new ClassicCounter<>(); for (Entry<E, ClassicCounter<CandidatePhrase>> en : patternsandWords4Label .entrySet()) { logpos_i.setCount(en.getKey(), Math.log(en.getValue().size())); } Counters.multiplyInPlace(currentPatternWeights4Label, logpos_i); } Counters.retainNonZeros(currentPatternWeights4Label); return currentPatternWeights4Label; } Counter<E> convert2OneDim(String label, Function<Pair<E, CandidatePhrase>, Double> scoringFunction, Set<CandidatePhrase> allCandidatePhrases, TwoDimensionalCounter<E, CandidatePhrase> positivePatternsAndWords, boolean sqrtPatScore, boolean scorePhrasesInPatSelection, Counter<CandidatePhrase> dictOddsWordWeights, boolean useFreqPhraseExtractedByPat) throws IOException, ClassNotFoundException { // if (Data.googleNGram.size() == 0 && Data.googleNGramsFile != null) { // Data.loadGoogleNGrams(); // } Counter<E> patterns = new ClassicCounter<>(); Counter<CandidatePhrase> googleNgramNormScores = new ClassicCounter<>(); Counter<CandidatePhrase> domainNgramNormScores = new ClassicCounter<>(); Counter<CandidatePhrase> externalFeatWtsNormalized = new ClassicCounter<>(); Counter<CandidatePhrase> editDistanceFromOtherSemanticBinaryScores = new ClassicCounter<>(); Counter<CandidatePhrase> editDistanceFromAlreadyExtractedBinaryScores = new ClassicCounter<>(); double externalWtsDefault = 0.5; Counter<String> classifierScores = null; if ((patternScoring.equals(PatternScoring.PhEvalInPat) || patternScoring .equals(PatternScoring.PhEvalInPatLogP)) && scorePhrasesInPatSelection) { for (CandidatePhrase gc : allCandidatePhrases) { String g = gc.getPhrase(); if (constVars.usePatternEvalEditDistOther) { editDistanceFromOtherSemanticBinaryScores.setCount(gc, constVars.getEditDistanceScoresOtherClassThreshold(label, g)); } if (constVars.usePatternEvalEditDistSame) { editDistanceFromAlreadyExtractedBinaryScores.setCount(gc, 1 - constVars.getEditDistanceScoresThisClassThreshold(label, g)); } if (constVars.usePatternEvalGoogleNgram) googleNgramNormScores .setCount(gc, PhraseScorer.getGoogleNgramScore(gc)); if (constVars.usePatternEvalDomainNgram) { // calculate domain-ngram wts if (Data.domainNGramRawFreq.containsKey(g)) { assert (Data.rawFreq.containsKey(gc)); domainNgramNormScores.setCount(gc, scorePhrases.phraseScorer.getDomainNgramScore(g)); } } if (constVars.usePatternEvalWordClass) { Integer num = constVars.getWordClassClusters().get(g); if(num == null){ num = constVars.getWordClassClusters().get(g.toLowerCase()); } if (num != null && constVars.distSimWeights.get(label).containsKey(num)) { externalFeatWtsNormalized.setCount(gc, constVars.distSimWeights.get(label).getCount(num)); } else externalFeatWtsNormalized.setCount(gc, externalWtsDefault); } } if (constVars.usePatternEvalGoogleNgram) googleNgramNormScores = GetPatternsFromDataMultiClass .normalizeSoftMaxMinMaxScores(googleNgramNormScores, true, true, false); if (constVars.usePatternEvalDomainNgram) domainNgramNormScores = GetPatternsFromDataMultiClass .normalizeSoftMaxMinMaxScores(domainNgramNormScores, true, true, false); if (constVars.usePatternEvalWordClass) externalFeatWtsNormalized = GetPatternsFromDataMultiClass .normalizeSoftMaxMinMaxScores(externalFeatWtsNormalized, true, true, false); } else if ((patternScoring.equals(PatternScoring.LOGREG) || patternScoring.equals(PatternScoring.LOGREGlogP)) && scorePhrasesInPatSelection) { Properties props2 = new Properties(); props2.putAll(props); props2.setProperty("phraseScorerClass", "edu.stanford.nlp.patterns.ScorePhrasesLearnFeatWt"); ScorePhrases scoreclassifier = new ScorePhrases(props2, constVars); System.out.println("file is " + props.getProperty("domainNGramsFile")); ArgumentParser.fillOptions(Data.class, props2); classifierScores = scoreclassifier.phraseScorer.scorePhrases(label, allCandidatePhrases, true); } Counter<CandidatePhrase> cachedScoresForThisIter = new ClassicCounter<>(); for (Map.Entry<E, ClassicCounter<CandidatePhrase>> en: positivePatternsAndWords.entrySet()) { for(Entry<CandidatePhrase, Double> en2: en.getValue().entrySet()) { CandidatePhrase word = en2.getKey(); Counter<ScorePhraseMeasures> scoreslist = new ClassicCounter<>(); double score = 1; if ((patternScoring.equals(PatternScoring.PhEvalInPat) || patternScoring .equals(PatternScoring.PhEvalInPatLogP)) && scorePhrasesInPatSelection) { if (cachedScoresForThisIter.containsKey(word)) { score = cachedScoresForThisIter.getCount(word); } else { if (constVars.getOtherSemanticClassesWords().contains(word) || constVars.getCommonEngWords().contains(word)) score = 1; else { if (constVars.usePatternEvalSemanticOdds) { double semanticClassOdds = 1; if (dictOddsWordWeights.containsKey(word)) semanticClassOdds = 1 - dictOddsWordWeights.getCount(word); scoreslist.setCount(ScorePhraseMeasures.SEMANTICODDS, semanticClassOdds); } if (constVars.usePatternEvalGoogleNgram) { double gscore = 0; if (googleNgramNormScores.containsKey(word)) { gscore = 1 - googleNgramNormScores.getCount(word); } scoreslist.setCount(ScorePhraseMeasures.GOOGLENGRAM, gscore); } if (constVars.usePatternEvalDomainNgram) { double domainscore; if (domainNgramNormScores.containsKey(word)) { domainscore = 1 - domainNgramNormScores.getCount(word); } else domainscore = 1 - scorePhrases.phraseScorer .getPhraseWeightFromWords(domainNgramNormScores, word, scorePhrases.phraseScorer.OOVDomainNgramScore); scoreslist.setCount(ScorePhraseMeasures.DOMAINNGRAM, domainscore); } if (constVars.usePatternEvalWordClass) { double externalFeatureWt = externalWtsDefault; if (externalFeatWtsNormalized.containsKey(word)) externalFeatureWt = 1 - externalFeatWtsNormalized.getCount(word); scoreslist.setCount(ScorePhraseMeasures.DISTSIM, externalFeatureWt); } if (constVars.usePatternEvalEditDistOther) { assert editDistanceFromOtherSemanticBinaryScores.containsKey(word) : "How come no edit distance info for word " + word + ""; scoreslist.setCount(ScorePhraseMeasures.EDITDISTOTHER, editDistanceFromOtherSemanticBinaryScores.getCount(word)); } if (constVars.usePatternEvalEditDistSame) { scoreslist.setCount(ScorePhraseMeasures.EDITDISTSAME, editDistanceFromAlreadyExtractedBinaryScores.getCount(word)); } // taking average score = Counters.mean(scoreslist); phInPatScores.setCounter(word, scoreslist); } cachedScoresForThisIter.setCount(word, score); } } else if ((patternScoring.equals(PatternScoring.LOGREG) || patternScoring.equals(PatternScoring.LOGREGlogP)) && scorePhrasesInPatSelection) { score = 1 - classifierScores.getCount(word); // score = 1 - scorePhrases.scoreUsingClassifer(classifier, // e.getKey(), label, true, null, null, dictOddsWordWeights); // throw new RuntimeException("not implemented yet"); } if (useFreqPhraseExtractedByPat) score = score * scoringFunction.apply(new Pair(en.getKey(), word)); if (constVars.sqrtPatScore) patterns.incrementCount(en.getKey(), Math.sqrt(score)); else patterns.incrementCount(en.getKey(), score); } } return patterns; } }