/* * Copyright 2014 Radialpoint SafeCare Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.radialpoint.word2vec.lucene; import java.io.IOException; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; import org.apache.lucene.analysis.tokenattributes.TypeAttribute; import com.radialpoint.word2vec.Distance.ScoredTerm; import com.radialpoint.word2vec.query_expansion.QueryExpander; /** * This filter is not intended for indexing, just for query expansion. Trying to use this analyzer for indexing will * result in a very high indexing time and an unnecessarily large index. */ public class Word2VecFilter extends TokenFilter { public static final String TYPE_SYNONYM = "SYNONYM"; private final QueryExpander expander; private final int size; private final boolean multiword; private final Set<String> output; private Iterator<String> outputIt; private final String[] terms; private final String[] termsTmp; /** * Expand the terms in a token stream using word2vec vectors. * * @param input * the original tokenstream * @param expander * the query expander using word2vec vectors * @param size * the number of terms to look-ahead to perform the word2vec queries * @param multiword * whether to also combine the terms using underscores and to split returned terms on underscores. */ public Word2VecFilter(TokenStream input, QueryExpander expander, int size, boolean multiword) { super(input); this.expander = expander; this.size = size; this.multiword = multiword; this.output = new HashSet<String>(); this.outputIt = null; this.terms = new String[size]; this.termsTmp = new String[size]; } private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); private final PositionIncrementAttribute posIncrAtt = addAttribute(PositionIncrementAttribute.class); private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class); @Override final public boolean incrementToken() throws java.io.IOException { if (outputIt == null || !outputIt.hasNext()) { if (!input.incrementToken()) return false; State state = input.captureState(); // if new word, advance up to size collecting the terms, boolean consumed = false; for (int i = 0; i < this.size; i++) { if (consumed) this.terms[i] = null; else { this.terms[i] = termAtt.toString(); if (!input.incrementToken()) consumed = true; } } // then expand and record the expansions this.output.clear(); for (int i = 0; i < this.size; i++) { String term = this.terms[i]; this.termsTmp[i] = expander.isTermKnown(term) ? term : null; } List<ScoredTerm> expansion = expander.expand(this.termsTmp); for (ScoredTerm scoredTerm : expansion) { String term = scoredTerm.getTerm(); if (this.multiword && term.indexOf('_') >= 0) { String[] parts = term.split("_"); for (String subTerm : parts) this.output.add(subTerm); } else this.output.add(term); } if (this.multiword) { // combine pairs of words int current = 0; while (current < this.size - 1) { if (this.terms[current] != null && this.terms[current + 1] != null) { String mwe = this.terms[current] + "_" + this.terms[current + 1]; if (expander.isTermKnown(mwe)) { // expand String currentTerm = this.termsTmp[current]; String currentTermNext = this.termsTmp[current + 1]; this.termsTmp[current + 1] = null; this.termsTmp[current] = mwe; expansion = expander.expand(this.termsTmp); for (ScoredTerm scoredTerm : expansion) { String term = scoredTerm.getTerm(); if (this.multiword && term.indexOf('_') >= 0) { String[] parts = term.split("_"); for (String subTerm : parts) this.output.add(subTerm); } else this.output.add(term); } this.termsTmp[current + 1] = currentTermNext; this.termsTmp[current] = currentTerm; } } current++; } } input.restoreState(state); this.outputIt = output.iterator(); return true; } else { String next = outputIt.next(); termAtt.copyBuffer(next.toCharArray(), 0, next.length()); posIncrAtt.setPositionIncrement(0); typeAtt.setType(TYPE_SYNONYM); return true; } } @Override public void reset() throws IOException { super.reset(); this.output.clear(); this.outputIt = null; } }