/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.udf.generic;
import java.util.List;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Collections;
import java.util.Iterator;
import java.util.Comparator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A generic, re-usable n-gram estimation class that supports partial aggregations.
* The algorithm is based on the heuristic from the following paper:
* Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm",
* J. Machine Learning Research 11 (2010), pp. 849--872.
*
* In particular, it is guaranteed that frequencies will be under-counted. With large
* data and a reasonable precision factor, this undercounting appears to be on the order
* of 5%.
*/
public class NGramEstimator {
/* Class private variables */
private int k;
private int pf;
private int n;
private HashMap<ArrayList<String>, Double> ngrams;
/**
* Creates a new n-gram estimator object. The 'n' for n-grams is computed dynamically
* when data is fed to the object.
*/
public NGramEstimator() {
k = 0;
pf = 0;
n = 0;
ngrams = new HashMap<ArrayList<String>, Double>();
}
/**
* Returns true if the 'k' and 'pf' parameters have been set.
*/
public boolean isInitialized() {
return (k != 0);
}
/**
* Sets the 'k' and 'pf' parameters.
*/
public void initialize(int pk, int ppf, int pn) throws HiveException {
assert(pk > 0 && ppf > 0 && pn > 0);
k = pk;
pf = ppf;
n = pn;
// enforce a minimum precision factor
if(k * pf < 1000) {
pf = 1000 / k;
}
}
/**
* Resets an n-gram estimator object to its initial state.
*/
public void reset() {
ngrams.clear();
n = pf = k = 0;
}
/**
* Returns the final top-k n-grams in a format suitable for returning to Hive.
*/
public ArrayList<Object[]> getNGrams() throws HiveException {
trim(true);
if(ngrams.size() < 1) { // SQL standard - return null for zero elements
return null;
}
// Sort the n-gram list by frequencies in descending order
ArrayList<Object[]> result = new ArrayList<Object[]>();
ArrayList<Map.Entry<ArrayList<String>, Double>> list = new ArrayList(ngrams.entrySet());
Collections.sort(list, new Comparator<Map.Entry<ArrayList<String>, Double>>() {
public int compare(Map.Entry<ArrayList<String>, Double> o1,
Map.Entry<ArrayList<String>, Double> o2) {
int result = o2.getValue().compareTo(o1.getValue());
if (result != 0)
return result;
ArrayList<String> key1 = o1.getKey();
ArrayList<String> key2 = o2.getKey();
for (int i = 0; i < key1.size() && i < key2.size(); i++) {
result = key1.get(i).compareTo(key2.get(i));
if (result != 0)
return result;
}
return key1.size() - key2.size();
}
});
// Convert the n-gram list to a format suitable for Hive
for(int i = 0; i < list.size(); i++) {
ArrayList<String> key = list.get(i).getKey();
Double val = list.get(i).getValue();
Object[] curGram = new Object[2];
ArrayList<Text> ng = new ArrayList<Text>();
for(int j = 0; j < key.size(); j++) {
ng.add(new Text(key.get(j)));
}
curGram[0] = ng;
curGram[1] = new DoubleWritable(val.doubleValue());
result.add(curGram);
}
return result;
}
/**
* Returns the number of n-grams in our buffer.
*/
public int size() {
return ngrams.size();
}
/**
* Adds a new n-gram to the estimation.
*
* @param ng The n-gram to add to the estimation
*/
public void add(ArrayList<String> ng) throws HiveException {
assert(ng != null && ng.size() > 0 && ng.get(0) != null);
Double curFreq = ngrams.get(ng);
if(curFreq == null) {
// new n-gram
curFreq = new Double(1.0);
} else {
// existing n-gram, just increment count
curFreq++;
}
ngrams.put(ng, curFreq);
// set 'n' if we haven't done so before
if(n == 0) {
n = ng.size();
} else {
if(n != ng.size()) {
throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'"
+ ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
+ ng.size() + "'.");
}
}
// Trim down the total number of n-grams if we've exceeded the maximum amount of memory allowed
//
// NOTE: Although 'k'*'pf' specifies the size of the estimation buffer, we don't want to keep
// performing N.log(N) trim operations each time the maximum hashmap size is exceeded.
// To handle this, we *actually* maintain an estimation buffer of size 2*'k'*'pf', and
// trim down to 'k'*'pf' whenever the hashmap size exceeds 2*'k'*'pf'. This really has
// a significant effect when 'k'*'pf' is very high.
if(ngrams.size() > k * pf * 2) {
trim(false);
}
}
/**
* Trims an n-gram estimation down to either 'pf' * 'k' n-grams, or 'k' n-grams if
* finalTrim is true.
*/
private void trim(boolean finalTrim) throws HiveException {
ArrayList<Map.Entry<ArrayList<String>,Double>> list = new ArrayList(ngrams.entrySet());
Collections.sort(list, new Comparator<Map.Entry<ArrayList<String>,Double>>() {
public int compare(Map.Entry<ArrayList<String>,Double> o1,
Map.Entry<ArrayList<String>,Double> o2) {
return o1.getValue().compareTo(o2.getValue());
}
});
for(int i = 0; i < list.size() - (finalTrim ? k : pf*k); i++) {
ngrams.remove( list.get(i).getKey() );
}
}
/**
* Takes a serialized n-gram estimator object created by the serialize() method and merges
* it with the current n-gram object.
*
* @param other A serialized n-gram object created by the serialize() method
*/
public void merge(List other) throws HiveException {
if(other == null) {
return;
}
// Get estimation parameters
int otherK = Integer.parseInt(other.get(0).toString());
int otherN = Integer.parseInt(other.get(1).toString());
int otherPF = Integer.parseInt(other.get(2).toString());
if(k > 0 && k != otherK) {
throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'k'"
+ ", which usually is caused by a non-constant expression. Found '"+k+"' and '"
+ otherK + "'.");
}
if(n > 0 && otherN != n) {
throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'"
+ ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
+ otherN + "'.");
}
if(pf > 0 && otherPF != pf) {
throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'pf'"
+ ", which usually is caused by a non-constant expression. Found '"+pf+"' and '"
+ otherPF + "'.");
}
k = otherK;
pf = otherPF;
n = otherN;
// Merge the other estimation into the current one
for(int i = 3; i < other.size(); i++) {
ArrayList<String> key = new ArrayList<String>();
for(int j = 0; j < n; j++) {
key.add(other.get(i+j).toString());
}
i += n;
double val = Double.parseDouble( other.get(i).toString() );
Double myval = ngrams.get(key);
if(myval == null) {
myval = new Double(val);
} else {
myval += val;
}
ngrams.put(key, myval);
}
trim(false);
}
/**
* In preparation for a Hive merge() call, serializes the current n-gram estimator object into an
* ArrayList of Text objects. This list is deserialized and merged by the
* merge method.
*
* @return An ArrayList of Hadoop Text objects that represents the current
* n-gram estimation.
* @see #merge
*/
public ArrayList<Text> serialize() throws HiveException {
ArrayList<Text> result = new ArrayList<Text>();
result.add(new Text(Integer.toString(k)));
result.add(new Text(Integer.toString(n)));
result.add(new Text(Integer.toString(pf)));
for(Iterator<ArrayList<String> > it = ngrams.keySet().iterator(); it.hasNext(); ) {
ArrayList<String> mykey = it.next();
assert(mykey.size() > 0);
for(int i = 0; i < mykey.size(); i++) {
result.add(new Text(mykey.get(i)));
}
Double myval = ngrams.get(mykey);
result.add(new Text(myval.toString()));
}
return result;
}
}