// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.morph.signature;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import marmot.core.Sequence;
import marmot.core.Token;
import marmot.morph.Word;
import marmot.morph.io.SentenceReader;
import marmot.util.Counter;
import marmot.util.FileUtils;
import marmot.util.SymbolTable;
public class Trie implements Serializable {
private static final long serialVersionUID = 1L;
protected transient List<String> words_;
protected transient List<List<List<Integer>>> tags_;
private transient boolean[] feature_map_;
private List<Trie> children_;
private double[] entropy_;
private Feature feature_;
private int child_index_;
private Trie parent_;
private int index_;
private Set<String> no_signature_;
private boolean verbose_;
public Trie(Trie trie, int feature_index, int child_index) {
this(null, trie.verbose_);
feature_map_ = new boolean[trie.feature_map_.length];
System.arraycopy(trie.feature_map_, 0, feature_map_, 0,
feature_map_.length);
feature_map_[feature_index] = false;
child_index_ = child_index;
parent_ = trie;
}
public Trie(Set<String> no_signature, boolean verbose) {
entropy_ = null;
words_ = new ArrayList<String>();
tags_ = new ArrayList<List<List<Integer>>>();
child_index_ = -1;
parent_ = null;
no_signature_ = no_signature;
verbose_ = verbose;
}
public void add(List<List<Integer>> tags, String word) {
words_.add(word);
tags_.add(tags);
}
public void split(int limit, Set<String> vocab) {
children_ = null;
List<Feature> features = getFeatures(vocab);
int num_leaves = 1;
PriorityQueue<Split> splits = new PriorityQueue<Split>();
feature_map_ = new boolean[features.size()];
Arrays.fill(feature_map_, true);
List<Trie> tries = new LinkedList<Trie>();
tries.add(this);
while (num_leaves < limit && !tries.isEmpty()) {
for (Trie trie : tries) {
for (int feature_index = 0; feature_index < features.size(); feature_index++) {
if (trie.feature_map_[feature_index]) {
Split split = new Split(features, trie, feature_index);
if (split.valid_) {
splits.add(split);
} else {
trie.feature_map_[feature_index] = false;
}
}
}
}
tries.clear();
Split split;
while (true) {
split = splits.poll();
if (split == null)
break;
if (split.trie_.isLeaf()) {
break;
}
}
if (split == null)
break;
Trie trie = split.trie_;
assert trie.children_ == null;
trie.children_ = split.children_;
trie.feature_ = features.get(split.feature_index_);
for (Trie child : trie.children_) {
tries.add(child);
}
num_leaves += 1;
}
List<Trie> leaves = new LinkedList<Trie>();
this.getLeafes(leaves);
int words = 0;
for (Trie leaf : leaves) {
words += leaf.words_.size();
}
assert words_.size() == words;
clear(0);
}
private List<Feature> getFeatures(Set<String> vocab) {
List<Feature> features = new ArrayList<Feature>();
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
for (int index = 0; index < word.length(); index++) {
char c = word.charAt(index);
if (Character.isDigit(c)) {
return true;
}
}
return false;
}
@Override
String getName() {
return "HasDigit";
}
});
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
for (int index = 0; index < word.length(); index++) {
char c = word.charAt(index);
if (Character.isLetter(c)) {
return true;
}
}
return false;
}
@Override
String getName() {
return "HasLetter";
}
});
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
for (int index = 0; index < word.length(); index++) {
char c = word.charAt(index);
if (Character.isUpperCase(c)) {
return true;
}
}
return false;
}
@Override
String getName() {
return "HasUpper";
}
});
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
for (int index = 0; index < word.length(); index++) {
char c = word.charAt(index);
if (Character.isLowerCase(c)) {
return true;
}
}
return false;
}
@Override
String getName() {
return "HasLower";
}
});
for (int length = 1; length < 10; length++) {
final int length_ = length;
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
return word.length() > length_;
}
@Override
String getName() {
return "Length>" + length_;
}
});
}
Counter<Character> alphabet = new Counter<Character>();
for (String word : words_) {
for (int index = 0; index < word.length(); index++) {
char c = Character.toLowerCase(word.charAt(index));
alphabet.increment(c, 1.0);
}
}
for (Map.Entry<Character, Double> entry : alphabet.entrySet()) {
if (entry.getValue() > 50) {
final char C = entry.getKey();
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
for (int index = 0; index < word.length(); index++) {
char c = Character.toLowerCase(word.charAt(index));
if (c == C) {
return true;
}
}
return false;
}
@Override
String getName() {
return "Contains=" + C;
}
});
}
}
for (int position = 1; position <= 5; position++) {
final int POSITION = position;
for (Map.Entry<Character, Double> entry : alphabet.entrySet()) {
if (entry.getValue() > 50) {
final char C = entry.getKey();
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
int index = word.length() - POSITION;
if (index < 0)
return false;
return Character.toLowerCase(word.charAt(index)) == C;
}
@Override
String getName() {
return "Char[-" + POSITION + "]=" + C;
}
});
}
}
}
for (int position = 0; position < 5; position++) {
final int POSITION = position;
for (Map.Entry<Character, Double> entry : alphabet.entrySet()) {
if (entry.getValue() > 50) {
final char C = entry.getKey();
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
boolean feature(String word) {
int index = POSITION;
if (index >= word.length())
return false;
return Character.toLowerCase(word.charAt(index)) == C;
}
@Override
String getName() {
return "Char[" + POSITION + "]=" + C;
}
});
}
}
}
final Set<String> known_lowercase_words = new HashSet<String>();
for (String word : vocab) {
if (word.toLowerCase().equals(word)) {
known_lowercase_words.add(word);
}
}
features.add(new Feature() {
private static final long serialVersionUID = 1L;
@Override
String getName() {
return "LowerIsKnown";
}
@Override
boolean feature(String word) {
String lower = word.toLowerCase();
if (lower.equals(word)) {
return true;
}
return known_lowercase_words.contains(lower);
}
});
return features;
}
public boolean isLeaf() {
return children_ == null;
}
public double[] getEntropy() {
if (entropy_ == null) {
if (tags_.isEmpty())
return null;
int K = tags_.get(0).size();
entropy_ = new double[K];
for (int k = 0; k < K; k++) {
double entropy = 0.;
Counter<Integer> counter = new Counter<Integer>();
assert !tags_.get(k).isEmpty();
for (List<List<Integer>> tag_list : tags_) {
for (int tag : tag_list.get(k)) {
counter.increment(tag, 1.0);
}
}
// assert counter.size() > 0;
for (double count : counter.counts()) {
double prob = count / counter.totalCount();
entropy -= prob * Math.log(prob);
}
entropy_[k] = entropy;
}
}
return entropy_;
}
public String signature() {
if (parent_ == null) {
return "";
}
if (isLeaf()) {
assert feature_ == null;
}
StringBuilder sb = new StringBuilder();
sb.append(parent_.signature());
Feature feature = parent_.feature_;
if (sb.length() > 0)
sb.append(',');
sb.append(feature.getName());
sb.append('=');
sb.append((child_index_ == 0) ? 't' : 'f');
return sb.toString();
}
public int classify(String word) {
if (no_signature_.contains(word))
return -1;
return classify_(word);
}
private int classify_(String word) {
if (isLeaf()) {
return index_;
}
assert feature_ != null;
int value = feature_.feature(word) ? 0 : 1;
return children_.get(value).classify_(word);
}
public void getLeafes(List<Trie> leaves) {
if (isLeaf()) {
leaves.add(this);
} else {
for (Trie child : children_) {
child.getLeafes(leaves);
}
}
}
public int clear(int index) {
if (isLeaf()) {
if (verbose_) {
System.err.println(index);
System.err.println(Arrays.toString(getEntropy()));
System.err.println(signature());
System.err.println("words " + words_.size() + " "
+ Split.shorten(words_));
System.err.println();
}
index_ = index;
index += 1;
} else {
index_ = -1;
for (Trie child : children_) {
index = child.clear(index);
}
}
words_ = null;
tags_ = null;
feature_map_ = null;
if (parent_ == null) {
index_ = index;
}
return index;
}
public int getIndex() {
return index_;
}
public static Trie train(String trainfile, boolean verbose) {
return train(trainfile, verbose, 20, 1);
}
public static Trie train(String trainfile, boolean verbose, int num_folds, int K) {
List<Sequence> sentences = new LinkedList<Sequence>();
for (Sequence sentence : new SentenceReader(trainfile)) {
sentences.add(sentence);
}
return train(sentences, verbose, num_folds, K);
}
public static Trie train(Collection<Sequence> sentences, boolean verbose) {
return train(sentences, verbose, 20, 1);
}
public static Trie train(Collection<Sequence> sentences, boolean verbose, int num_folds, int K) {
int sentences_per_fold = sentences.size() / num_folds;
if (sentences.size() < num_folds) {
throw new RuntimeException("Training set is to small: |sentences| = " + sentences.size() + " num folds =" + num_folds);
}
Set<String> known = new HashSet<String>();
Map<String, List<List<Integer>>> map = new HashMap<String, List<List<Integer>>>();
SymbolTable<String> tags = new SymbolTable<String>();
Set<String> vocab = new HashSet<String>();
for (Sequence sentence : sentences) {
for (Token token : sentence) {
Word word = (Word) token;
vocab.add(word.getWordForm());
tags.toIndex(word.getPosTag(), true);
}
}
int start_index = 0;
while (start_index < sentences.size()) {
known.clear();
int end_index = start_index + sentences_per_fold;
if (end_index + sentences_per_fold >= sentences.size()) {
end_index = sentences.size();
}
int index = 0;
for (Sequence sentence : sentences) {
if (index >= start_index && index < end_index) {
for (Token token : sentence) {
Word word = (Word) token;
known.add(word.getWordForm());
}
}
index++;
}
vocab.retainAll(known);
start_index = end_index;
}
for (Sequence sentence : sentences) {
for (int i = 0; i < sentence.size(); i++) {
Word word = (Word) sentence.get(i);
String form = word.getWordForm();
if (!vocab.contains(form)) {
List<List<Integer>> tag_list = map.get(form);
if (tag_list == null) {
tag_list = new LinkedList<List<Integer>>();
map.put(form, tag_list);
for (int k = 0; k < K; k++) {
tag_list.add(new LinkedList<Integer>());
}
}
for (int k = 0; k < K; k++) {
int shifted_index = i + k - K / 2;
if (shifted_index >= sentence.size()
|| shifted_index < 0) {
continue;
}
int tag = tags.toIndex(((Word) sentence
.get(shifted_index)).getPosTag());
tag_list.get(k).add(tag);
}
}
}
}
Trie trie = new Trie(vocab, verbose);
for (Map.Entry<String, List<List<Integer>>> entry : map.entrySet()) {
trie.add(entry.getValue(), entry.getKey());
}
trie.split(100, vocab);
return trie;
}
public static void main(String[] args) {
if (args.length != 2) {
System.err.println("Usage: Trie form-index=?,tag-index=?,train-file outputfile");
System.exit(1);
}
Trie trie = train(args[0], true);
FileUtils.saveToFile(trie, args[1]);
}
}