package org.wikibrain.core.nlp;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.TLongIntMap;
import gnu.trove.map.TLongObjectMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TLongIntHashMap;
import gnu.trove.map.hash.TLongObjectHashMap;
import gnu.trove.procedure.TLongIntProcedure;
import gnu.trove.set.TLongSet;
import gnu.trove.set.hash.TLongHashSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.utils.*;
import java.io.*;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* A class to remember counts for unigrams and (optionally) bigrams.
*
* This class uses a hashing trick so that a word's counts can be kept in 12
* bytes of memory.
*
*
* All methods that count words are mutually threadsafe.
* All methods that return counts are mutually threadsafe.
* The two types of methods cannot be mixed with thread safety, though.
*
* This class also remembers the number of mentions for each article.
* A mention must be in the format "foo:/w/en/1000" or "foo:/w/en/1000/Hercule_Poirot"
* where foo is the phrase mentioning the article and 1000 is the Wikipedia article id
* of the article with title "Hercule_Poirot."
*
* @author Shilad Sen
*/
public class Dictionary implements Closeable {
public static final int MAX_DICTIONARY_SIZE = 20000000; // 20M unigrams + bigrams by default.
public static int PRUNE_INTERVAL = 10000; // Consider pruning every PRUNE_INTERVAL increments
public static Logger LOG = LoggerFactory.getLogger(Dictionary.class);
/**
* Matches mentions like:
* foo:/w/en/1000
* foo:/w/en/1000/Hercule_Poirot
*/
public static final Pattern PATTERN_MENTION = Pattern.compile("(.*?):/w/([^/]+)/(-?\\d+)(/[^ ]*($| ))?");
/**
* How should words be stored? IN_MEMORY requires much more memory.
*/
public static enum WordStorage {
ON_DISK,
IN_MEMORY,
NONE
}
private final Language language;
/**
* Whether the text to be tokenized contains mentions of articles.
*/
private boolean containsMentions = true;
/**
* Whether bigrams should be counted.
*/
private boolean countBigrams = false;
private final WordStorage wordStorage;
private AtomicLong totalWords = new AtomicLong();
private AtomicLong totalBigrams = new AtomicLong();
private AtomicLong totalNgrams = new AtomicLong();
private final TLongIntMap unigramCounts = new TLongIntHashMap();
private final TLongIntMap bigramCounts = new TLongIntHashMap();
private final TLongIntMap ngramCounts = new TLongIntHashMap();
private StringTokenizer tokenizer = new StringTokenizer();
private NGramCreator nGramCreator = new NGramCreator();
private BufferedWriter wordWriter;
private File wordFile;
/**
* The maximum number of unigrams + bigrams (not including mentions)
*/
private int maxDictionarySize = MAX_DICTIONARY_SIZE;
/**
* Things with less than this number of occurrences will be pruned if
* the dictionary size exceeds maxDictionarySize.
*
* This is incremented BEFORE every pruning (e.g. the first pruning will have
* minPruneCount = 2).
*/
private int minPruneCount = 1;
/**
* Map of Wikipedia article id -> number of mentions in unigrams.
* Only calculated if containsMentions is true.
*/
private final TIntIntMap mentionCounts = new TIntIntHashMap();
/**
* Hashes of interesting ngrams.
*/
private TLongSet interestingNGrams = null;
private TLongSet interestingSubGrams = null;
/**
* Map from word hashes to actual words.
* Only maintained if "rememberWords" is true.
*/
private final TLongObjectMap<String> words = new TLongObjectHashMap<String>();
public Dictionary(Language language) {
this(language, WordStorage.NONE);
}
public Dictionary(Language language, WordStorage wordMode) {
this.language = language;
this.wordStorage = wordMode;
if (wordMode == WordStorage.ON_DISK) {
try {
wordFile = File.createTempFile("words", ".txt");
wordFile.deleteOnExit();
wordFile.delete();
wordWriter = WpIOUtils.openWriter(wordFile);
} catch (IOException e) {
throw new RuntimeException(e); // shouldn't happen for a temp file...
}
}
}
public void setInterestingNgrams(Iterator<String> ngrams) {
interestingSubGrams = new TLongHashSet();
interestingNGrams = new TLongHashSet();
while (ngrams.hasNext()) {
List<String> words = tokenizer.getWords(language, ngrams.next());
if (words.isEmpty()) {
continue;
}
StringBuilder b = new StringBuilder();
long hash = -1;
for (int i = 0; i < words.size(); i++) {
if (i > 0) {
b.append(' ');
}
b.append(words.get(i));
hash = hashWord(b.toString());
interestingSubGrams.add(hash);
}
interestingNGrams.add(hash);
}
}
/**
*
* Counts words in a file that have not been normalized.
* I.e. the separators between words are those traditionally found in plain text.
* Note that each line is presumed to be a sentence, so bigrams that span lines are not allowed.
*
* @param corpus
* @throws IOException
*/
public void countRawFile(File corpus) throws IOException {
LineIterator lineIterator = FileUtils.lineIterator(corpus, "UTF-8");
ParallelForEach.iterate(
lineIterator,
Math.min(3, WpThreadUtils.getMaxThreads()), // 3 seems the optimal number of threads here...
1000,
new Procedure<String>() {
@Override
public void call(String line) throws Exception {
countRawText(line);
}
},
Integer.MAX_VALUE);
lineIterator.close();
}
/**
*
* Counts words in a file that have not been normalized.
* I.e. the separators between words are only spaces without punctuation.
* Note that each line is presumed to be a sentence, so bigrams that span lines are not allowed.
*
* @param corpus
* @throws IOException
*/
public void countNormalizedFile(File corpus) throws IOException {
LineIterator lineIterator = FileUtils.lineIterator(corpus, "UTF-8");
ParallelForEach.iterate(
lineIterator,
Math.min(3, WpThreadUtils.getMaxThreads()), // 3 seems the optimal number of threads here...
1000,
new Procedure<String>() {
@Override
public void call(String line) throws Exception {
countNormalizedText(line);
}
},
Integer.MAX_VALUE);
lineIterator.close();
}
/**
* Counts words that have not been normalized. I.e. the separators between words
* are those traditionally found in plain text.
*
* @param text
*/
public void countRawText(String text) {
// Count and extract mentions if necessary.
if (containsMentions) {
Matcher m = PATTERN_MENTION.matcher(text);
while (m.find()) {
int wpId = Integer.valueOf(m.group(3));
synchronized (mentionCounts) {
mentionCounts.adjustOrPutValue(wpId, 1, 1);
}
}
text = PATTERN_MENTION.matcher(text).replaceAll("$1 ");
}
countWords(tokenizer.getWords(language, text));
}
/**
* Counts words that have not been normalized. I.e. words are separated ONLY by spaces.
*
* @param text
*/
public void countNormalizedText(String text) {
countWords(Arrays.asList(text.split(" +")));
}
private void countWords(List<String> tokens) {
for (String word : tokens) {
countUnigram(word);
}
if (countBigrams) {
for (String bigram : nGramCreator.getNGrams(tokens, 2, 2)) {
countBigram(bigram);
}
}
if (interestingNGrams != null) {
countNgrams(tokens);
}
}
/**
* Increments the count for a particular unigram.
*
* If this.countMentions is true, it scans for, counts,
* and removes any article mentions "foo:ID342234" -> "foo"
*
* If rememberWords is true, and this word hasn't been seen before,
* it remembers the word.
*
* @param word
*/
public void countUnigram(String word) {
word = word.trim();
if (word.isEmpty()) {
return;
}
if (containsMentions) {
Matcher m = PATTERN_MENTION.matcher(word);
if (m.matches()) {
word = m.group(1);
int wpId = Integer.valueOf(m.group(3));
synchronized (mentionCounts) {
mentionCounts.adjustOrPutValue(wpId, 1, 1);
}
}
}
long hash = getHash(word);
if (wordStorage == WordStorage.IN_MEMORY) {
synchronized (words) {
if (!words.containsKey(hash)) {
words.put(hash, word);
}
}
}
int n;
synchronized (unigramCounts) {
n = unigramCounts.adjustOrPutValue(hash, 1, 1);
}
if (n == 1 && wordStorage == WordStorage.ON_DISK) {
try {
wordWriter.write(word + "\n");
} catch (IOException e) {
throw new RuntimeException(e); // shouldn't really happen
}
}
if (totalWords.incrementAndGet() % PRUNE_INTERVAL == 0) {
pruneIfNecessary();
}
}
/**
* Counts a bigram.
* @param word
*/
public void countBigram(String word) {
word = word.trim();
if (word.isEmpty()) {
return;
}
if (containsMentions) {
Matcher m = PATTERN_MENTION.matcher(word);
if (m.matches()) {
word = m.group(1);
}
}
long h = getHash(word);
synchronized (bigramCounts) {
bigramCounts.adjustOrPutValue(h, 1, 1);
}
if (totalBigrams.incrementAndGet() % PRUNE_INTERVAL == 0) {
pruneIfNecessary();
}
}
public void countNgrams(List<String> words) {
for (int i = 0; i < words.size(); i++) {
StringBuilder buffer = new StringBuilder();
for (int j = i; j < words.size(); j++) {
if (j > i) {
buffer.append(' ');
}
buffer.append(words.get(i));
long hash = hashWord(buffer.toString());
if (!interestingNGrams.contains(hash) && !interestingSubGrams.contains(hash)) {
break;
}
}
}
}
public synchronized void pruneIfNecessary() {
while (true) {
int n1, n2;
synchronized (unigramCounts) {
n1 = unigramCounts.size();
}
synchronized (bigramCounts) {
n2 = bigramCounts.size();
}
if (n1 + n2 <= maxDictionarySize) {
return;
}
minPruneCount++;
LOG.info("pruning dictionary entries with frequency less than " + minPruneCount);
synchronized (unigramCounts) {
unigramCounts.retainEntries(new TLongIntProcedure() {
@Override
public boolean execute(long hash, int count) {
return (count >= minPruneCount);
}
});
n1 = unigramCounts.size();
}
synchronized (bigramCounts) {
bigramCounts.retainEntries(new TLongIntProcedure() {
@Override
public boolean execute(long hash, int count) {
return (count >= minPruneCount);
}
});
n2 = bigramCounts.size();
}
LOG.info("after pruning dictionary size is " + (n1 + n2));
// TODO: clear out words, but we need a triple lock... ugh.
}
}
/**
* Writes all unigrams and mentions in a dictionary to a file.
* @param output
* @throws IOException
*/
public void write(File output) throws IOException {
write(output, 1);
}
/**
* Writes all unigrams with at least minCount frequency and all mentions to a file.
* @param output
* @param minCount
* @throws IOException
*/
public void write(File output, int minCount) throws IOException {
if (wordStorage == WordStorage.NONE) {
throw new UnsupportedOperationException();
}
IOUtils.closeQuietly(this);
BufferedWriter writer = WpIOUtils.openWriter(output);
writer.write("t " + totalWords.get() + " _\n");
if (wordStorage == WordStorage.ON_DISK) {
BufferedReader reader = WpIOUtils.openBufferedReader(this.wordFile);
while (true) {
String line = reader.readLine();
if (line == null) {
break;
}
String phrase = line.trim();
long hash = getHash(phrase);
int c = unigramCounts.get(hash);
if (c < minCount) {
continue;
}
writer.write("w " + c + " " + phrase + "\n");
}
reader.close();
} else if (wordStorage == WordStorage.IN_MEMORY) {
for (String phrase : words.valueCollection()) {
long hash = getHash(phrase);
int c = unigramCounts.get(hash);
if (c < minCount) {
continue;
}
writer.write("w " + c + " " + phrase + "\n");
}
} else {
throw new IllegalStateException();
}
for (int wpId : mentionCounts.keys()) {
writer.write("m " + wpId + " " + mentionCounts.get(wpId) + "\n");
}
writer.close();
}
/**
* Reads an entire unigram dictionary back from a file.
* @param file
* @throws IOException
*/
public void read(File file) throws IOException {
read(file, Integer.MAX_VALUE, 1);
}
/**
* Reads unigrams from a file.
*
* The top (by frequency) maxWords unigrams are retained.
* @param file
* @param maxWords
* @param minCount
* @throws IOException
*/
public void read(File file, int maxWords, int minCount) throws IOException {
if (wordStorage == WordStorage.ON_DISK) {
throw new UnsupportedOperationException("Cannot read into dictionaries using disk storage");
}
// Pass 1: store all available counts to figure out cutoff for tracking words
TIntList counts = new TIntArrayList();
BufferedReader reader = WpIOUtils.openBufferedReader(file);
while (true) {
String line = reader.readLine();
if (line == null) {
break;
}
String tokens[] = line.trim().split(" ", 3);
if (tokens[0].equals("w")) {
int c = Integer.valueOf(tokens[1]);
if (c >= minCount) {
counts.add(c);
}
}
}
reader.close();
counts.sort();
counts.reverse();
/**
* Figure out threshold for saving words
*/
int threshold = 0;
int numSavedAtThreshold = Integer.MAX_VALUE;
if (counts.size() > maxWords) {
threshold = counts.get(maxWords-1);
for (int i = maxWords-1; i >= 0; i--) {
if (counts.get(i) == threshold) {
numSavedAtThreshold++;
} else {
break;
}
}
}
/**
* Restore words
*/
reader = WpIOUtils.openBufferedReader(file);
totalWords.set(0);
while (true) {
String line = reader.readLine();
if (line == null) {
break;
}
String tokens[] = line.trim().split(" ", 3);
if (tokens[0].equals("w")) {
int c = Integer.valueOf(tokens[1]);
if (c < threshold || c < minCount) {
continue;
}
if (c == threshold && numSavedAtThreshold == 0) {
continue;
}
if (c == threshold) {
numSavedAtThreshold--;
}
String phrase = tokens[2].trim();
int count = Integer.valueOf(tokens[1]);
long hash = getHash(phrase);
unigramCounts.put(hash, count);
if (wordStorage == WordStorage.IN_MEMORY) {
words.put(hash, phrase);
}
} else if (tokens[0].equals("m")) {
mentionCounts.put(Integer.valueOf(tokens[1]), Integer.valueOf(tokens[2]));
} else if (tokens[0].equals("t")) {
totalWords.set(Long.valueOf(tokens[1]));
} else {
throw new IOException("unexpected line: " + line);
}
}
reader.close();
counts.sort();
counts.reverse();
}
public int getUnigramCount(String bigram) {
return unigramCounts.get(getHash(bigram));
}
public int getBigramCount(String word1, String word2) {
return bigramCounts.get(getHash(word1 + " " + word2));
}
public int getBigramCount(String word) {
return bigramCounts.get(getHash(word));
}
public int getMentionCount(int wpId) {
return mentionCounts.get(wpId);
}
/**
* Returns the number of times the article corresponding to the url was mentioned.
* @param mentionUrl mention in the format /w/langCode/articleId/ArticleTitle
* @return
*/
public int getMentionCount(String mentionUrl) {
if (!mentionUrl.startsWith("/w/")) {
throw new IllegalArgumentException("format for mentionUrl must be /w/langCode/articleId/ArticleTitle");
}
String tokens[] = mentionUrl.split("/", 5);
if (tokens.length != 5) {
throw new IllegalArgumentException("format for mentionUrl must be /w/langCode/articleId/ArticleTitle");
}
return mentionCounts.get(Integer.valueOf(tokens[3]));
}
public final long getHash(String ngram) {
return hashWord(ngram);
}
public long getTotalCount() {
return totalWords.get();
}
public void setContainsMentions(boolean containsMentions) {
this.containsMentions = containsMentions;
}
public void setCountBigrams(boolean countBigrams) {
this.countBigrams = countBigrams;
}
public void setTokenizer(StringTokenizer tokenizer) {
this.tokenizer = tokenizer;
}
public void setCreator(NGramCreator creator) {
this.nGramCreator = creator;
}
public int getNumUnigrams() {
return unigramCounts.size();
}
public int getNumBigrams() {
return bigramCounts.size();
}
public int getNumMentionedArticles() {
return mentionCounts.size();
}
/**
* Return the n most frequently used unigrams, in decreasing order.
* @param n
* @return
*/
public List<String> getFrequentUnigrams(int n) {
if (wordStorage != WordStorage.IN_MEMORY) {
throw new UnsupportedOperationException("WordStorage must be in memory to return strings");
}
int threshold = 0;
if (n < unigramCounts.size()) {
int counts[] = unigramCounts.values();
Arrays.sort(counts);
threshold = counts[counts.length - n];
}
final List<String> top = new ArrayList<String>();
final int finalThreshold = threshold;
unigramCounts.forEachEntry(new TLongIntProcedure() {
@Override
public boolean execute(long hash, int count) {
if (count >= finalThreshold) {
top.add(words.get(hash));
}
return true;
}
});
Collections.sort(top, new Comparator<String>() {
@Override
public int compare(String w1, String w2) {
int r = getUnigramCount(w2) - getUnigramCount(w1);
// order is deterministic for testing ease
if (r == 0) {
r = w1.compareTo(w2);
}
return r;
}
});
if (top.size() > n) {
return top.subList(0, n);
} else {
return top;
}
}
/**
* Return the n most frequently used unigrams and mentions, in decreasing order of frequency.
*
* Mentions are encoded as words with the format "/w/WikipediaId/ArticleTitle"
*
* @param lpd
* @param maxWords
* @param minWordFreq
* @param minMentionFreq
* @return
*/
public List<String> getFrequentUnigramsAndMentions(LocalPageDao lpd, int maxWords, int minWordFreq, int minMentionFreq) throws DaoException {
if (wordStorage != WordStorage.IN_MEMORY) {
throw new UnsupportedOperationException("WordStorage must be in memory to return strings");
}
final int threshold;
if (maxWords < unigramCounts.size()) {
int counts[] = unigramCounts.values();
Arrays.sort(counts);
threshold = Math.max(minWordFreq, counts[counts.length - maxWords]);
} else {
threshold = 0;
}
final List<String> topWords = new ArrayList<String>();
unigramCounts.forEachEntry(new TLongIntProcedure() {
@Override
public boolean execute(long hash, int count) {
if (count >= threshold) {
topWords.add(words.get(hash));
}
return true;
}
});
Collections.sort(topWords, new Comparator<String>() {
@Override
public int compare(String w1, String w2) {
int r = getUnigramCount(w2) - getUnigramCount(w1);
// order is deterministic for testing ease
if (r == 0) {
r = w1.compareTo(w2);
}
return r;
}
});
List<String> result = topWords;
if (result.size() > maxWords) result = result.subList(0, maxWords);
for (int wpId : mentionCounts.keys()) {
if (mentionCounts.get(wpId) >= minMentionFreq) {
LocalPage lp = lpd.getById(language, wpId);
if (lp != null) {
result.add(makeMentionUrl(lp));
}
}
}
Collections.sort(result, new Comparator<String>() {
@Override
public int compare(String w1, String w2) {
int n1 = (w1.startsWith("/w/")) ? getMentionCount(w1) : getUnigramCount(w1);
int n2 = (w2.startsWith("/w/")) ? getMentionCount(w2) : getUnigramCount(w2);
int r = n2 - n1;
// order is deterministic for testing ease
if (r == 0) {
r = w1.compareTo(w2);
}
return r;
}
});
return result;
}
private String makeMentionUrl(LocalPage page) {
return "/w/" + language.getLangCode() + "/" + page.getLocalId() + "/" + page.getTitle().getCanonicalTitle().replaceAll(" ", "_");
}
@Override
public void close() throws IOException {
if (this.wordWriter != null) {
wordWriter.close();
}
}
/**
* Returns a hashcode for a particular word.
* The hashCode 0 will NEVER be returned.
* @param w
* @return
*/
public static long hashWord(String w) {
long h = MurmurHash.hash64(w);
if (h == 0) h = 1; // hack: h == 0 is reserved.
return h;
}
public void setMaxDictionarySize(int maxDictionarySize) {
this.maxDictionarySize = maxDictionarySize;
}
}