package edu.stanford.nlp.semparse.open.ling; import java.util.*; import edu.stanford.nlp.semparse.open.ling.LingData.POSType; import edu.stanford.nlp.semparse.open.model.FeatureVector; import edu.stanford.nlp.semparse.open.util.VectorAverager; /** * AveragedWordVector computes and stores averaged neural net word vectors. */ public class AveragedWordVector { // The average (mean) of the word vectors of all tokens public double[] averaged; // Divide each word vector by word frequency public double[] freqWeighted; // Use only open POS class words public double[] openPOSOnly; // Use only open POS class words and divide each word vector by word frequency public double[] freqWeightedOpenPOSOnly; // Term-wise minimum and maximum public double[] min, max, minmax; public AveragedWordVector(Collection<String> phrases) { VectorAverager normalAverager = new VectorAverager(WordVectorTable.numDimensions), freqWeightedAverager = new VectorAverager(WordVectorTable.numDimensions), openPOSOnlyAverager = new VectorAverager(WordVectorTable.numDimensions), freqWeightedOpenPOSAverager = new VectorAverager(WordVectorTable.numDimensions); for (String phrase : phrases) { LingData lingData = LingData.get(phrase); for (int i = 0; i < lingData.length; i++) { String token = lingData.tokens.get(i); int freq = BrownClusterTable.getSmoothedFrequency(token); double[] vector = WordVectorTable.getVector(token); normalAverager.add(vector); freqWeightedAverager.add(vector, 1.0 / freq); if (lingData.posTypes.get(i) == POSType.OPEN) { openPOSOnlyAverager.add(vector); freqWeightedOpenPOSAverager.add(vector, 1.0 / freq); } } } averaged = normalAverager.getAverage(); freqWeighted = freqWeightedAverager.getAverage(); openPOSOnly = openPOSOnlyAverager.getAverage(); freqWeightedOpenPOSOnly = freqWeightedOpenPOSAverager.getAverage(); min = normalAverager.getMin(); max = normalAverager.getMax(); minmax = normalAverager.getMinmax(); } public AveragedWordVector(String phrase) { // Slightly inefficient, but will not be called often. this(Arrays.asList(phrase)); } public double[] get(boolean freqWeighted, boolean openPOSOnly) { if (freqWeighted) { return openPOSOnly ? this.freqWeightedOpenPOSOnly : this.freqWeighted; } else { return openPOSOnly ? this.openPOSOnly : this.averaged; } } /** * Add general features of the form name...[i] * with value = each element of the averaged word vector. */ @Deprecated public void addTermwiseFeatures(FeatureVector v, String domain, String name) { if (averaged != null) for (int i = 0; i < WordVectorTable.numDimensions; i++) v.add(domain, name + "[" + i + "]", averaged[i]); if (freqWeighted != null) for (int i = 0; i < WordVectorTable.numDimensions; i++) v.add(domain, name + "-freq-weighted[" + i + "]", freqWeighted[i]); if (openPOSOnly != null) for (int i = 0; i < WordVectorTable.numDimensions; i++) v.add(domain, name + "-open-pos[" + i + "]", openPOSOnly[i]); } /** * Add general features of the form name...[i] * with value = term-wise product between the averaged word vector and the given vector. */ @Deprecated public void addTermwiseFeatures(FeatureVector v, String domain, String name, double[] factor) { if (factor == null) return; if (averaged != null) for (int i = 0; i < WordVectorTable.numDimensions; i++) v.add(domain, name + "[" + i + "]", averaged[i] * factor[i]); if (freqWeighted != null) for (int i = 0; i < WordVectorTable.numDimensions; i++) v.add(domain, name + "-freq-weighted[" + i + "]", freqWeighted[i] * factor[i]); if (openPOSOnly != null) for (int i = 0; i < WordVectorTable.numDimensions; i++) v.add(domain, name + "-open-pos[" + i + "]", openPOSOnly[i] * factor[i]); } // Too slow and memory consuming @Deprecated public static void addCrossProductFeatures(FeatureVector v, String domain, String name1, String name2, double[] factor1, double[] factor2) { if (factor1 == null || factor2 == null) return; for (int i = 0; i < WordVectorTable.numDimensions; i++) { for (int j = 0; j < WordVectorTable.numDimensions; j++) { v.add(domain, name1 + "[" + i + "]*" + name2 + "[" + j + "]", factor1[i] * factor2[j]); } } } }