/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.utils.vectors.lucene;
import java.io.File;
import java.io.IOException;
import java.io.Writer;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import com.google.common.io.Files;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.utils.vectors.TermInfo;
import org.apache.mahout.utils.vectors.io.DelimitedTermInfoWriter;
import org.apache.mahout.utils.vectors.io.SequenceFileVectorWriter;
import org.apache.mahout.utils.vectors.io.VectorWriter;
import org.apache.mahout.vectorizer.TF;
import org.apache.mahout.vectorizer.TFIDF;
import org.apache.mahout.vectorizer.Weight;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public final class Driver {
private static final Logger log = LoggerFactory.getLogger(Driver.class);
private String luceneDir;
private String outFile;
private String field;
private String idField;
private String dictOut;
private String weightType = "tfidf";
private String delimiter = "\t";
private double norm = LuceneIterable.NO_NORMALIZING;
private long maxDocs = Long.MAX_VALUE;
private int minDf = 1;
private int maxDFPercent = 99;
private double maxPercentErrorDocs = 0.0;
public void dumpVectors() throws IOException {
File file = new File(luceneDir);
Preconditions.checkArgument(file.isDirectory(),
"Lucene directory: " + file.getAbsolutePath()
+ " does not exist or is not a directory");
Preconditions.checkArgument(maxDocs >= 0, "maxDocs must be >= 0");
Preconditions.checkArgument(minDf >= 1, "minDf must be >= 1");
Preconditions.checkArgument(maxDFPercent <= 99, "maxDFPercent must be <= 99");
Directory dir = FSDirectory.open(file);
IndexReader reader = IndexReader.open(dir, true);
Weight weight;
if ("tf".equalsIgnoreCase(weightType)) {
weight = new TF();
} else if ("tfidf".equalsIgnoreCase(weightType)) {
weight = new TFIDF();
} else {
throw new IllegalArgumentException("Weight type " + weightType + " is not supported");
}
TermInfo termInfo = new CachedTermInfo(reader, field, minDf, maxDFPercent);
VectorMapper mapper = new TFDFMapper(reader, weight, termInfo);
LuceneIterable iterable;
if (norm == LuceneIterable.NO_NORMALIZING) {
iterable = new LuceneIterable(reader, idField, field, mapper, LuceneIterable.NO_NORMALIZING, maxPercentErrorDocs);
} else {
iterable = new LuceneIterable(reader, idField, field, mapper, norm, maxPercentErrorDocs);
}
log.info("Output File: {}", outFile);
VectorWriter vectorWriter = getSeqFileWriter(outFile);
try {
long numDocs = vectorWriter.write(iterable, maxDocs);
log.info("Wrote: {} vectors", numDocs);
} finally {
Closeables.closeQuietly(vectorWriter);
}
File dictOutFile = new File(dictOut);
log.info("Dictionary Output file: {}", dictOutFile);
Writer writer = Files.newWriter(dictOutFile, Charsets.UTF_8);
DelimitedTermInfoWriter tiWriter = new DelimitedTermInfoWriter(writer, delimiter, field);
try {
tiWriter.write(termInfo);
} finally {
Closeables.closeQuietly(tiWriter);
}
}
public static void main(String[] args) throws IOException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
Option inputOpt = obuilder.withLongName("dir").withRequired(true).withArgument(
abuilder.withName("dir").withMinimum(1).withMaximum(1).create())
.withDescription("The Lucene directory").withShortName("d").create();
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription("The output file")
.withShortName("o").create();
Option fieldOpt = obuilder.withLongName("field").withRequired(true).withArgument(
abuilder.withName("field").withMinimum(1).withMaximum(1).create()).withDescription(
"The field in the index").withShortName("f").create();
Option idFieldOpt = obuilder.withLongName("idField").withRequired(false).withArgument(
abuilder.withName("idField").withMinimum(1).withMaximum(1).create()).withDescription(
"The field in the index containing the index. If null, then the Lucene internal doc "
+ "id is used which is prone to error if the underlying index changes").create();
Option dictOutOpt = obuilder.withLongName("dictOut").withRequired(true).withArgument(
abuilder.withName("dictOut").withMinimum(1).withMaximum(1).create()).withDescription(
"The output of the dictionary").withShortName("t").create();
Option weightOpt = obuilder.withLongName("weight").withRequired(false).withArgument(
abuilder.withName("weight").withMinimum(1).withMaximum(1).create()).withDescription(
"The kind of weight to use. Currently TF or TFIDF").withShortName("w").create();
Option delimiterOpt = obuilder.withLongName("delimiter").withRequired(false).withArgument(
abuilder.withName("delimiter").withMinimum(1).withMaximum(1).create()).withDescription(
"The delimiter for outputting the dictionary").withShortName("l").create();
Option powerOpt = obuilder.withLongName("norm").withRequired(false).withArgument(
abuilder.withName("norm").withMinimum(1).withMaximum(1).create()).withDescription(
"The norm to use, expressed as either a double or \"INF\" if you want to use the Infinite norm. "
+ "Must be greater or equal to 0. The default is not to normalize").withShortName("n").create();
Option maxOpt = obuilder.withLongName("max").withRequired(false).withArgument(
abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription(
"The maximum number of vectors to output. If not specified, then it will loop over all docs")
.withShortName("m").create();
Option minDFOpt = obuilder.withLongName("minDF").withRequired(false).withArgument(
abuilder.withName("minDF").withMinimum(1).withMaximum(1).create()).withDescription(
"The minimum document frequency. Default is 1").withShortName("md").create();
Option maxDFPercentOpt = obuilder.withLongName("maxDFPercent").withRequired(false).withArgument(
abuilder.withName("maxDFPercent").withMinimum(1).withMaximum(1).create()).withDescription(
"The max percentage of docs for the DF. Can be used to remove really high frequency terms."
+ " Expressed as an integer between 0 and 100. Default is 99.").withShortName("x").create();
Option maxPercentErrorDocsOpt = obuilder.withLongName("maxPercentErrorDocs").withRequired(false).withArgument(
abuilder.withName("maxPercentErrorDocs").withMinimum(1).withMaximum(1).create()).withDescription(
"The max percentage of docs that can have a null term vector. These are noise document and can occur if the " +
"analyzer used strips out all terms in the target field. This percentage is expressed as a value between 0 and 1. " +
"The default is 0.").withShortName("err").create();
Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
.create();
Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(idFieldOpt).withOption(
outputOpt).withOption(delimiterOpt).withOption(helpOpt).withOption(fieldOpt).withOption(maxOpt)
.withOption(dictOutOpt).withOption(powerOpt).withOption(maxDFPercentOpt)
.withOption(weightOpt).withOption(minDFOpt).withOption(maxPercentErrorDocsOpt).create();
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
if (cmdLine.hasOption(inputOpt)) { // Lucene case
Driver luceneDriver = new Driver();
luceneDriver.setLuceneDir(cmdLine.getValue(inputOpt).toString());
if (cmdLine.hasOption(maxOpt)) {
luceneDriver.setMaxDocs(Long.parseLong(cmdLine.getValue(maxOpt).toString()));
}
if (cmdLine.hasOption(weightOpt)) {
luceneDriver.setWeightType(cmdLine.getValue(weightOpt).toString());
}
luceneDriver.setField(cmdLine.getValue(fieldOpt).toString());
if (cmdLine.hasOption(minDFOpt)) {
luceneDriver.setMinDf(Integer.parseInt(cmdLine.getValue(minDFOpt).toString()));
}
if (cmdLine.hasOption(maxDFPercentOpt)) {
luceneDriver.setMaxDFPercent(Integer.parseInt(cmdLine.getValue(maxDFPercentOpt).toString()));
}
if (cmdLine.hasOption(powerOpt)) {
String power = cmdLine.getValue(powerOpt).toString();
if ("INF".equals(power)) {
luceneDriver.setNorm(Double.POSITIVE_INFINITY);
} else {
luceneDriver.setNorm(Double.parseDouble(power));
}
}
if (cmdLine.hasOption(idFieldOpt)) {
luceneDriver.setIdField(cmdLine.getValue(idFieldOpt).toString());
}
if (cmdLine.hasOption(maxPercentErrorDocsOpt)) {
luceneDriver.setMaxPercentErrorDocs(Double.parseDouble(cmdLine.getValue(maxPercentErrorDocsOpt).toString()));
}
luceneDriver.setOutFile(cmdLine.getValue(outputOpt).toString());
luceneDriver.setDelimiter(cmdLine.hasOption(delimiterOpt) ? cmdLine.getValue(delimiterOpt).toString() : "\t");
luceneDriver.setDictOut(cmdLine.getValue(dictOutOpt).toString());
luceneDriver.dumpVectors();
}
} catch (OptionException e) {
log.error("Exception", e);
CommandLineUtil.printHelp(group);
}
}
private static VectorWriter getSeqFileWriter(String outFile) throws IOException {
Path path = new Path(outFile);
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(conf);
// TODO: Make this parameter driven
SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class,
VectorWritable.class);
return new SequenceFileVectorWriter(seqWriter);
}
public void setLuceneDir(String luceneDir) {
this.luceneDir = luceneDir;
}
public void setMaxDocs(long maxDocs) {
this.maxDocs = maxDocs;
}
public void setWeightType(String weightType) {
this.weightType = weightType;
}
public void setField(String field) {
this.field = field;
}
public void setMinDf(int minDf) {
this.minDf = minDf;
}
public void setMaxDFPercent(int maxDFPercent) {
this.maxDFPercent = maxDFPercent;
}
public void setNorm(double norm) {
this.norm = norm;
}
public void setIdField(String idField) {
this.idField = idField;
}
public void setOutFile(String outFile) {
this.outFile = outFile;
}
public void setDelimiter(String delimiter) {
this.delimiter = delimiter;
}
public void setDictOut(String dictOut) {
this.dictOut = dictOut;
}
public void setMaxPercentErrorDocs(double maxPercentErrorDocs) {
this.maxPercentErrorDocs = maxPercentErrorDocs;
}
}