/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.utils.vectors;
import com.google.common.base.Function;
import com.google.common.collect.Collections2;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.lucene.util.PriorityQueue;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.FileLineIterator;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
public final class VectorHelper {
private static final Pattern TAB_PATTERN = Pattern.compile("\t");
private VectorHelper() { }
public static String vectorToCSVString(Vector vector, boolean namesAsComments) throws IOException {
Appendable bldr = new StringBuilder(2048);
vectorToCSVString(vector, namesAsComments, bldr);
return bldr.toString();
}
public static String buildJson(Iterable<Pair<String,Double>> iterable) {
return buildJson(iterable, new StringBuilder(2048));
}
public static String buildJson(Iterable<Pair<String,Double>> iterable, StringBuilder bldr) {
bldr.append('{');
for (Pair<String,Double> p : iterable) {
bldr.append(p.getFirst());
bldr.append(':');
bldr.append(p.getSecond());
bldr.append(',');
}
if(bldr.length() > 1) {
bldr.setCharAt(bldr.length() - 1, '}');
}
return bldr.toString();
}
public static String vectorToSortedString(Vector vector, String[] dictionary) {
return vectorToJson(vector, dictionary, Integer.MAX_VALUE, true);
}
public static List<Pair<Integer, Double>> topEntries(Vector vector, int maxEntries) {
PriorityQueue<Pair<Integer,Double>> queue = new TDoublePQ<Integer>(-1, maxEntries);
Iterator<Vector.Element> it = vector.iterateNonZero();
while(it.hasNext()) {
Vector.Element e = it.next();
queue.insertWithOverflow(Pair.of(e.index(), e.get()));
}
List<Pair<Integer, Double>> entries = Lists.newArrayList();
Pair<Integer, Double> pair;
while((pair = queue.pop()) != null) {
if(pair.getFirst() > -1) {
entries.add(pair);
}
}
Collections.sort(entries, Ordering.natural().reverse());
return entries;
}
public static List<Pair<Integer, Double>> firstEntries(Vector vector, int maxEntries) {
List<Pair<Integer, Double>> entries = Lists.newArrayList();
Iterator<Vector.Element> it = vector.iterateNonZero();
int i = 0;
while(it.hasNext() && i++ < maxEntries) {
Vector.Element e = it.next();
entries.add(Pair.of(e.index(), e.get()));
}
return entries;
}
public static List<Pair<String, Double>> toWeightedTerms(Collection<Pair<Integer, Double>> entries,
final String[] dictionary) {
return Lists.newArrayList(Collections2.transform(entries,
new Function<Pair<Integer, Double>, Pair<String, Double>>() {
@Override
public Pair<String, Double> apply(Pair<Integer, Double> p) {
return Pair.of(dictionary[p.getFirst()], p.getSecond());
}
}));
}
public static String vectorToJson(Vector vector, String[] dictionary, int maxEntries, boolean sort) {
return buildJson(toWeightedTerms(sort
? topEntries(vector, maxEntries)
: firstEntries(vector, maxEntries), dictionary));
}
public static void vectorToCSVString(Vector vector,
boolean namesAsComments,
Appendable bldr) throws IOException {
if (namesAsComments && vector instanceof NamedVector){
bldr.append('#').append(((NamedVector) vector).getName()).append('\n');
}
Iterator<Vector.Element> iter = vector.iterator();
boolean first = true;
while (iter.hasNext()) {
if (first) {
first = false;
} else {
bldr.append(',');
}
Vector.Element elt = iter.next();
bldr.append(String.valueOf(elt.get()));
}
bldr.append('\n');
}
/**
* Read in a dictionary file. Format is:
*
* <pre>
* term DocFreq Index
* </pre>
*/
public static String[] loadTermDictionary(File dictFile) throws IOException {
return loadTermDictionary(new FileInputStream(dictFile));
}
/**
* Read a dictionary in {@link SequenceFile} generated by
* {@link org.apache.mahout.vectorizer.DictionaryVectorizer}
*
* @param filePattern
* <PATH TO DICTIONARY>/dictionary.file-*
*/
public static String[] loadTermDictionary(Configuration conf, String filePattern) {
OpenObjectIntHashMap<String> dict = new OpenObjectIntHashMap<String>();
for (Pair<Text,IntWritable> record :
new SequenceFileDirIterable<Text,IntWritable>(new Path(filePattern), PathType.GLOB,
null, null, true, conf)) {
dict.put(record.getFirst().toString(), record.getSecond().get());
}
String[] dictionary = new String[dict.size()];
for (String feature : dict.keys()) {
dictionary[dict.get(feature)] = feature;
}
return dictionary;
}
/**
* Read in a dictionary file. Format is: First line is the number of entries
*
* <pre>
* term DocFreq Index
* </pre>
*/
private static String[] loadTermDictionary(InputStream is) throws IOException {
FileLineIterator it = new FileLineIterator(is);
int numEntries = Integer.parseInt(it.next());
String[] result = new String[numEntries];
while (it.hasNext()) {
String line = it.next();
if (line.startsWith("#")) {
continue;
}
String[] tokens = TAB_PATTERN.split(line);
if (tokens.length < 3) {
continue;
}
int index = Integer.parseInt(tokens[2]); // tokens[1] is the doc freq
result[index] = tokens[0];
}
return result;
}
private static class TDoublePQ<T> extends PriorityQueue<Pair<T, Double>> {
private final T sentinel;
private TDoublePQ(T sentinel, int size) {
initialize(size);
this.sentinel = sentinel;
}
@Override
protected boolean lessThan(Pair<T, Double> a,
Pair<T, Double> b) {
return a.getSecond().compareTo(b.getSecond()) < 0;
}
@Override
protected Pair<T, Double> getSentinelObject() {
return Pair.of(sentinel, Double.NEGATIVE_INFINITY);
}
}
}