/* * Copyright 2014 Radialpoint SafeCare Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.radialpoint.word2vec.query_expansion; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.util.List; import com.radialpoint.word2vec.Distance; import com.radialpoint.word2vec.OutOfVocabularyException; import com.radialpoint.word2vec.Vectors; import com.radialpoint.word2vec.VectorsException; import com.sampullara.cli.Args; import com.sampullara.cli.Argument; /** * This small program takes a file with comma-separated queries and query IDs and produce the expanded query terms in separate files. * Useful for running experiments. */ public class ExpandQuery { @Argument(alias = "v", description = "File containing word2vec vectors in binary format", required = true) private static String vectorsFileName; @Argument(alias = "c", description = "Whether to combine all the query terms or expand them independently", required = false) private static Boolean combineTerms = false; @Argument(alias = "s", description = "Term selection strategy", required = false) private static String termSelectionString = "ALL"; @Argument(alias = "i", description = "File containing the queries, one query per line", required = true) private static String inputFileName; @Argument(alias = "o", description = "Output folder", required = true) private static String outputFolderName; /** * @param args * , three arguments, whether all the words are combined or separated * @throws IOException * @throws OutOfVocabularyException */ public static void main(String[] args) throws VectorsException, IOException { // arguments try { Args.parse(ExpandQuery.class, args); } catch (IllegalArgumentException e) { Args.usage(ExpandQuery.class); System.exit(1); } QueryExpander queryExpander = new QueryExpander(new Vectors(new FileInputStream(new File(vectorsFileName))), combineTerms, QueryExpander.TermSelection.valueOf(termSelectionString)); // read queries, one per line BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(inputFileName), "UTF-8")); String line = br.readLine(); line = br.readLine(); while (line != null) { // qid and query String[] parts = line.split(","); String qid = parts[0]; String query = parts[1]; List<Distance.ScoredTerm> expansion = queryExpander.expand(query); PrintWriter pw = new PrintWriter(new FileWriter(new File(outputFolderName, qid + ".terms"))); for (Distance.ScoredTerm scoredTerm : expansion) pw.println(scoredTerm.getTerm() + "\t" + scoredTerm.getScore()); pw.close(); line = br.readLine(); } br.close(); } }