package edu.stanford.nlp.ie.machinereading.common;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.util.ArrayList;
import java.util.Map;
import java.util.Set;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.util.Generics;
public class StringDictionary {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(StringDictionary.class);
public static class IndexAndCount {
public final int mIndex;
public int mCount;
IndexAndCount(int i, int c) {
mIndex = i;
mCount = c;
}
}
/** Name of this dictionary */
private final String mName;
/**
* Access type: If true, create a dictionary entry if the entry does not exist
* in get Otherwise, return -1 if the entry does not exist in get
*/
private boolean mCreate;
/** The actual dictionary */
private Map<String, IndexAndCount> mDict;
/** Inverse mapping from integer keys to the string values */
private Map<Integer, String> mInverse;
public StringDictionary(String name) {
mName = name;
mCreate = false;
mDict = Generics.newHashMap();
mInverse = Generics.newHashMap();
}
public void setMode(boolean mode) {
mCreate = mode;
}
public int size() {
return mDict.size();
}
public int get(String s) {
return get(s, true);
}
public IndexAndCount getIndexAndCount(String s) {
IndexAndCount ic = mDict.get(s);
if (mCreate) {
if (ic == null) {
ic = new IndexAndCount(mDict.size(), 0);
mDict.put(s, ic);
mInverse.put(Integer.valueOf(ic.mIndex), s);
}
ic.mCount++;
}
return ic;
}
/**
* Fetches the index of this string If mCreate is true, the entry is created
* if it does not exist. If mCreate is true, the count of the entry is
* incremented for every get If no entry found throws an exception if
* shouldThrow == true
*/
public int get(String s, boolean shouldThrow) {
IndexAndCount ic = mDict.get(s);
if (mCreate) {
if (ic == null) {
ic = new IndexAndCount(mDict.size(), 0);
mDict.put(s, ic);
mInverse.put(Integer.valueOf(ic.mIndex), s);
}
ic.mCount++;
}
if (ic != null)
return ic.mIndex;
if (shouldThrow) {
throw new RuntimeException("Unknown entry \"" + s + "\" in dictionary \"" + mName + "\"!");
} else {
return -1;
}
}
public static final String NIL_VALUE = "nil";
/**
* Reverse mapping from integer key to string value
*/
public String get(int idx) {
if (idx == -1)
return NIL_VALUE;
String s = mInverse.get(idx);
if (s == null)
throw new RuntimeException("Unknown index \"" + idx + "\" in dictionary \"" + mName + "\"!");
return s;
}
public int getCount(int idx) {
if (idx == -1)
return 0;
String s = mInverse.get(idx);
if (s == null)
throw new RuntimeException("Unknown index \"" + idx + "\" in dictionary \"" + mName + "\"!");
return getIndexAndCount(s).mCount;
}
/**
* Saves all dictionary entries that appeared {@literal >} threshold times Note: feature
* indices are changed to contiguous values starting at 0. This is needed in
* order to minimize the memory allocated for the expanded feature vectors
* (average perceptron).
*/
public void save(String path, String prefix, int threshold) throws java.io.IOException {
String fileName = path + java.io.File.separator + prefix + "." + mName;
java.io.PrintStream os = new java.io.PrintStream(new java.io.FileOutputStream(fileName));
int index = 0;
for (Map.Entry<String, IndexAndCount> entry : mDict.entrySet()) {
IndexAndCount ic = entry.getValue();
if (ic.mCount > threshold) {
os.println(entry.getKey() + ' ' + index + ' ' + ic.mCount);
index++;
}
}
os.close();
log.info("Saved " + index + "/" + mDict.size() + " entries for dictionary \"" + mName + "\".");
}
public void clear() {
mDict.clear();
mInverse.clear();
}
public Set<String> keySet() {
return mDict.keySet();
}
/** Loads all saved dictionary entries from disk */
public void load(String path, String prefix) throws java.io.IOException {
String fileName = path + java.io.File.separator + prefix + "." + mName;
BufferedReader is = IOUtils.readerFromString(fileName);
for (String line; (line = is.readLine()) != null; ) {
ArrayList<String> tokens = SimpleTokenize.tokenize(line);
if (tokens.size() != 3) {
throw new RuntimeException("Invalid dictionary line: " + line);
}
int index = Integer.parseInt(tokens.get(1));
int count = Integer.parseInt(tokens.get(2));
if (index < 0 || count <= 0) {
throw new RuntimeException("Invalid dictionary line: " + line);
}
IndexAndCount ic = new IndexAndCount(index, count);
mDict.put(tokens.get(0), ic);
mInverse.put(Integer.valueOf(index), tokens.get(0));
}
is.close();
log.info("Loaded " + mDict.size() + " entries for dictionary \"" + mName + "\".");
}
public java.util.Set<String> keys() {
return mDict.keySet();
}
}