// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.hmm;
import hmmla.Properties;
import hmmla.io.Sentence;
import hmmla.io.Token;
import hmmla.util.SuffixTrie;
import hmmla.util.SymbolTable;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
public class Model implements Serializable {
private static final long serialVersionUID = 1L;
public static final int BorderIndex = 0;
public static final String BorderSymbol = "<BREAK>";
protected SymbolTable<String> word_table_;
protected SymbolTable<String> tag_table_;
protected Map<String, Tree> top_level_;
protected Map<String, Tree> clustering_;
protected Statistics statistics_;
private List<String> open_tag_classes_;
private Map<String, List<String>> wordform_to_candidates_;
private Properties props_;
private Set<String> rare_;
private Set<String> vocab_;
private SuffixTrie suffix_trie_;
public Model(Model model) {
word_table_ = model.getWordTable();
tag_table_ = model.getTagTable();
top_level_ = model.getTopLevel();
clustering_ = model.getClustering();
statistics_ = model.getStatistics();
rare_ = model.rare_;
vocab_ = model.vocab_;
open_tag_classes_ = model.getOpenTagClasses();
wordform_to_candidates_ = model.getWordformToCandidates();
props_ = model.getProperties();
}
public Model(Iterable<Sentence> reader, Properties props) {
init(reader, props);
}
private void init(Iterable<Sentence> reader, Properties props) {
props_ = props;
word_table_ = new SymbolTable<String>();
tag_table_ = new SymbolTable<String>();
tag_table_.toIndex(BorderSymbol, true);
vocab_ = new HashSet<String>();
for (Sentence sentence : reader){
for (Token token : sentence){
tag_table_.toIndex(token.getTag(), true);
word_table_.toIndex(token.getWordForm(), true);
vocab_.add(token.getWordForm());
}
}
statistics_ = new Statistics(tag_table_, word_table_);
for (Sentence sentence : reader){
int fromIndex = BorderIndex;
for (Token token : sentence){
String tag = token.getTag();
int toIndex = tag_table_.toIndex(tag);
int output = word_table_.toIndex(token.getWordForm());
statistics_.addEmissions(toIndex, output, 1.0);
statistics_.addTransitions(fromIndex, toIndex, 1.0);
fromIndex = toIndex;
}
statistics_.addTransitions(fromIndex, BorderIndex, 1.0);
}
rare_ = new HashSet<String>();
for (Entry<String, Integer> entry : word_table_.entrySet()) {
double total = 0.0;
for (int index = 0; index < statistics_.getNumTags(); index++) {
total += statistics_.getEmissions(index, entry.getValue());
}
if (total + 0.5 < props_.getIsRareThreshold()) {
rare_.add(entry.getKey());
}
}
clustering_ = new HashMap<String, Tree>();
for (Entry<String, Integer> entry : tag_table_.entrySet()) {
Tree tree = new Tree(entry.getKey(), 0);
clustering_.put(entry.getKey(), tree);
}
top_level_ = new HashMap<String, Tree>(clustering_);
wordform_to_candidates_ = new HashMap<String, List<String>>();
Map<String, Set<Integer>> tag_to_words = new HashMap<String, Set<Integer>>();
for (Entry<String, Integer> entry : word_table_.entrySet()) {
String word_form = entry.getKey();
List<String> tags = new LinkedList<String>();
for (Entry<String, Integer> tag : tag_table_.entrySet()) {
if (statistics_.getEmissions(tag.getValue(), entry.getValue()) < 0.5) {
continue;
}
tags.add(tag.getKey());
Set<Integer> words = tag_to_words.get(tag.getKey());
if (words == null) {
words = new HashSet<Integer>();
tag_to_words.put(tag.getKey(), words);
}
words.add(entry.getValue());
}
if (tags.size() > 0) {
wordform_to_candidates_.put(word_form, tags);
}
}
open_tag_classes_ = new LinkedList<String>();
for (Entry<String, Set<Integer>> entry : tag_to_words.entrySet()) {
if (entry.getValue().size() > 40) {
open_tag_classes_.add(entry.getKey());
}
}
}
public Map<String, List<String>> getWordformToCandidates() {
return wordform_to_candidates_;
}
public List<Iterable<Integer>> getSentenceCandidates(Sentence sentence) {
List<Iterable<Integer>> candidates = new ArrayList<Iterable<Integer>>(sentence.size());
List<Tree> leaves = new LinkedList<Tree>();
for (Token token : sentence) {
List<String> candidate_parents = getCandidates(token.getWordForm());
List<Integer> candidate_list = new LinkedList<Integer>();
for (String parent_string : candidate_parents) {
leaves.clear();
Tree parent = top_level_.get(parent_string);
parent.getLeaves(leaves);
for (Tree leaf : leaves) {
candidate_list.add(tag_table_.toIndex(leaf.getName()));
}
}
candidates.add(candidate_list);
}
return candidates;
}
public List<String> getOpenTagClasses() {
return open_tag_classes_;
}
public List<String> getCandidates(String word_form) {
List<String> candidates = wordform_to_candidates_.get(word_form);
if (candidates != null) {
return candidates;
}
return open_tag_classes_;
}
public Properties getProperties() {
return props_;
}
public SymbolTable<String> getWordTable() {
return word_table_;
}
public SymbolTable<String> getTagTable() {
return tag_table_;
}
public void setTagTable(SymbolTable<String> tag_table) {
this.tag_table_ = tag_table;
}
public Map<String, Tree> getTopLevel() {
return top_level_;
}
public Map<String, Tree> getClustering() {
return clustering_;
}
public void setClustering(Map<String, Tree> clustering_) {
this.clustering_ = clustering_;
}
public Statistics getStatistics() {
return statistics_;
}
public void setStatistics(Statistics stats) {
this.statistics_ = stats;
}
public void saveToFile(String file_path) {
try {
FileOutputStream stream = new FileOutputStream(file_path);
GZIPOutputStream gzip = new GZIPOutputStream(stream);
ObjectOutputStream oos = new ObjectOutputStream(gzip);
oos.writeObject(this);
oos.close();
gzip.close();
stream.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public static Model loadFromFile(String file_path) {
try {
FileInputStream stream = new FileInputStream(file_path);
GZIPInputStream gzip = new GZIPInputStream(stream);
ObjectInputStream ois = new ObjectInputStream(gzip);
Object object = ois.readObject();
ois.close();
gzip.close();
stream.close();
boolean is_model = object instanceof Model;
if (!is_model) {
throw new RuntimeException("Object at " + file_path
+ "is not of class Model");
}
Model model = (Model) object;
if (model == null) {
throw new NullPointerException();
}
return model;
} catch (IOException e) {
throw new RuntimeException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
public int getLevel() {
for (Tree tree : getClustering().values()) {
return tree.getLevel();
}
throw new RuntimeException("Clustering is empty!");
}
public boolean isRare(String word) {
return rare_.contains(word);
}
public boolean isKnown(String word) {
return vocab_.contains(word);
}
public void setVocab(Set<String> vocab) {
vocab_ = vocab;
}
public void setWordTable(SymbolTable<String> word_table) {
word_table_ = word_table;
}
public void setTopLevel(Map<String, Tree> top_level) {
top_level_ = top_level;
}
public void setSuffixTrie(SuffixTrie trie) {
suffix_trie_ = trie;
}
public SuffixTrie getSuffixTrie() {
return suffix_trie_;
}
public void setProperties(Properties props) {
props_ = props;
}
public int getNumTags() {
// We don't count the border symbol.
return tag_table_.size() - 1;
}
}