package edu.stanford.nlp.patterns;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.*;
import java.util.Map.Entry;
import java.util.concurrent.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import edu.stanford.nlp.classify.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures;
import edu.stanford.nlp.patterns.dep.DataInstanceDep;
import edu.stanford.nlp.patterns.dep.ExtractPhraseFromPattern;
import edu.stanford.nlp.patterns.dep.ExtractedPhrase;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.stats.*;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.ArgumentParser.Option;
import edu.stanford.nlp.util.concurrent.AtomicDouble;
import edu.stanford.nlp.util.concurrent.ConcurrentHashCounter;
import edu.stanford.nlp.util.logging.Redwood;
/**
* Learn a logistic regression classifier to combine weights to score a phrase.
*
* @author Sonal Gupta (sonalg@stanford.edu)
*
*/
public class ScorePhrasesLearnFeatWt<E extends Pattern> extends PhraseScorer<E> {
@Option(name = "scoreClassifierType")
private ClassifierType scoreClassifierType = ClassifierType.LR;
private static Map<String, double[]> wordVectors = null;
public ScorePhrasesLearnFeatWt(ConstantsAndVariables constvar) {
super(constvar);
if(constvar.useWordVectorsToComputeSim && (constvar.subsampleUnkAsNegUsingSim|| constvar.expandPositivesWhenSampling || constvar.expandNegativesWhenSampling || constVars.usePhraseEvalWordVector) && wordVectors == null) {
if(Data.rawFreq == null){
Data.rawFreq = new ClassicCounter<>();
Data.computeRawFreqIfNull(PatternFactory.numWordsCompoundMax, constvar.batchProcessSents);
}
Redwood.log(Redwood.DBG, "Reading word vectors");
wordVectors = new HashMap<>();
for (String line : IOUtils.readLines(constVars.wordVectorFile)) {
String[] tok = line.split("\\s+");
String word = tok[0];
CandidatePhrase p = CandidatePhrase.createOrGet(word);
//save the vector if it occurs in the rawFreq, seed set, stop words, english words
if (Data.rawFreq.containsKey(p) || constvar.getStopWords().contains(p) || constvar.getEnglishWords().contains(word) || constvar.hasSeedWordOrOtherSem(p)) {
double[] d = new double[tok.length - 1];
for (int i = 1; i < tok.length; i++) {
d[i - 1] = Double.valueOf(tok[i]);
}
wordVectors.put(word, d);
} else
CandidatePhrase.deletePhrase(p);
}
Redwood.log(Redwood.DBG, "Read " + wordVectors.size() + " word vectors");
}
OOVExternalFeatWt = 0;
OOVdictOdds = 0;
OOVDomainNgramScore = 0;
OOVGoogleNgramScore = 0;
}
public enum ClassifierType {
DT, LR, RF, SVM, SHIFTLR, LINEAR
}
public TwoDimensionalCounter<CandidatePhrase, ScorePhraseMeasures> phraseScoresRaw = new TwoDimensionalCounter<>();
public edu.stanford.nlp.classify.Classifier learnClassifier(String label, boolean forLearningPatterns,
TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns) throws IOException, ClassNotFoundException {
phraseScoresRaw.clear();
learnedScores.clear();
if(Data.domainNGramsFile != null)
Data.loadDomainNGrams();
boolean computeRawFreq = false;
if (Data.rawFreq == null) {
Data.rawFreq = new ClassicCounter<>();
computeRawFreq = true;
}
GeneralDataset<String, ScorePhraseMeasures> dataset = choosedatums(forLearningPatterns, label, wordsPatExtracted, allSelectedPatterns, computeRawFreq);
edu.stanford.nlp.classify.Classifier classifier;
if (scoreClassifierType.equals(ClassifierType.LR)) {
LogisticClassifierFactory<String, ScorePhraseMeasures> logfactory = new LogisticClassifierFactory<>();
LogPrior lprior = new LogPrior();
lprior.setSigma(constVars.LRSigma);
classifier = logfactory.trainClassifier(dataset, lprior, false);
LogisticClassifier logcl = ((LogisticClassifier) classifier);
String l = (String) logcl.getLabelForInternalPositiveClass();
Counter<String> weights = logcl.weightsAsCounter();
if (l.equals(Boolean.FALSE.toString())) {
Counters.multiplyInPlace(weights, -1);
}
List<Pair<String, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
} else if(scoreClassifierType.equals(ClassifierType.SVM)){
SVMLightClassifierFactory<String, ScorePhraseMeasures> svmcf = new SVMLightClassifierFactory<>(true);
classifier = svmcf.trainClassifier(dataset);
Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((SVMLightClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
}else if(scoreClassifierType.equals(ClassifierType.SHIFTLR)){
//change the dataset to basic dataset because currently ShiftParamsLR doesn't support RVFDatum
GeneralDataset<String, ScorePhraseMeasures> newdataset = new Dataset<>();
Iterator<RVFDatum<String, ScorePhraseMeasures>> iter = dataset.iterator();
while(iter.hasNext()){
RVFDatum<String, ScorePhraseMeasures> inst = iter.next();
newdataset.add(new BasicDatum<>(inst.asFeatures(), inst.label()));
}
ShiftParamsLogisticClassifierFactory<String, ScorePhraseMeasures> factory = new ShiftParamsLogisticClassifierFactory<>();
classifier = factory.trainClassifier(newdataset);
//print weights
MultinomialLogisticClassifier<String, ScorePhraseMeasures> logcl = ((MultinomialLogisticClassifier) classifier);
Counter<ScorePhraseMeasures> weights = logcl.weightsAsGenericCounter().get("true");
List<Pair<ScorePhraseMeasures, Double>> wtd = Counters.toDescendingMagnitudeSortedListWithCounts(weights);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(wtd.subList(0, Math.min(wtd.size(), 600)), "\n"));
} else if(scoreClassifierType.equals(ClassifierType.LINEAR)){
LinearClassifierFactory<String, ScorePhraseMeasures> lcf = new LinearClassifierFactory<>();
classifier = lcf.trainClassifier(dataset);
Set<String> labels = Generics.newHashSet(Arrays.asList("true"));
List<Triple<ScorePhraseMeasures, String, Double>> topfeatures = ((LinearClassifier<String, ScorePhraseMeasures>) classifier).getTopFeatures(labels, 0, true, 600, true);
Redwood.log(ConstantsAndVariables.minimaldebug, "The weights are " + StringUtils.join(topfeatures, "\n"));
}else
throw new RuntimeException("cannot identify classifier " + scoreClassifierType);
// else if (scoreClassifierType.equals(ClassifierType.RF)) {
// ClassifierFactory wekaFactory = new WekaDatumClassifierFactory<String, ScorePhraseMeasures>("weka.classifiers.trees.RandomForest", constVars.wekaOptions);
// classifier = wekaFactory.trainClassifier(dataset);
// Classifier cls = ((WekaDatumClassifier) classifier).getClassifier();
// RandomForest rf = (RandomForest) cls;
// }
BufferedWriter w = new BufferedWriter(new FileWriter("tempscorestrainer.txt"));
System.out.println("size of learned scores is " + phraseScoresRaw.size());
for (CandidatePhrase s : phraseScoresRaw.firstKeySet()) {
w.write(s + "\t" + phraseScoresRaw.getCounter(s) + "\n");
}
w.close();
return classifier;
}
@Override
public void printReasonForChoosing(Counter<CandidatePhrase> phrases){
Redwood.log(Redwood.DBG, "Features of selected phrases");
for(Entry<CandidatePhrase, Double> pEn: phrases.entrySet())
Redwood.log(Redwood.DBG, pEn.getKey().getPhrase() + "\t" + pEn.getValue() + "\t" + phraseScoresRaw.getCounter(pEn.getKey()));
}
@Override
public Counter<CandidatePhrase> scorePhrases(String label, TwoDimensionalCounter<CandidatePhrase, E> terms,
TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns,
Set<CandidatePhrase> alreadyIdentifiedWords, boolean forLearningPatterns) throws IOException, ClassNotFoundException {
getAllLabeledWordsCluster();
Counter<CandidatePhrase> scores = new ClassicCounter<>();
edu.stanford.nlp.classify.Classifier classifier = learnClassifier(label, forLearningPatterns, wordsPatExtracted, allSelectedPatterns);
for (Entry<CandidatePhrase, ClassicCounter<E>> en : terms.entrySet()) {
Double score = this.scoreUsingClassifer(classifier, en.getKey(), label, forLearningPatterns, en.getValue(), allSelectedPatterns);
if(!score.isNaN() && !score.isInfinite()){
scores.setCount(en.getKey(), score);
}else
Redwood.log(Redwood.DBG, "Ignoring " + en.getKey() + " because score is " + score);
}
return scores;
}
@Override
public Counter<CandidatePhrase> scorePhrases(String label, Set<CandidatePhrase> terms, boolean forLearningPatterns) throws IOException, ClassNotFoundException {
getAllLabeledWordsCluster();
Counter<CandidatePhrase> scores = new ClassicCounter<>();
edu.stanford.nlp.classify.Classifier classifier = learnClassifier(label, forLearningPatterns, null, null);
for (CandidatePhrase en : terms) {
double score = this.scoreUsingClassifer(classifier, en, label, forLearningPatterns,null, null);
scores.setCount(en, score);
}
return scores;
}
public static boolean getRandomBoolean(Random random, double p) {
return random.nextFloat() < p;
}
static double logistic(double d) {
return 1 / (1 + Math.exp(-1 * d));
}
ConcurrentHashMap<CandidatePhrase, Counter<Integer>> wordClassClustersForPhrase = new ConcurrentHashMap<>();
Counter<Integer> wordClass(String phrase, String phraseLemma){
Counter<Integer> cl = new ClassicCounter<>();
String[] phl = null;
if(phraseLemma!=null)
phl = phraseLemma.split("\\s+");
int i =0;
for(String w: phrase.split("\\s+")) {
Integer cluster = constVars.getWordClassClusters().get(w);
if (cluster == null && phl!=null)
cluster = constVars.getWordClassClusters().get(phl[i]);
//try lowercase
if(cluster == null){
cluster = constVars.getWordClassClusters().get(w.toLowerCase());
if (cluster == null && phl!=null)
cluster = constVars.getWordClassClusters().get(phl[i].toLowerCase());
}
if(cluster != null)
cl.incrementCount(cluster);
i++;
}
return cl;
}
void getAllLabeledWordsCluster(){
for(String label: constVars.getLabels()){
for(Map.Entry<CandidatePhrase, Double> p : constVars.getLearnedWords(label).entrySet()){
wordClassClustersForPhrase.put(p.getKey(), wordClass(p.getKey().getPhrase(), p.getKey().getPhraseLemma()));
}
for(CandidatePhrase p : constVars.getSeedLabelDictionary().get(label)){
wordClassClustersForPhrase.put(p, wordClass(p.getPhrase(), p.getPhraseLemma()));
}
}
}
private Counter<CandidatePhrase> computeSimWithWordVectors(Collection<CandidatePhrase> candidatePhrases, Collection<CandidatePhrase> otherPhrases, boolean ignoreWordRegex, String label){
Counter<CandidatePhrase> sims = new ClassicCounter<>(candidatePhrases.size());
for(CandidatePhrase p : candidatePhrases) {
Map<String, double[]> simsAvgMaxAllLabels = similaritiesWithLabeledPhrases.get(p.getPhrase());
if(simsAvgMaxAllLabels == null)
simsAvgMaxAllLabels = new HashMap<>();
double[] simsAvgMax = simsAvgMaxAllLabels.get(label);
if (simsAvgMax == null) {
simsAvgMax = new double[Similarities.values().length];
// Arrays.fill(simsAvgMax, 0); // not needed; Java arrays zero initialized
}
if(wordVectors.containsKey(p.getPhrase()) && (! ignoreWordRegex || !PatternFactory.ignoreWordRegex.matcher(p.getPhrase()).matches())){
double[] d1 = wordVectors.get(p.getPhrase());
BinaryHeapPriorityQueue<CandidatePhrase> topSimPhs = new BinaryHeapPriorityQueue<>(constVars.expandPhrasesNumTopSimilar);
double allsum = 0;
double max = Double.MIN_VALUE;
boolean donotuse = false;
for (CandidatePhrase other : otherPhrases) {
if (p.equals(other)) {
donotuse = true;
break;
}
if (!wordVectors.containsKey(other.getPhrase()))
continue;
double sim;
PhrasePair pair = new PhrasePair(p.getPhrase(), other.getPhrase());
if (cacheSimilarities.containsKey(pair))
sim = cacheSimilarities.getCount(pair);
else {
double[] d2 = wordVectors.get(other.getPhrase());
double sum = 0;
double d1sq = 0;
double d2sq = 0;
for (int i = 0; i < d1.length; i++) {
sum += d1[i] * d2[i];
d1sq += d1[i] * d1[i];
d2sq += d2[i] * d2[i];
}
sim = sum / (Math.sqrt(d1sq) * Math.sqrt(d2sq));
cacheSimilarities.setCount(pair, sim);
}
topSimPhs.add(other, sim);
if(topSimPhs.size() > constVars.expandPhrasesNumTopSimilar)
topSimPhs.removeLastEntry();
//avgSim /= otherPhrases.size();
allsum += sim;
if(sim > max)
max = sim;
}
double finalSimScore = 0;
int numEl = 0;
while(topSimPhs.hasNext()) {
finalSimScore += topSimPhs.getPriority();
topSimPhs.next();
numEl++;
}
finalSimScore /= numEl;
double prevNumItems = simsAvgMax[Similarities.NUMITEMS.ordinal()];
double prevAvg = simsAvgMax[Similarities.AVGSIM.ordinal()];
double prevMax = simsAvgMax[Similarities.MAXSIM.ordinal()];
double newNumItems = prevNumItems + otherPhrases.size();
double newAvg = (prevAvg*prevNumItems + allsum) /(newNumItems);
double newMax = prevMax > max ? prevMax: max;
simsAvgMax[Similarities.NUMITEMS.ordinal()] = newNumItems;
simsAvgMax[Similarities.AVGSIM.ordinal()] = newAvg;
simsAvgMax[Similarities.MAXSIM.ordinal()] = newMax;
if(!donotuse){
sims.setCount(p, finalSimScore);
}
}else{
sims.setCount(p, Double.MIN_VALUE);
}
simsAvgMaxAllLabels.put(label, simsAvgMax);
similaritiesWithLabeledPhrases.put(p.getPhrase(), simsAvgMaxAllLabels);
}
return sims;
}
private Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>> computeSimWithWordVectors(List<CandidatePhrase> candidatePhrases, Collection<CandidatePhrase> positivePhrases,
Map<String, Collection<CandidatePhrase>> allPossibleNegativePhrases, String label) {
assert wordVectors != null : "Why are word vectors null?";
Counter<CandidatePhrase> posSims = computeSimWithWordVectors(candidatePhrases, positivePhrases, true, label);
Counter<CandidatePhrase> negSims = new ClassicCounter<>();
for(Map.Entry<String, Collection<CandidatePhrase>> en: allPossibleNegativePhrases.entrySet())
negSims.addAll(computeSimWithWordVectors(candidatePhrases, en.getValue(), true, en.getKey()));
Function<CandidatePhrase, Boolean> retainPhrasesNotCloseToNegative = candidatePhrase -> {
if(negSims.getCount(candidatePhrase) > posSims.getCount(candidatePhrase))
return false;
else
return true;
};
Counters.retainKeys(posSims, retainPhrasesNotCloseToNegative);
return new Pair(posSims, negSims);
}
Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>> computeSimWithWordCluster(Collection<CandidatePhrase> candidatePhrases, Collection<CandidatePhrase> positivePhrases, AtomicDouble allMaxSim){
Counter<CandidatePhrase> sims = new ClassicCounter<>(candidatePhrases.size());
for(CandidatePhrase p : candidatePhrases) {
Counter<Integer> feat = wordClassClustersForPhrase.get(p);
if(feat == null){
feat = wordClass(p.getPhrase(), p.getPhraseLemma());
wordClassClustersForPhrase.put(p, feat);
}
double avgSim = 0;// Double.MIN_VALUE;
if(feat.size() > 0) {
for (CandidatePhrase pos : positivePhrases) {
if(p.equals(pos))
continue;
Counter<Integer> posfeat = wordClassClustersForPhrase.get(pos);
if(posfeat == null){
posfeat = wordClass(pos.getPhrase(), pos.getPhraseLemma());
wordClassClustersForPhrase.put(pos, feat);
}
if(posfeat.size() > 0){
Double j = Counters.jaccardCoefficient(posfeat, feat);
//System.out.println("clusters for positive phrase " + pos + " is " +wordClassClustersForPhrase.get(pos) + " and the features for unknown are " + feat + " for phrase " + p);
if(!j.isInfinite() && !j.isNaN()){
avgSim += j;
}
//if (j > maxSim)
// maxSim = j;
}
}
avgSim /= positivePhrases.size();
}
sims.setCount(p, avgSim);
if(allMaxSim.get() < avgSim)
allMaxSim.set(avgSim);
}
//TODO: compute similarity with neg phrases
return new Pair(sims, null);
}
class ComputeSim implements Callable<Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>>>{
List<CandidatePhrase> candidatePhrases;
String label;
AtomicDouble allMaxSim;
Collection<CandidatePhrase> positivePhrases;
Map<String, Collection<CandidatePhrase>> knownNegativePhrases;
public ComputeSim(String label, List<CandidatePhrase> candidatePhrases, AtomicDouble allMaxSim, Collection<CandidatePhrase> positivePhrases, Map<String, Collection<CandidatePhrase>> knownNegativePhrases){
this.label = label;
this.candidatePhrases = candidatePhrases;
this.allMaxSim = allMaxSim;
this.positivePhrases = positivePhrases;
this.knownNegativePhrases = knownNegativePhrases;
}
@Override
public Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>> call() throws Exception {
if(constVars.useWordVectorsToComputeSim){
Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>> phs = computeSimWithWordVectors(candidatePhrases, positivePhrases, knownNegativePhrases, label);
Redwood.log(Redwood.DBG, "Computed similarities with positive and negative phrases");
return phs;
}
else
//TODO: knownnegaitvephrases
return computeSimWithWordCluster(candidatePhrases, positivePhrases, allMaxSim);
}
}
//this chooses the ones that are not close to the positive phrases!
Set<CandidatePhrase> chooseUnknownAsNegatives(Set<CandidatePhrase> candidatePhrases, String label, Collection<CandidatePhrase> positivePhrases, Map<String,
Collection<CandidatePhrase>> knownNegativePhrases, BufferedWriter logFile) throws IOException {
List<List<CandidatePhrase>> threadedCandidates = GetPatternsFromDataMultiClass.getThreadBatches(CollectionUtils.toList(candidatePhrases), constVars.numThreads);
Counter<CandidatePhrase> sims = new ClassicCounter<>();
AtomicDouble allMaxSim = new AtomicDouble(Double.MIN_VALUE);
ExecutorService executor = Executors.newFixedThreadPool(constVars.numThreads);
List<Future<Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>>>> list = new ArrayList<>();
//multi-threaded choose positive, negative and unknown
for (List<CandidatePhrase> keys : threadedCandidates) {
Callable<Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>>> task = new ComputeSim(label, keys, allMaxSim, positivePhrases, knownNegativePhrases);
Future<Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>>> submit = executor.submit(task);
list.add(submit);
}
// Now retrieve the result
for (Future<Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>>> future : list) {
try {
sims.addAll(future.get().first());
} catch (Exception e) {
executor.shutdownNow();
throw new RuntimeException(e);
}
}
executor.shutdown();
if(allMaxSim.get() == Double.MIN_VALUE){
Redwood.log(Redwood.DBG, "No similarity recorded between the positives and the unknown!");
}
CandidatePhrase k = Counters.argmax(sims);
System.out.println("Maximum similarity was " + sims.getCount(k) + " for word " + k);
Counter<CandidatePhrase> removed = Counters.retainBelow(sims, constVars.positiveSimilarityThresholdLowPrecision);
System.out.println("removing phrases as negative phrases that were higher that positive similarity threshold of " + constVars.positiveSimilarityThresholdLowPrecision + removed);
if(logFile != null && wordVectors != null){
for(Entry<CandidatePhrase, Double> en: removed.entrySet())
if(wordVectors.containsKey(en.getKey().getPhrase()))
logFile.write(en.getKey()+"-PN " + ArrayUtils.toString(wordVectors.get(en.getKey().getPhrase()), " ")+"\n");
}
//Collection<CandidatePhrase> removed = Counters.retainBottom(sims, (int) (sims.size() * percentage));
//System.out.println("not choosing " + removed + " as the negative phrases. percentage is " + percentage + " and allMaxsim was " + allMaxSim);
return sims.keySet();
}
Set<CandidatePhrase> chooseUnknownPhrases(DataInstance sent, Random random, double perSelect, Class positiveClass, String label, int maxNum){
Set<CandidatePhrase> unknownSamples = new HashSet<>();
if(maxNum == 0)
return unknownSamples;
Function<CoreLabel, Boolean> acceptWord = coreLabel -> {
if(coreLabel.get(positiveClass).equals(label) || constVars.functionWords.contains(coreLabel.word()))
return false;
else
return true;
};
Random r = new Random(0);
List<Integer> lengths = new ArrayList<>();
for(int i = 1;i <= PatternFactory.numWordsCompoundMapped.get(label); i++)
lengths.add(i);
int length = CollectionUtils.sample(lengths, r);
if(constVars.patternType.equals(PatternFactory.PatternType.DEP)){
ExtractPhraseFromPattern extract = new ExtractPhraseFromPattern(true, length);
SemanticGraph g = ((DataInstanceDep) sent).getGraph();
Collection<CoreLabel> sampledHeads = CollectionUtils.sampleWithoutReplacement(sent.getTokens(), Math.min(maxNum, (int) (perSelect * sent.getTokens().size())), random);
//TODO: change this for more efficient implementation
List<String> textTokens = sent.getTokens().stream().map(x -> x.word()).collect(Collectors.toList());
for(CoreLabel l: sampledHeads) {
if(!acceptWord.apply(l))
continue;
IndexedWord w = g.getNodeByIndex(l.index());
List<String> outputPhrases = new ArrayList<>();
List<ExtractedPhrase> extractedPhrases = new ArrayList<>();
List<IntPair> outputIndices = new ArrayList<>();
extract.printSubGraph(g, w, new ArrayList<>(), textTokens, outputPhrases, outputIndices, new ArrayList<>(), new ArrayList<>(),
false, extractedPhrases, null, acceptWord);
for(ExtractedPhrase p :extractedPhrases){
unknownSamples.add(CandidatePhrase.createOrGet(p.getValue(), null, p.getFeatures()));
}
}
}else if(constVars.patternType.equals(PatternFactory.PatternType.SURFACE)){
CoreLabel[] tokens = sent.getTokens().toArray(new CoreLabel[0]);
for(int i =0; i < tokens.length; i++){
if(random.nextDouble() < perSelect){
int left = (int)((length -1) /2.0);
int right = length -1 -left;
String ph = "";
boolean haspositive = false;
for(int j = Math.max(0, i - left); j < tokens.length && j <= i+right; j++){
if(tokens[j].get(positiveClass).equals(label)){
haspositive = true;
break;
}
ph += " " + tokens[j].word();
}
ph = ph.trim();
if(!haspositive && !ph.trim().isEmpty() && !constVars.functionWords.contains(ph)){
unknownSamples.add(CandidatePhrase.createOrGet(ph));
}
}
}
} else
throw new RuntimeException("not yet implemented");
return unknownSamples;
}
private static<E,F> boolean hasElement(Map<E, Collection<F>> values, F value, E ignoreLabel){
for(Map.Entry<E, Collection<F>> en: values.entrySet()){
if(en.getKey().equals(ignoreLabel))
continue;
if(en.getValue().contains(value))
return true;
}
return false;
}
Counter<String> numLabeledTokens(){
Counter<String> counter = new ClassicCounter<>();
ConstantsAndVariables.DataSentsIterator data = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
while(data.hasNext()){
Map<String, DataInstance> sentsf = data.next().first();
for(Entry<String, DataInstance> en: sentsf.entrySet()){
for(CoreLabel l : en.getValue().getTokens()){
for(Entry<String, Class<? extends TypesafeMap.Key<String>>> enc: constVars.getAnswerClass().entrySet()){
if(l.get(enc.getValue()).equals(enc.getKey())){
counter.incrementCount(enc.getKey());
}
}
}
}
}
return counter;
}
Counter<CandidatePhrase> closeToPositivesFirstIter = null;
Counter<CandidatePhrase> closeToNegativesFirstIter = null;
public class ChooseDatumsThread implements Callable {
Collection<String> keys;
Map<String, DataInstance> sents;
Class answerClass;
String answerLabel;
TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted;
Counter<E> allSelectedPatterns;
Counter<Integer> wordClassClustersOfPositive;
Map<String, Collection<CandidatePhrase>> allPossiblePhrases;
boolean expandPos;
boolean expandNeg;
public ChooseDatumsThread(String label, Map<String, DataInstance> sents, Collection<String> keys, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, Counter<E> allSelectedPatterns,
Counter<Integer> wordClassClustersOfPositive, Map<String, Collection<CandidatePhrase>> allPossiblePhrases, boolean expandPos, boolean expandNeg){
this.answerLabel = label;
this.sents = sents;
this.keys = keys;
this.wordsPatExtracted = wordsPatExtracted;
this.allSelectedPatterns = allSelectedPatterns;
this.wordClassClustersOfPositive = wordClassClustersOfPositive;
this.allPossiblePhrases = allPossiblePhrases;
answerClass = constVars.getAnswerClass().get(answerLabel);
this.expandNeg = expandNeg;
this.expandPos = expandPos;
}
@Override
public Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>> call() throws Exception {
Random r = new Random(10);
Random rneg = new Random(10);
Set<CandidatePhrase> allPositivePhrases = new HashSet<>();
Set<CandidatePhrase> allNegativePhrases = new HashSet<>();
Set<CandidatePhrase> allUnknownPhrases = new HashSet<>();
Counter<CandidatePhrase> allCloseToPositivePhrases = new ClassicCounter<>();
Counter<CandidatePhrase> allCloseToNegativePhrases = new ClassicCounter<>();
Set<CandidatePhrase> knownPositivePhrases = CollectionUtils.unionAsSet(constVars.getLearnedWords(answerLabel).keySet(), constVars.getSeedLabelDictionary().get(answerLabel));
Set<CandidatePhrase> allConsideredPhrases = new HashSet<>();
Map<Class, Object> otherIgnoreClasses = constVars.getIgnoreWordswithClassesDuringSelection().get(answerLabel);
int numlabeled = 0;
for (String sentid : keys) {
DataInstance sentInst = sents.get(sentid);
List<CoreLabel> value = sentInst.getTokens();
CoreLabel[] sent = value.toArray(new CoreLabel[value.size()]);
for (int i = 0; i < sent.length; i++) {
CoreLabel l = sent[i];
if (l.get(answerClass).equals(answerLabel)) {
numlabeled++;
CandidatePhrase candidate = l.get(PatternsAnnotations.LongestMatchedPhraseForEachLabel.class).get(answerLabel);
if (candidate == null) {
throw new RuntimeException("for sentence id " + sentid + " and token id " + i + " candidate is null for " + l.word() + " and longest matching" + l.get(PatternsAnnotations.LongestMatchedPhraseForEachLabel.class) + " and matched phrases are " + l.get(PatternsAnnotations.MatchedPhrases.class));
//candidate = CandidatePhrase.createOrGet(l.word());
}
//If the phrase does not exist in its form in the datset (happens when fuzzy matching etc).
if(!Data.rawFreq.containsKey(candidate)){
candidate = CandidatePhrase.createOrGet(l.word());
}
//Do not add to positive if the word is a "negative" (stop word, english word, ...)
if(hasElement(allPossiblePhrases, candidate, answerLabel) || PatternFactory.ignoreWordRegex.matcher(candidate.getPhrase()).matches())
continue;
allPositivePhrases.add(candidate);
} else {
Map<String, CandidatePhrase> longestMatching = l.get(PatternsAnnotations.LongestMatchedPhraseForEachLabel.class);
boolean ignoreclass = false;
CandidatePhrase candidate = CandidatePhrase.createOrGet(l.word());
for (Class cl : otherIgnoreClasses.keySet()) {
if ((Boolean) l.get(cl)) {
ignoreclass = true;
candidate = longestMatching.containsKey("OTHERSEM")? longestMatching.get("OTHERSEM") : candidate;
break;
}
}
if(!ignoreclass) {
ignoreclass = constVars.functionWords.contains(l.word());
}
boolean negative = false;
boolean add= false;
for (Map.Entry<String, CandidatePhrase> lo : longestMatching.entrySet()) {
//assert !lo.getValue().getPhrase().isEmpty() : "How is the longestmatching phrase for " + l.word() + " empty ";
if (!lo.getKey().equals(answerLabel) && lo.getValue() != null) {
negative = true;
add = true;
//If the phrase does not exist in its form in the datset (happens when fuzzy matching etc).
if(Data.rawFreq.containsKey(lo.getValue())){
candidate = lo.getValue();
}
}
}
if (!negative && ignoreclass) {
add = true;
}
if(add && rneg.nextDouble() < constVars.perSelectNeg){
assert !candidate.getPhrase().isEmpty();
allNegativePhrases.add(candidate);
}
if(!negative && !ignoreclass && (expandPos || expandNeg) && !hasElement(allPossiblePhrases, candidate, answerLabel) && !PatternFactory.ignoreWordRegex.matcher(candidate.getPhrase()).matches()) {
if (!allConsideredPhrases.contains(candidate)) {
Pair<Counter<CandidatePhrase>, Counter<CandidatePhrase>> sims;
assert candidate != null;
if(constVars.useWordVectorsToComputeSim)
sims = computeSimWithWordVectors(Arrays.asList(candidate), knownPositivePhrases, allPossiblePhrases, answerLabel);
else
sims = computeSimWithWordCluster(Arrays.asList(candidate), knownPositivePhrases, new AtomicDouble());
boolean addedAsPos = false;
if(expandPos)
{
double sim = sims.first().getCount(candidate);
if (sim > constVars.similarityThresholdHighPrecision){
allCloseToPositivePhrases.setCount(candidate, sim);
addedAsPos = true;
}
}
if(expandNeg && !addedAsPos) {
double simneg = sims.second().getCount(candidate);
if (simneg > constVars.similarityThresholdHighPrecision)
allCloseToNegativePhrases.setCount(candidate, simneg);
}
allConsideredPhrases.add(candidate);
}
}
}
}
allUnknownPhrases.addAll(chooseUnknownPhrases(sentInst, r, constVars.perSelectRand, constVars.getAnswerClass().get(answerLabel), answerLabel, Math.max(0, Integer.MAX_VALUE)));
//
// if (negative && getRandomBoolean(rneg, perSelectNeg)) {
// numneg++;
// } else if (getRandomBoolean(r, perSelectRand)) {
// candidate = CandidatePhrase.createOrGet(l.word());
// numneg++;
// } else {
// continue;
// }
//
//
// chosen.add(new Pair<String, Integer>(en.getKey(), i));
}
return new Quintuple(allPositivePhrases, allNegativePhrases, allUnknownPhrases, allCloseToPositivePhrases, allCloseToNegativePhrases);
}
}
static private class PhrasePair{
final String p1;
final String p2;
final int hashCode;
public PhrasePair(String p1, String p2) {
if(p1.compareTo(p2) <=0)
{
this.p1 = p1;
this.p2 = p2;
}else
{
this.p1 = p2;
this.p2 = p1;
}
this.hashCode = p1.hashCode() + p2.hashCode() + 331;
}
@Override
public int hashCode(){
return hashCode;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof PhrasePair))
return false;
PhrasePair p = (PhrasePair) o;
if (p.getPhrase1().equals(this.getPhrase1()) && p.getPhrase2().equals(this.getPhrase2()))
return true;
return false;
}
public String getPhrase1() {
return p1;
}
public String getPhrase2() {
return p2;
}
}
static Counter<PhrasePair> cacheSimilarities = new ConcurrentHashCounter<>();
//First map is phrase, second map is label to similarity stats
static Map<String, Map<String, double[]>> similaritiesWithLabeledPhrases = new ConcurrentHashMap<>();
Map<String, Collection<CandidatePhrase>> getAllPossibleNegativePhrases(String answerLabel){
//make all possible negative phrases
Map<String, Collection<CandidatePhrase>> allPossiblePhrases = new HashMap<>();
Collection<CandidatePhrase> negPhrases = new HashSet<>();
//negPhrases.addAll(constVars.getOtherSemanticClassesWords());
negPhrases.addAll(constVars.getStopWords());
negPhrases.addAll(CandidatePhrase.convertStringPhrases(constVars.functionWords));
negPhrases.addAll(CandidatePhrase.convertStringPhrases(constVars.getEnglishWords()));
allPossiblePhrases.put("NEGATIVE", negPhrases);
for(String label: constVars.getLabels()) {
if (!label.equals(answerLabel)){
allPossiblePhrases.put(label, new HashSet<>());
if(constVars.getLearnedWordsEachIter().containsKey(label))
allPossiblePhrases.get(label).addAll(constVars.getLearnedWords(label).keySet());
allPossiblePhrases.get(label).addAll(constVars.getSeedLabelDictionary().get(label));
}
}
allPossiblePhrases.put("OTHERSEM", constVars.getOtherSemanticClassesWords());
return allPossiblePhrases;
}
public GeneralDataset<String, ScorePhraseMeasures> choosedatums(boolean forLearningPattern, String answerLabel,
TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted,
Counter<E> allSelectedPatterns, boolean computeRawFreq) throws IOException {
boolean expandNeg = false;
if(closeToNegativesFirstIter == null){
closeToNegativesFirstIter = new ClassicCounter<>();
if(constVars.expandNegativesWhenSampling)
expandNeg = true;
}
boolean expandPos = false;
if(closeToPositivesFirstIter == null) {
closeToPositivesFirstIter = new ClassicCounter<>();
if(constVars.expandPositivesWhenSampling)
expandPos = true;
}
Counter<Integer> distSimClustersOfPositive = new ClassicCounter<>();
if((expandPos || expandNeg) && !constVars.useWordVectorsToComputeSim){
for(CandidatePhrase s: CollectionUtils.union(constVars.getLearnedWords(answerLabel).keySet(), constVars.getSeedLabelDictionary().get(answerLabel))){
String[] toks = s.getPhrase().split("\\s+");
Integer num = constVars.getWordClassClusters().get(s.getPhrase());
if(num == null)
num = constVars.getWordClassClusters().get(s.getPhrase().toLowerCase());
if(num == null){
for(String tok: toks){
Integer toknum =constVars.getWordClassClusters().get(tok);
if(toknum == null)
toknum =constVars.getWordClassClusters().get(tok.toLowerCase());
if(toknum != null){
distSimClustersOfPositive.incrementCount(toknum);
}
}
} else
distSimClustersOfPositive.incrementCount(num);
}
}
//computing this regardless of expandpos and expandneg because we reject all positive words that occur in negatives (can happen in multi word phrases etc)
Map<String, Collection<CandidatePhrase>> allPossibleNegativePhrases = getAllPossibleNegativePhrases(answerLabel);
GeneralDataset<String, ScorePhraseMeasures> dataset = new RVFDataset<>();
int numpos = 0;
Set<CandidatePhrase> allNegativePhrases = new HashSet<>();
Set<CandidatePhrase> allUnknownPhrases = new HashSet<>();
Set<CandidatePhrase> allPositivePhrases = new HashSet<>();
//Counter<CandidatePhrase> allCloseToPositivePhrases = new ClassicCounter<CandidatePhrase>();
//Counter<CandidatePhrase> allCloseToNegativePhrases = new ClassicCounter<CandidatePhrase>();
//for all sentences brtch
ConstantsAndVariables.DataSentsIterator sentsIter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
while(sentsIter.hasNext()) {
Pair<Map<String, DataInstance>, File> sentsf = sentsIter.next();
Map<String, DataInstance> sents = sentsf.first();
Redwood.log(Redwood.DBG, "Sampling datums from " + sentsf.second());
if (computeRawFreq)
Data.computeRawFreqIfNull(sents, PatternFactory.numWordsCompoundMax);
List<List<String>> threadedSentIds = GetPatternsFromDataMultiClass.getThreadBatches(new ArrayList<>(sents.keySet()), constVars.numThreads);
ExecutorService executor = Executors.newFixedThreadPool(constVars.numThreads);
List<Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>>> list = new ArrayList<>();
//multi-threaded choose positive, negative and unknown
for (List<String> keys : threadedSentIds) {
Callable<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> task = new ChooseDatumsThread(answerLabel, sents, keys,
wordsPatExtracted, allSelectedPatterns, distSimClustersOfPositive, allPossibleNegativePhrases, expandPos, expandNeg);
Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> submit = executor.submit(task);
list.add(submit);
}
// Now retrieve the result
for (Future<Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>>> future : list) {
try {
Quintuple<Set<CandidatePhrase>, Set<CandidatePhrase>, Set<CandidatePhrase>, Counter<CandidatePhrase>, Counter<CandidatePhrase>> result = future.get();
allPositivePhrases.addAll(result.first());
allNegativePhrases.addAll(result.second());
allUnknownPhrases.addAll(result.third());
if(expandPos)
for(Entry<CandidatePhrase, Double> en : result.fourth().entrySet())
closeToPositivesFirstIter.setCount(en.getKey(), en.getValue());
if(expandNeg)
for(Entry<CandidatePhrase, Double> en : result.fifth().entrySet())
closeToNegativesFirstIter.setCount(en.getKey(), en.getValue());
} catch (Exception e) {
executor.shutdownNow();
throw new RuntimeException(e);
}
}
executor.shutdown();
}
//Set<CandidatePhrase> knownPositivePhrases = CollectionUtils.unionAsSet(constVars.getLearnedWords().get(answerLabel).keySet(), constVars.getSeedLabelDictionary().get(answerLabel));
//TODO: this is kinda not nice; how is allpositivephrases different from positivephrases again?
allPositivePhrases.addAll(constVars.getLearnedWords(answerLabel).keySet());
//allPositivePhrases.addAll(knownPositivePhrases);
BufferedWriter logFile = null;
BufferedWriter logFileFeat = null;
if(constVars.logFileVectorSimilarity != null){
logFile = new BufferedWriter(new FileWriter(constVars.logFileVectorSimilarity));
logFileFeat = new BufferedWriter(new FileWriter(constVars.logFileVectorSimilarity+"_feat"));
if(wordVectors != null){
for(CandidatePhrase p : allPositivePhrases){
if(wordVectors.containsKey(p.getPhrase())){
logFile.write(p.getPhrase()+"-P " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ")+"\n");
}
}
}
}
if(constVars.expandPositivesWhenSampling){
//TODO: patwtbyfrew
//Counters.retainTop(allCloseToPositivePhrases, (int) (allCloseToPositivePhrases.size()*constVars.subSampleUnkAsPosUsingSimPercentage));
Redwood.log("Expanding positives by adding " + Counters.toSortedString(closeToPositivesFirstIter, closeToPositivesFirstIter.size(),"%1$s:%2$f", "\t")+ " phrases");
allPositivePhrases.addAll(closeToPositivesFirstIter.keySet());
//write log
if(logFile != null && wordVectors != null && expandNeg){
for(CandidatePhrase p : closeToPositivesFirstIter.keySet()){
if(wordVectors.containsKey(p.getPhrase())){
logFile.write(p.getPhrase()+"-PP " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ")+"\n");
}
}
}
}
if(constVars.expandNegativesWhenSampling){
//TODO: patwtbyfrew
//Counters.retainTop(allCloseToPositivePhrases, (int) (allCloseToPositivePhrases.size()*constVars.subSampleUnkAsPosUsingSimPercentage));
Redwood.log("Expanding negatives by adding " + Counters.toSortedString(closeToNegativesFirstIter , closeToNegativesFirstIter.size(), "%1$s:%2$f","\t")+ " phrases");
allNegativePhrases.addAll(closeToNegativesFirstIter.keySet());
//write log
if(logFile != null && wordVectors != null && expandNeg){
for(CandidatePhrase p : closeToNegativesFirstIter.keySet()){
if(wordVectors.containsKey(p.getPhrase())){
logFile.write(p.getPhrase()+"-NN " + ArrayUtils.toString(wordVectors.get(p.getPhrase()), " ")+"\n");
}
}
}
}
System.out.println("all positive phrases of size " + allPositivePhrases.size() + " are " + allPositivePhrases);
for(CandidatePhrase candidate: allPositivePhrases) {
Counter<ScorePhraseMeasures> feat;
//CandidatePhrase candidate = new CandidatePhrase(l.word());
if (forLearningPattern) {
feat = getPhraseFeaturesForPattern(answerLabel, candidate);
} else {
feat = getFeatures(answerLabel, candidate, wordsPatExtracted.getCounter(candidate), allSelectedPatterns);
}
RVFDatum<String, ScorePhraseMeasures> datum = new RVFDatum<>(feat, "true");
dataset.add(datum);
numpos += 1;
if(logFileFeat !=null){
logFileFeat.write("POSITIVE " + candidate.getPhrase() +"\t" + Counters.toSortedByKeysString(feat,"%1$s:%2$.0f",";","%s")+"\n");
}
}
Redwood.log(Redwood.DBG, "Number of pure negative phrases is " + allNegativePhrases.size());
Redwood.log(Redwood.DBG, "Number of unknown phrases is " + allUnknownPhrases.size());
if(constVars.subsampleUnkAsNegUsingSim){
Set<CandidatePhrase> chosenUnknown = chooseUnknownAsNegatives(allUnknownPhrases, answerLabel, allPositivePhrases, allPossibleNegativePhrases, logFile);
Redwood.log(Redwood.DBG, "Choosing " + chosenUnknown.size() + " unknowns as negative based to their similarity to the positive phrases");
allNegativePhrases.addAll(chosenUnknown);
}
else{
allNegativePhrases.addAll(allUnknownPhrases);
}
if(allNegativePhrases.size() > numpos) {
Redwood.log(Redwood.WARN, "Num of negative (" + allNegativePhrases.size() + ") is higher than number of positive phrases (" + numpos + ") = " +
(allNegativePhrases.size() / (double)numpos) + ". " +
"Capping the number by taking the first numPositives as negative. Consider decreasing perSelectRand");
int i = 0;
Set<CandidatePhrase> selectedNegPhrases = new HashSet<>();
for(CandidatePhrase p : allNegativePhrases){
if (i >= numpos)
break;
selectedNegPhrases.add(p);
i++;
}
allNegativePhrases.clear();
allNegativePhrases = selectedNegPhrases;
}
System.out.println("all negative phrases are " + allNegativePhrases);
for(CandidatePhrase negative: allNegativePhrases){
Counter<ScorePhraseMeasures> feat;
//CandidatePhrase candidate = new CandidatePhrase(l.word());
if (forLearningPattern) {
feat = getPhraseFeaturesForPattern(answerLabel, negative);
} else {
feat = getFeatures(answerLabel, negative, wordsPatExtracted.getCounter(negative), allSelectedPatterns);
}
RVFDatum<String, ScorePhraseMeasures> datum = new RVFDatum<>(feat, "false");
dataset.add(datum);
if(logFile!=null && wordVectors != null && wordVectors.containsKey(negative.getPhrase())){
logFile.write(negative.getPhrase()+"-N"+" " + ArrayUtils.toString(wordVectors.get(negative.getPhrase()), " ")+"\n");
}
if(logFileFeat !=null)
logFileFeat.write("NEGATIVE " + negative.getPhrase() +"\t" + Counters.toSortedByKeysString(feat,"%1$s:%2$.0f",";","%s")+"\n");
}
if(logFile!=null){
logFile.close();
}
if(logFileFeat != null){
logFileFeat.close();
}
System.out.println("Before feature count threshold, dataset stats are ");
dataset.summaryStatistics();
dataset.applyFeatureCountThreshold(constVars.featureCountThreshold);
System.out.println("AFTER feature count threshold of " + constVars.featureCountThreshold + ", dataset stats are ");
dataset.summaryStatistics();
Redwood.log(Redwood.DBG, "Eventually, number of positive datums: " + numpos + " and number of negative datums: " + allNegativePhrases.size());
return dataset;
}
//Map of label to an array of values -- num_items, avg similarity, max similarity
private static Map<String, double[]> getSimilarities(String phrase) {
return similaritiesWithLabeledPhrases.get(phrase);
}
Counter<ScorePhraseMeasures> getPhraseFeaturesForPattern(String label, CandidatePhrase word) {
if (phraseScoresRaw.containsFirstKey(word))
return phraseScoresRaw.getCounter(word);
Counter<ScorePhraseMeasures> scoreslist = new ClassicCounter<>();
//Add features on the word, if any!
if(word.getFeatures()!= null){
scoreslist.addAll(Counters.transform(word.getFeatures(), x -> ScorePhraseMeasures.create(x)));
} else{
Redwood.log(ConstantsAndVariables.extremedebug, "features are null for " + word);
}
if (constVars.usePatternEvalSemanticOdds) {
double dscore = this.getDictOddsScore(word, label, 0);
scoreslist.setCount(ScorePhraseMeasures.SEMANTICODDS, dscore);
}
if (constVars.usePatternEvalGoogleNgram) {
Double gscore = getGoogleNgramScore(word);
if (gscore.isInfinite() || gscore.isNaN()) {
throw new RuntimeException("how is the google ngrams score " + gscore + " for " + word);
}
scoreslist.setCount(ScorePhraseMeasures.GOOGLENGRAM, gscore);
}
if (constVars.usePatternEvalDomainNgram) {
Double gscore = getDomainNgramScore(word.getPhrase());
if (gscore.isInfinite() || gscore.isNaN()) {
throw new RuntimeException("how is the domain ngrams score " + gscore + " for " + word + " when domain raw freq is " + Data.domainNGramRawFreq.getCount(word)
+ " and raw freq is " + Data.rawFreq.getCount(word));
}
scoreslist.setCount(ScorePhraseMeasures.DOMAINNGRAM, gscore);
}
if (constVars.usePatternEvalWordClass) {
Integer wordclass = constVars.getWordClassClusters().get(word.getPhrase());
if(wordclass == null){
wordclass = constVars.getWordClassClusters().get(word.getPhrase().toLowerCase());
}
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.DISTSIM.toString()+"-"+wordclass), 1.0);
}
if (constVars.usePatternEvalEditDistSame) {
double ed = constVars.getEditDistanceScoresThisClass(label, word.getPhrase());
assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;
scoreslist.setCount(ScorePhraseMeasures.EDITDISTSAME, ed);
}
if (constVars.usePatternEvalEditDistOther) {
double ed = constVars.getEditDistanceScoresOtherClass(label, word.getPhrase());
assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;;
scoreslist.setCount(ScorePhraseMeasures.EDITDISTOTHER, ed);
}
if(constVars.usePatternEvalWordShape){
scoreslist.setCount(ScorePhraseMeasures.WORDSHAPE, this.getWordShapeScore(word.getPhrase(), label));
}
if(constVars.usePatternEvalWordShapeStr){
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.WORDSHAPESTR + "-" + this.wordShape(word.getPhrase())), 1.0);
}
if(constVars.usePatternEvalFirstCapital){
scoreslist.setCount(ScorePhraseMeasures.ISFIRSTCAPITAL, StringUtils.isCapitalized(word.getPhrase())? 1.0 :0);
}
if(constVars.usePatternEvalBOW){
for(String s: word.getPhrase().split("\\s+"))
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.BOW +"-"+ s), 1.0);
}
phraseScoresRaw.setCounter(word, scoreslist);
//System.out.println("scores for " + word + " are " + scoreslist);
return scoreslist;
}
/*
Counter<ScorePhraseMeasures> getPhraseFeaturesForPattern(String label, CandidatePhrase word) {
if (phraseScoresRaw.containsFirstKey(word))
return phraseScoresRaw.getCounter(word);
Counter<ScorePhraseMeasures> scoreslist = new ClassicCounter<ScorePhraseMeasures>();
if (constVars.usePatternEvalSemanticOdds) {
assert constVars.dictOddsWeights != null : "usePatternEvalSemanticOdds is true but dictOddsWeights is null for the label " + label;
double dscore = this.getDictOddsScore(word, label, 0);
dscore = logistic(dscore);
scoreslist.setCount(ScorePhraseMeasures.SEMANTICODDS, dscore);
}
if (constVars.usePatternEvalGoogleNgram) {
Double gscore = getGoogleNgramScore(word);
if (gscore.isInfinite() || gscore.isNaN()) {
throw new RuntimeException("how is the google ngrams score " + gscore + " for " + word);
}
gscore = logistic(gscore);
scoreslist.setCount(ScorePhraseMeasures.GOOGLENGRAM, gscore);
}
if (constVars.usePatternEvalDomainNgram) {
Double gscore = getDomainNgramScore(word.getPhrase());
if (gscore.isInfinite() || gscore.isNaN()) {
throw new RuntimeException("how is the domain ngrams score " + gscore + " for " + word + " when domain raw freq is " + Data.domainNGramRawFreq.getCount(word)
+ " and raw freq is " + Data.rawFreq.getCount(word));
}
gscore = logistic(gscore);
scoreslist.setCount(ScorePhraseMeasures.DOMAINNGRAM, gscore);
}
if (constVars.usePatternEvalWordClass) {
double distSimWt = getDistSimWtScore(word.getPhrase(), label);
distSimWt = logistic(distSimWt);
scoreslist.setCount(ScorePhraseMeasures.DISTSIM, distSimWt);
}
if (constVars.usePatternEvalEditDistSame) {
scoreslist.setCount(ScorePhraseMeasures.EDITDISTSAME, constVars.getEditDistanceScoresThisClass(label, word.getPhrase()));
}
if (constVars.usePatternEvalEditDistOther)
scoreslist.setCount(ScorePhraseMeasures.EDITDISTOTHER, constVars.getEditDistanceScoresOtherClass(label, word.getPhrase()));
if(constVars.usePatternEvalWordShape){
scoreslist.setCount(ScorePhraseMeasures.WORDSHAPE, this.getWordShapeScore(word.getPhrase(), label));
}
if(constVars.usePatternEvalWordShapeStr){
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.WORDSHAPE +"-"+ this.wordShape(word.getPhrase())), 1.0);
}
if(constVars.usePatternEvalFirstCapital){
scoreslist.setCount(ScorePhraseMeasures.ISFIRSTCAPITAL, StringUtils.isCapitalized(word.getPhrase())?1.0:0.0);
}
if(constVars.usePatternEvalBOW){
for(String s: word.getPhrase().split("\\s+"))
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.BOW +"-"+ s.toLowerCase()), 1.0);
}
phraseScoresRaw.setCounter(word, scoreslist);
return scoreslist;
}
*/
public double scoreUsingClassifer(edu.stanford.nlp.classify.Classifier classifier, CandidatePhrase word, String label, boolean forLearningPatterns,
Counter<E> patternsThatExtractedPat, Counter<E> allSelectedPatterns) {
if (learnedScores.containsKey(word))
return learnedScores.getCount(word);
double score;
if (scoreClassifierType.equals(ClassifierType.DT)) {
Counter<ScorePhraseMeasures> feat = null;
if (forLearningPatterns)
feat = getPhraseFeaturesForPattern(label, word);
else
feat = this.getFeatures(label, word, patternsThatExtractedPat, allSelectedPatterns);
RVFDatum<String, ScorePhraseMeasures> d = new RVFDatum<>(feat, Boolean.FALSE.toString());
Counter<String> sc = classifier.scoresOf(d);
score = sc.getCount(Boolean.TRUE.toString());
} else if (scoreClassifierType.equals(ClassifierType.LR)) {
LogisticClassifier logcl = ((LogisticClassifier) classifier);
String l = (String) logcl.getLabelForInternalPositiveClass();
Counter<ScorePhraseMeasures> feat;
if (forLearningPatterns)
feat = getPhraseFeaturesForPattern(label, word);
else
feat = this.getFeatures(label, word, patternsThatExtractedPat, allSelectedPatterns);
RVFDatum<String, ScorePhraseMeasures> d = new RVFDatum<>(feat, Boolean.TRUE.toString());
score = logcl.probabilityOf(d);
} else if( scoreClassifierType.equals(ClassifierType.SHIFTLR)){
//convert to basicdatum -- restriction of ShiftLR right now
Counter<ScorePhraseMeasures> feat;
if (forLearningPatterns)
feat = getPhraseFeaturesForPattern(label, word);
else
feat = this.getFeatures(label, word, patternsThatExtractedPat, allSelectedPatterns);
BasicDatum<String, ScorePhraseMeasures> d = new BasicDatum<>(feat.keySet(), Boolean.FALSE.toString());
Counter<String> sc = ((MultinomialLogisticClassifier)classifier).probabilityOf(d);
score = sc.getCount(Boolean.TRUE.toString());
}else if (scoreClassifierType.equals(ClassifierType.SVM) || scoreClassifierType.equals(ClassifierType.RF) ||scoreClassifierType.equals(ClassifierType.LINEAR)) {
Counter<ScorePhraseMeasures> feat = null;
if (forLearningPatterns)
feat = getPhraseFeaturesForPattern(label, word);
else
feat = this.getFeatures(label, word, patternsThatExtractedPat, allSelectedPatterns);
RVFDatum<String, ScorePhraseMeasures> d = new RVFDatum<>(feat, Boolean.FALSE.toString());
Counter<String> sc = classifier.scoresOf(d);
score = sc.getCount(Boolean.TRUE.toString());
} else
throw new RuntimeException("cannot identify classifier " + scoreClassifierType);
this.learnedScores.setCount(word, score);
return score;
}
Counter<ScorePhraseMeasures> getFeatures(String label, CandidatePhrase word, Counter<E> patThatExtractedWord, Counter<E> allSelectedPatterns) {
if (phraseScoresRaw.containsFirstKey(word))
return phraseScoresRaw.getCounter(word);
Counter<ScorePhraseMeasures> scoreslist = new ClassicCounter<>();
//Add features on the word, if any!
if(word.getFeatures()!= null){
scoreslist.addAll(Counters.transform(word.getFeatures(), x -> ScorePhraseMeasures.create(x)));
} else{
Redwood.log(ConstantsAndVariables.extremedebug, "features are null for " + word);
}
if (constVars.usePhraseEvalPatWtByFreq) {
double tfscore = getPatTFIDFScore(word, patThatExtractedWord, allSelectedPatterns);
scoreslist.setCount(ScorePhraseMeasures.PATWTBYFREQ, tfscore);
}
if (constVars.usePhraseEvalSemanticOdds) {
double dscore = this.getDictOddsScore(word, label, 0);
scoreslist.setCount(ScorePhraseMeasures.SEMANTICODDS, dscore);
}
if (constVars.usePhraseEvalGoogleNgram) {
Double gscore = getGoogleNgramScore(word);
if (gscore.isInfinite() || gscore.isNaN()) {
throw new RuntimeException("how is the google ngrams score " + gscore + " for " + word);
}
scoreslist.setCount(ScorePhraseMeasures.GOOGLENGRAM, gscore);
}
if (constVars.usePhraseEvalDomainNgram) {
Double gscore = getDomainNgramScore(word.getPhrase());
if (gscore.isInfinite() || gscore.isNaN()) {
throw new RuntimeException("how is the domain ngrams score " + gscore + " for " + word + " when domain raw freq is " + Data.domainNGramRawFreq.getCount(word)
+ " and raw freq is " + Data.rawFreq.getCount(word));
}
scoreslist.setCount(ScorePhraseMeasures.DOMAINNGRAM, gscore);
}
if (constVars.usePhraseEvalWordClass) {
Integer wordclass = constVars.getWordClassClusters().get(word.getPhrase());
if(wordclass == null){
wordclass = constVars.getWordClassClusters().get(word.getPhrase().toLowerCase());
}
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.DISTSIM.toString()+"-"+wordclass), 1.0);
}
if(constVars.usePhraseEvalWordVector){
Map<String, double[]> sims = getSimilarities(word.getPhrase());
if(sims == null){
//TODO: make more efficient
Map<String, Collection<CandidatePhrase>> allPossibleNegativePhrases = getAllPossibleNegativePhrases(label);
Set<CandidatePhrase> knownPositivePhrases = CollectionUtils.unionAsSet(constVars.getLearnedWords(label).keySet(), constVars.getSeedLabelDictionary().get(label));
computeSimWithWordVectors(Arrays.asList(word), knownPositivePhrases, allPossibleNegativePhrases, label);
sims = getSimilarities(word.getPhrase());
}
assert sims != null : " Why are there no similarities for " + word;
double avgPosSim = sims.get(label)[Similarities.AVGSIM.ordinal()];
double maxPosSim = sims.get(label)[Similarities.MAXSIM.ordinal()];
double sumNeg = 0, maxNeg = Double.MIN_VALUE;
double allNumItems =0;
for(Entry<String, double[]> simEn: sims.entrySet()){
if(simEn.getKey().equals(label))
continue;
double numItems = simEn.getValue()[Similarities.NUMITEMS.ordinal()];
sumNeg += simEn.getValue()[Similarities.AVGSIM.ordinal()]*numItems;
allNumItems += numItems;
double maxNegLabel =simEn.getValue()[Similarities.MAXSIM.ordinal()];
if(maxNeg < maxNegLabel)
maxNeg = maxNegLabel;
}
double avgNegSim = sumNeg / allNumItems;
scoreslist.setCount(ScorePhraseMeasures.WORDVECPOSSIMAVG, avgPosSim);
scoreslist.setCount(ScorePhraseMeasures.WORDVECPOSSIMMAX, maxPosSim);
scoreslist.setCount(ScorePhraseMeasures.WORDVECNEGSIMAVG, avgNegSim);
scoreslist.setCount(ScorePhraseMeasures.WORDVECNEGSIMAVG, maxNeg);
}
if (constVars.usePhraseEvalEditDistSame) {
double ed = constVars.getEditDistanceScoresThisClass(label, word.getPhrase());
assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;
scoreslist.setCount(ScorePhraseMeasures.EDITDISTSAME, ed);
}
if (constVars.usePhraseEvalEditDistOther) {
double ed = constVars.getEditDistanceScoresOtherClass(label, word.getPhrase());
assert ed <= 1 : " how come edit distance from the true class is " + ed + " for word " + word;;
scoreslist.setCount(ScorePhraseMeasures.EDITDISTOTHER, ed);
}
if(constVars.usePhraseEvalWordShape){
scoreslist.setCount(ScorePhraseMeasures.WORDSHAPE, this.getWordShapeScore(word.getPhrase(), label));
}
if(constVars.usePhraseEvalWordShapeStr){
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.WORDSHAPESTR + "-" + this.wordShape(word.getPhrase())), 1.0);
}
if(constVars.usePhraseEvalFirstCapital){
scoreslist.setCount(ScorePhraseMeasures.ISFIRSTCAPITAL, StringUtils.isCapitalized(word.getPhrase())? 1.0 :0);
}
if(constVars.usePhraseEvalBOW){
for(String s: word.getPhrase().split("\\s+"))
scoreslist.setCount(ScorePhraseMeasures.create(ScorePhraseMeasures.BOW +"-"+ s), 1.0);
}
phraseScoresRaw.setCounter(word, scoreslist);
//System.out.println("scores for " + word + " are " + scoreslist);
return scoreslist;
}
}