package org.maltparser.core.symbol.trie;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.log4j.Logger;
import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.helper.HashMap;
import org.maltparser.core.io.dataformat.ColumnDescription;
import org.maltparser.core.symbol.SymbolException;
import org.maltparser.core.symbol.SymbolTable;
import org.maltparser.core.symbol.nullvalue.InputNullValues;
import org.maltparser.core.symbol.nullvalue.NullValues;
import org.maltparser.core.symbol.nullvalue.NullValues.NullValueId;
import org.maltparser.core.symbol.nullvalue.OutputNullValues;
/**
*
* @author Johan Hall
* @since 1.0
*/
public class TrieSymbolTable implements SymbolTable {
private final String name;
private final Trie trie;
private final SortedMap<Integer, TrieNode> codeTable;
private int columnCategory;
private final NullValues nullValues;
private int valueCounter;
/**
* Cache the hash code for the symbol table
*/
private int cachedHash;
/**
* Special treatment during parsing
*/
private final int symbolTableMode;
private HashMap<String, Integer> tmpStorageStrIntMap;
private HashMap<Integer, String> tmpStorageIntStrMap;
private int tmpStorageValueCounter;
public TrieSymbolTable(String name, Trie trie, int columnCategory, String nullValueStrategy, int symbolTableMode) throws MaltChainedException {
this.name = name;
this.trie = trie;
this.columnCategory = columnCategory;
codeTable = new TreeMap<Integer, TrieNode>();
if (columnCategory == ColumnDescription.INPUT) {
nullValues = new InputNullValues(nullValueStrategy, this);
} else if (columnCategory == ColumnDescription.DEPENDENCY_EDGE_LABEL) {
nullValues = new OutputNullValues(nullValueStrategy, this);
} else {
nullValues = new InputNullValues(nullValueStrategy, this);
}
valueCounter = nullValues.getNextCode();
this.symbolTableMode = symbolTableMode;
if (this.symbolTableMode == TrieSymbolTableHandler.ADD_NEW_TO_TMP_STORAGE) {
tmpStorageStrIntMap = new HashMap<String, Integer>();
tmpStorageIntStrMap = new HashMap<Integer, String>();
tmpStorageValueCounter = -1;
}
}
public TrieSymbolTable(String name, Trie trie, int symbolTableMode) {
this.name = name;
this.trie = trie;
codeTable = new TreeMap<Integer, TrieNode>();
nullValues = new InputNullValues("one", this);
valueCounter = 1;
this.symbolTableMode = symbolTableMode;
if (this.symbolTableMode == TrieSymbolTableHandler.ADD_NEW_TO_TMP_STORAGE) {
tmpStorageStrIntMap = new HashMap<String, Integer>();
tmpStorageIntStrMap = new HashMap<Integer, String>();
tmpStorageValueCounter = -1;
}
}
public int addSymbol(String symbol) throws MaltChainedException {
if (nullValues == null || !nullValues.isNullValue(symbol)) {
if (symbol == null || symbol.length() == 0) {
throw new SymbolException("Symbol table error: empty string cannot be added to the symbol table");
}
if (this.symbolTableMode == TrieSymbolTableHandler.ADD_NEW_TO_TRIE) {
final TrieNode node = trie.addValue(symbol, this, -1);
final int code = node.getEntry(this);
if (!codeTable.containsKey(code)) {
codeTable.put(code, node);
}
return code;
} else { // this.symbolTableMode == ADD_NEW_TO_TMP_STORAGE
Integer entry = trie.getEntry(symbol, this);
if (entry != null) {
return entry.intValue();
}
if (!tmpStorageStrIntMap.containsKey(symbol)) {
// System.out.println("!tmpStorageStrIntMap.containsKey(symbol) : " + this.getName() + ": " + symbol.toString());
if (tmpStorageValueCounter == -1) {
tmpStorageValueCounter = valueCounter + 1;
} else {
tmpStorageValueCounter++;
}
tmpStorageStrIntMap.put(symbol, tmpStorageValueCounter);
tmpStorageIntStrMap.put(tmpStorageValueCounter, symbol);
return tmpStorageValueCounter;
} else {
return tmpStorageStrIntMap.get(symbol);
}
}
} else {
return nullValues.symbolToCode(symbol);
}
}
public int addSymbol(StringBuilder symbol) throws MaltChainedException {
if (nullValues == null || !nullValues.isNullValue(symbol)) {
if (symbol == null || symbol.length() == 0) {
throw new SymbolException("Symbol table error: empty string cannot be added to the symbol table");
}
if (this.symbolTableMode == TrieSymbolTableHandler.ADD_NEW_TO_TRIE) {
final TrieNode node = trie.addValue(symbol, this, -1);
final int code = node.getEntry(this);
if (!codeTable.containsKey(code)) {
codeTable.put(code, node);
}
return code;
} else { // this.symbolTableMode == ADD_NEW_TO_TMP_STORAGE
Integer entry = trie.getEntry(symbol.toString(), this);
if (entry != null) {
return entry.intValue();
}
if (!tmpStorageStrIntMap.containsKey(symbol)) {
if (tmpStorageValueCounter == -1) {
tmpStorageValueCounter = valueCounter + 1;
} else {
tmpStorageValueCounter++;
}
tmpStorageStrIntMap.put(symbol.toString(), tmpStorageValueCounter);
tmpStorageIntStrMap.put(tmpStorageValueCounter, symbol.toString());
return tmpStorageValueCounter;
} else {
return tmpStorageStrIntMap.get(symbol);
}
}
} else {
return nullValues.symbolToCode(symbol);
}
}
public String getSymbolCodeToString(int code) throws MaltChainedException {
if (code >= 0) {
if (nullValues == null || !nullValues.isNullValue(code)) {
if (trie == null) {
throw new SymbolException("The symbol table is corrupt. ");
}
if (this.symbolTableMode == TrieSymbolTableHandler.ADD_NEW_TO_TRIE) {
return trie.getValue(codeTable.get(code), this);
} else {
TrieNode node = codeTable.get(code);
if (node != null) {
return trie.getValue(node, this);
} else {
return tmpStorageIntStrMap.get(code);
}
}
} else {
return nullValues.codeToSymbol(code);
}
} else {
throw new SymbolException("The symbol code '" + code + "' cannot be found in the symbol table. ");
}
}
public int getSymbolStringToCode(String symbol) throws MaltChainedException {
if (symbol != null) {
if (nullValues == null || !nullValues.isNullValue(symbol)) {
if (trie == null) {
throw new SymbolException("The symbol table is corrupt. ");
}
if (this.symbolTableMode == TrieSymbolTableHandler.ADD_NEW_TO_TRIE) {
final Integer entry = trie.getEntry(symbol, this);
if (entry == null) {
throw new SymbolException("Could not find the symbol '" + symbol + "' in the symbol table. ");
}
return entry.intValue();
} else {
final Integer entry = trie.getEntry(symbol, this);
if (entry != null) {
return entry.intValue();
} else {
Integer tmpEntry = tmpStorageStrIntMap.get(symbol);
if (tmpEntry == null) {
throw new SymbolException("Could not find the symbol '" + symbol + "' in the symbol table. ");
}
return tmpEntry.intValue();
}
}
} else {
return nullValues.symbolToCode(symbol);
}
} else {
throw new SymbolException("The symbol code '" + symbol + "' cannot be found in the symbol table. ");
}
}
public void clearTmpStorage() {
if (symbolTableMode == TrieSymbolTableHandler.ADD_NEW_TO_TMP_STORAGE) {
tmpStorageIntStrMap.clear();
tmpStorageStrIntMap.clear();
tmpStorageValueCounter = -1;
}
}
public String getNullValueStrategy() {
if (nullValues == null) {
return null;
}
return nullValues.getNullValueStrategy();
}
public int getColumnCategory() {
return columnCategory;
}
public void printSymbolTable(Logger logger) throws MaltChainedException {
for (Integer code : codeTable.keySet()) {
logger.info(code + "\t" + trie.getValue(codeTable.get(code), this) + "\n");
}
}
public void saveHeader(BufferedWriter out) throws MaltChainedException {
try {
out.append('\t');
out.append(getName());
out.append('\t');
out.append(Integer.toString(getColumnCategory()));
out.append('\t');
out.append(getNullValueStrategy());
out.append('\n');
} catch (IOException e) {
throw new SymbolException("Could not save the symbol table. ", e);
}
}
public int size() {
return codeTable.size();
}
public void save(BufferedWriter out) throws MaltChainedException {
try {
out.write(name);
out.write('\n');
for (Integer code : codeTable.keySet()) {
out.write(code + "");
out.write('\t');
out.write(trie.getValue(codeTable.get(code), this));
out.write('\n');
}
out.write('\n');
} catch (IOException e) {
throw new SymbolException("Could not save the symbol table. ", e);
}
}
public void load(BufferedReader in) throws MaltChainedException {
int max = 0;
int index = 0;
String fileLine;
try {
while ((fileLine = in.readLine()) != null) {
if (fileLine.length() == 0 || (index = fileLine.indexOf('\t')) == -1) {
setValueCounter(max + 1);
break;
}
int code = Integer.parseInt(fileLine.substring(0, index));
final String str = fileLine.substring(index + 1);
final TrieNode node = trie.addValue(str, this, code);
codeTable.put(node.getEntry(this), node); //.getCode(), node);
if (max < code) {
max = code;
}
}
} catch (NumberFormatException e) {
throw new SymbolException("The symbol table file (.sym) contains a non-integer value in the first column. ", e);
} catch (IOException e) {
throw new SymbolException("Could not load the symbol table. ", e);
}
}
public String getName() {
return name;
}
public int getValueCounter() {
return valueCounter;
}
private void setValueCounter(int valueCounter) {
this.valueCounter = valueCounter;
}
protected void updateValueCounter(int code) {
if (code > valueCounter) {
valueCounter = code;
}
}
protected int increaseValueCounter() {
return valueCounter++;
}
public int getNullValueCode(NullValueId nullValueIdentifier) throws MaltChainedException {
if (nullValues == null) {
throw new SymbolException("The symbol table does not have any null-values. ");
}
return nullValues.nullvalueToCode(nullValueIdentifier);
}
public String getNullValueSymbol(NullValueId nullValueIdentifier) throws MaltChainedException {
if (nullValues == null) {
throw new SymbolException("The symbol table does not have any null-values. ");
}
return nullValues.nullvalueToSymbol(nullValueIdentifier);
}
public boolean isNullValue(String symbol) throws MaltChainedException {
if (nullValues != null) {
return nullValues.isNullValue(symbol);
}
return false;
}
public boolean isNullValue(int code) throws MaltChainedException {
if (nullValues != null) {
return nullValues.isNullValue(code);
}
return false;
}
public void copy(SymbolTable fromTable) throws MaltChainedException {
final SortedMap<Integer, TrieNode> fromCodeTable = ((TrieSymbolTable) fromTable).getCodeTable();
int max = getValueCounter() - 1;
for (Integer code : fromCodeTable.keySet()) {
final String str = trie.getValue(fromCodeTable.get(code), this);
final TrieNode node = trie.addValue(str, this, code);
codeTable.put(node.getEntry(this), node); //.getCode(), node);
if (max < code) {
max = code;
}
}
setValueCounter(max + 1);
}
public SortedMap<Integer, TrieNode> getCodeTable() {
return codeTable;
}
public Set<Integer> getCodes() {
return codeTable.keySet();
}
protected Trie getTrie() {
return trie;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final TrieSymbolTable other = (TrieSymbolTable) obj;
return ((name == null) ? other.name == null : name.equals(other.name));
}
@Override
public int hashCode() {
if (cachedHash == 0) {
cachedHash = 217 + (null == name ? 0 : name.hashCode());
}
return cachedHash;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(name);
sb.append(' ');
sb.append(valueCounter);
return sb.toString();
}
}