/*
* Copyright 2014 Radialpoint SafeCare Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.radialpoint.word2vec.query_expansion;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.radialpoint.word2vec.Distance;
import com.radialpoint.word2vec.Distance.ScoredTerm;
import com.radialpoint.word2vec.OutOfVocabularyException;
import com.radialpoint.word2vec.Vectors;
/**
* Expand a query using a given word2vec vectors.
*/
public class QueryExpander {
/**
* word2vec Vectors used for expansion.
*/
private Vectors vectors;
/**
* Whether to consider the query terms independently or jointly
*/
private boolean combinedVector;
/**
* Term selection strategy
*/
private TermSelection termSelectionStrategy;
public static enum TermSelection {
/**
* Return all terms
*/
ALL {
protected List<ScoredTerm> select(List<ScoredTerm> terms) {
return terms;
}
},
/**
* Return the top term
*/
TOP {
protected List<ScoredTerm> select(List<ScoredTerm> terms) {
return selectTopN(terms, 1);
}
},
/**
* Return the top 2 term
*/
TOP2 {
protected List<ScoredTerm> select(List<ScoredTerm> terms) {
return selectTopN(terms, 2);
}
},
/**
* Return the top 5 terms
*/
TOP5 {
protected List<ScoredTerm> select(List<ScoredTerm> terms) {
return selectTopN(terms, 5);
}
},
/**
* Cut off at 75% cosine (absolute)
*/
CUT_75_ABS {
protected List<ScoredTerm> select(List<ScoredTerm> terms) {
return selectWithThreshold(terms, 0.75f);
}
},
/**
* Cut off at 66% cosine (absolute)
*/
CUT_66_ABS {
protected List<ScoredTerm> select(List<ScoredTerm> terms) {
return selectWithThreshold(terms, 0.66f);
}
},
/**
* Cut off at 90% cosine relative to top term
*/
CUT_90_REL {
protected List<ScoredTerm> select(List<ScoredTerm> terms) {
if (terms.isEmpty())
return terms;
float thr = terms.get(0).getScore() * 0.9f;
return selectWithThreshold(terms, thr);
}
};
protected List<ScoredTerm> selectTopN(List<ScoredTerm> terms, int n) {
return terms.size() < n ? terms : terms.subList(0, n);
}
protected List<ScoredTerm> selectWithThreshold(List<ScoredTerm> terms, float thr) {
int idx = 0;
while (idx < terms.size() && terms.get(idx).getScore() > thr)
idx++;
return terms.subList(0, idx);
}
protected abstract List<ScoredTerm> select(List<ScoredTerm> terms);
}
public QueryExpander(Vectors vectors) {
this(vectors, true, TermSelection.ALL);
}
public QueryExpander(Vectors vectors, boolean combinedVector) {
this(vectors, combinedVector, TermSelection.ALL);
}
public QueryExpander(Vectors vectors, boolean combinedVector, TermSelection termSelectionStrategy) {
this.vectors = vectors;
this.combinedVector = combinedVector;
this.termSelectionStrategy = termSelectionStrategy;
}
/**
* Expand a query by combining all terms into a vector.
*
* TODO: this method and Distance.measure should be refactored to receive the lists and arrays as parameters and
* re-use them, for memory efficiency.
*/
public List<Distance.ScoredTerm> expand(String[] terms) {
// check for missing terms
List<String> goodTerms = new ArrayList<String>();
for (String term : terms)
if (vectors.hasTerm(term))
goodTerms.add(term);
if (goodTerms.size() != terms.length)
terms = goodTerms.toArray(new String[0]);
// expanded terms according to GENCO or GENW
List<Distance.ScoredTerm> expansion = null;
try {
expansion = Distance.measure(vectors, 50, terms);
} catch (OutOfVocabularyException e) {
// can't happen
throw new IllegalStateException(e);
}
return this.termSelectionStrategy.select(expansion);
}
/**
* Expand a query, either by combining all terms into a vector or by merging expansion lists for different vectors.
*/
public List<Distance.ScoredTerm> expand(String query) {
// calculate the list of terms
String[] terms = query.split("\\s+");
// check for missing terms
List<String> goodTerms = new ArrayList<String>();
for (String term : terms)
if (term != null && vectors.hasTerm(term))
goodTerms.add(term);
if (goodTerms.size() != terms.length)
terms = goodTerms.toArray(new String[0]);
// expanded terms according to GENCO or GENW
List<Distance.ScoredTerm> expansion = null;
try {
if (combinedVector)
expansion = Distance.measure(vectors, 50, terms);
else {
expansion = new ArrayList<Distance.ScoredTerm>();
for (String term : terms)
merge(expansion, Distance.measure(vectors, 50, new String[] { term }));
}
} catch (OutOfVocabularyException e) {
// can't happen
throw new IllegalStateException(e);
}
return this.termSelectionStrategy.select(expansion);
}
/**
* Whether the term is known
*
* @param term
* to check
* @return true if this query expander knows about that term (i.e., it is in vocabulary).
*/
public boolean isTermKnown(String term) {
return this.vectors.hasTerm(term);
}
private void merge(List<ScoredTerm> expansion, List<ScoredTerm> extra) {
Map<String, Integer> termPos = new HashMap<String, Integer>();
for (int i = 0; i < expansion.size(); i++)
termPos.put(expansion.get(i).getTerm(), i);
for (Distance.ScoredTerm scoredTerm : extra) {
Integer pos = termPos.get(scoredTerm.getTerm());
if (pos == null)
expansion.add(scoredTerm);
else
expansion.set(pos,
new ScoredTerm(scoredTerm.getTerm(), expansion.get(pos).getScore() + scoredTerm.getScore()));
}
Collections.sort(expansion, new Comparator<ScoredTerm>() {
public int compare(ScoredTerm o1, ScoredTerm o2) {
return new Float(o2.getScore()).compareTo(o1.getScore());
}
});
}
}