// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package lemming.lemma.ranker;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import lemming.lemma.LemmaCandidate;
import lemming.lemma.LemmaCandidateSet;
import lemming.lemma.LemmaInstance;
import lemming.lemma.ranker.RankerTrainer.RankerTrainerOptions;
import lemming.lemma.toutanova.EditTreeAligner;
import marmot.morph.HashDictionary;
import marmot.morph.MorphDictionaryOptions;
import marmot.util.AspellLexicon;
import marmot.util.Converter;
import marmot.util.Encoder;
import marmot.util.FeatureTable;
import marmot.util.HashLexicon;
import marmot.util.Lexicon;
import marmot.util.StringUtils.Mode;
import marmot.util.SymbolTable;
import marmot.util.edit.EditTree;
public class RankerModel implements Serializable {
private static final long serialVersionUID = 1L;
private double[] weights_;
private SymbolTable<String> form_table_;
private SymbolTable<String> lemma_table_;
private SymbolTable<String> pos_table_;
private SymbolTable<Character> char_table_;
private SymbolTable<EditTree> tree_table_;
private EditTreeAligner aligner_;
private static final int max_window = 5;
private static final int window_bits_ = Encoder.bitsNeeded(max_window);
private static final int max_affix_length_ = 10;
private static enum Features {
lemma_feature, lemma_form_feature, align_feature, align_copy_feature, tree_feature, affix_feature, lexicon_feature, align_feature_output, tree_form_cluster_feature, lemma_cluster_feature, form_cluster_lemma_cluster_feature
}
private static final int feature_bits_ = Encoder.bitsNeeded(Features
.values().length - 1);
private static final int unigram_count_position_bits_ = Encoder
.bitsNeeded(HashLexicon.ARRAY_LENGTH - 1);
// Random relatively large prime number
private static final long max_weights_length_ = 11_549_873;
private static final int encoder_capacity_ = 15;
private Set<Integer> ignores_indexes_;
private int cluster_bits_;
private int lemma_bits_;
private int form_bits_;
private int char_bits_;
private int tree_bits_;
private List<Lexicon> unigram_lexicons_;
private int unigram_lexicons_bits_;
private SymbolTable<String> morph_table_;
private FeatureTable feature_table_;
private transient Encoder encoder_;
private transient Context context_;
private int real_capacity_;
private long pos_length_;
private long feat_length_;
private boolean use_shape_lexicon_;
private boolean use_core_features_;
private boolean use_alignment_features_;
private transient double[] accumulated_penalties_;
private double accumulated_penalty_;
private boolean copy_conjunctions_;
private HashDictionary cluster_dict_;
private final static double EPSILON = 1e-10;
private static class Context {
public List<Integer> list;
public boolean insert;
public Context() {
list = new ArrayList<>();
}
}
private static final int length_bits_ = Encoder
.bitsNeeded(2 * max_window + 10);
public void init(RankerTrainerOptions options,
List<RankerInstance> instances, EditTreeAligner aligner) {
SymbolTable<String> pos_table = null;
if (options.getUsePos()) {
pos_table = new SymbolTable<>();
}
SymbolTable<String> morph_table = null;
if (options.getUseMorph()) {
morph_table = new SymbolTable<>();
}
init(options, instances, aligner, pos_table, morph_table);
}
public void init(RankerTrainerOptions options,
List<RankerInstance> instances, EditTreeAligner aligner,
SymbolTable<String> pos_table, SymbolTable<String> morph_table) {
Logger logger = Logger.getLogger(getClass().getName());
aligner_ = aligner;
form_table_ = new SymbolTable<>();
lemma_table_ = new SymbolTable<>();
char_table_ = new SymbolTable<>();
tree_table_ = new SymbolTable<>();
pos_table_ = pos_table;
morph_table_ = morph_table;
for (RankerInstance instance : instances) {
fillTables(instance, instance.getCandidateSet());
}
form_bits_ = Encoder.bitsNeeded(form_table_.size() - 1);
lemma_bits_ = Encoder.bitsNeeded(lemma_table_.size() - 1);
char_bits_ = Encoder.bitsNeeded(char_table_.size());
tree_bits_ = Encoder.bitsNeeded(tree_table_.size() - 1);
logger.info(String.format("Number of edit trees: %5d",
tree_table_.size()));
if (pos_table_ != null) {
logger.info(String.format("Number of POS features: %3d",
pos_table_.size()));
logger.info(String.format("POS features: %s", pos_table_.keySet()));
}
if (morph_table_ != null) {
logger.info(String.format("Number of morph features: %3d",
morph_table_.size()));
logger.info(String.format("Morph features: %s",
morph_table_.keySet()));
}
int num_candidates = 0;
for (RankerInstance instance : instances) {
num_candidates += instance.getCandidateSet().size();
}
logger.info(String.format("Candidates per token: %d / %d = %g",
num_candidates, instances.size(), num_candidates
/ (double) instances.size()));
List<Object> unigram_files = options.getUnigramFile();
unigram_lexicons_ = new LinkedList<>();
for (Object unigram_file : unigram_files)
prepareUnigramFeature((String) unigram_file);
String cluster_file = options.getClusterFile();
cluster_dict_ = null;
if (!cluster_file.isEmpty()) {
prepareClusterFeature(cluster_file);
}
String aspell_path = options.getAspellPath();
if (!aspell_path.isEmpty()) {
String aspell_lang = options.getAspellLang();
logger.info(String.format("Adding aspell dictionary: %s",
aspell_lang));
unigram_lexicons_.add(new AspellLexicon(Mode.lower, aspell_path,
aspell_lang));
}
unigram_lexicons_bits_ = Encoder.bitsNeeded(unigram_lexicons_.size());
if (cluster_dict_ != null) {
cluster_bits_ = Encoder.bitsNeeded(cluster_dict_.numTags());
}
use_shape_lexicon_ = options.getUseShapeLexicon();
use_core_features_ = options.getUseCoreFeatures();
use_alignment_features_ = options.getUseAlignmentFeatures();
copy_conjunctions_ = options.getCopyConjunctions();
boolean use_hash_feature_table = options.getUseHashFeatureTable();
feature_table_ = FeatureTable.StaticMethods
.create(use_hash_feature_table);
logger.info("Starting feature index extraction.");
if (options.getUseOfflineFeatureExtraction()) {
for (RankerInstance instance : instances) {
addIndexes(instance, instance.getCandidateSet(), true);
}
}
// A random relatively large prime number
feat_length_ = 1_254_997l;
if (feature_table_ != null && options.getUseOfflineFeatureExtraction())
feat_length_ = feature_table_.size();
pos_length_ = (pos_table_ == null) ? 1 : pos_table_.size() + 1;
long morph_length = (morph_table_ == null) ? 1
: morph_table_.size() + 1;
long actual_length = feat_length_ * pos_length_ * morph_length;
logger.info(String.format("Actual weights length: %12d", actual_length));
int length = (int) Math.min(actual_length, max_weights_length_);
weights_ = new double[length];
if (feature_table_ != null && options.getUseOfflineFeatureExtraction())
logger.info(String.format("Number of features: %10d",
feature_table_.size()));
logger.info(String.format("Weights length: %6d", weights_.length));
logger.info(String.format("Real encoder capacity: %2d", real_capacity_));
String ignore_string = options.getIgnoreFeatures();
if (!ignore_string.isEmpty() && morph_table_ != null) {
ignores_indexes_ = new HashSet<>();
logger.info(String.format("Ignore-string: %s (%s)", ignore_string,
morph_table_));
for (String feat : ignore_string.split("\\|")) {
int index = morph_table_.toIndex(feat, -1);
ignores_indexes_.add(index);
logger.info(String
.format("Ignore-string: %s (%d)", feat, index));
}
}
}
private void prepareClusterFeature(String cluster_file) {
MorphDictionaryOptions options = MorphDictionaryOptions
.parse(cluster_file + ",indexes=[1],norm=umlaut");
cluster_dict_ = new HashDictionary();
cluster_dict_.init(options);
Logger logger = Logger.getLogger(getClass().getName());
logger.info(String
.format("Creating cluster lexicon from file: %s with %d entries and %d tags",
cluster_file, cluster_dict_.size(),
cluster_dict_.numTags()));
}
private void prepareUnigramFeature(String unigram_file) {
Logger logger = Logger.getLogger(getClass().getName());
if (unigram_file.isEmpty()) {
return;
}
String filename = null;
int min_count = 0;
for (String argument : unigram_file.split(",")) {
int index = argument.indexOf('=');
if (index >= 0) {
String argname = argument.substring(0, index);
String value = argument.substring(index + 1);
if (argname.equalsIgnoreCase("min-count")) {
min_count = Integer.valueOf(value);
} else {
throw new RuntimeException(String.format(
"Unknown option: %s", argname));
}
} else {
filename = argument;
}
}
if (filename == null) {
throw new RuntimeException(String.format(
"No filename specified: %s", unigram_file));
}
logger.info(String
.format("Creating unigram lexicon from file: %s and with min-count %d.",
filename, min_count));
HashLexicon lexicon = HashLexicon.readFromFile(filename, min_count);
logger.info(String.format("Created unigram lexicon with %7d entries.",
lexicon.size()));
unigram_lexicons_.add(lexicon);
}
private void fillTables(RankerInstance instance, LemmaCandidateSet set) {
String form = instance.getInstance().getForm();
form_table_.insert(form);
instance.getFormChars(char_table_, true);
instance.getPosIndex(pos_table_, true);
instance.getMorphIndexes(morph_table_, true);
for (Map.Entry<String, LemmaCandidate> candidate_pair : set) {
String lemma = candidate_pair.getKey();
LemmaCandidate candidate = candidate_pair.getValue();
if (use_alignment_features_) {
candidate.getLemmaChars(char_table_, lemma, true);
candidate.getAlignment(aligner_, form, lemma);
}
candidate.getTreeIndex(aligner_.getBuilder(), form, lemma,
tree_table_, true);
lemma_table_.insert(lemma);
}
}
public void removeIndexes(LemmaCandidateSet set) {
for (Map.Entry<String, LemmaCandidate> candidate_pair : set) {
LemmaCandidate candidate = candidate_pair.getValue();
candidate.setFeatureIndexes(null);
}
}
public void addIndexes(RankerInstance instance, LemmaCandidateSet set,
boolean insert) {
int[] form_cluster_indexes = null;
if (cluster_dict_ != null)
form_cluster_indexes = cluster_dict_.getIndexes(instance
.getInstance().getForm());
if (context_ == null) {
context_ = new Context();
encoder_ = new Encoder(encoder_capacity_);
}
String form = instance.getInstance().getForm();
int form_index = form_table_.toIndex(form, -1);
context_.insert = insert;
short[] form_chars = instance.getFormChars(char_table_, false);
for (Map.Entry<String, LemmaCandidate> candidate_pair : set) {
if (candidate_pair.getValue().getFeatureIndexes() == null) {
context_.list.clear();
String lemma = candidate_pair.getKey();
int[] lemma_cluster_indexes = null;
if (cluster_dict_ != null)
lemma_cluster_indexes = cluster_dict_.getIndexes(lemma);
if (lemma_cluster_indexes != null) {
for (int lemma_cluster_index : lemma_cluster_indexes) {
encoder_.append(
Features.lemma_cluster_feature.ordinal(),
feature_bits_);
encoder_.append(lemma_cluster_index, cluster_bits_);
addFeature();
}
if (form_cluster_indexes != null) {
for (int form_cluster_index : form_cluster_indexes) {
for (int lemma_cluster_index : lemma_cluster_indexes) {
encoder_.append(
Features.form_cluster_lemma_cluster_feature
.ordinal(), feature_bits_);
encoder_.append(lemma_cluster_index,
cluster_bits_);
encoder_.append(form_cluster_index,
cluster_bits_);
addFeature();
}
}
}
}
int lemma_index = lemma_table_.toIndex(lemma, -1, false);
LemmaCandidate candidate = candidate_pair.getValue();
if (use_core_features_) {
if (lemma_index >= 0) {
encoder_.append(Features.lemma_feature.ordinal(),
feature_bits_);
encoder_.append(lemma_index, lemma_bits_);
addFeature();
}
if (lemma_index >= 0 && form_index >= 0) {
encoder_.append(Features.lemma_form_feature.ordinal(),
feature_bits_);
encoder_.append(lemma_index, lemma_bits_);
encoder_.append(form_index, form_bits_);
addFeature();
}
int tree_index = candidate.getTreeIndex(
aligner_.getBuilder(), form, lemma, tree_table_,
false);
if (tree_index >= 0) {
encoder_.append(Features.tree_feature.ordinal(),
feature_bits_);
encoder_.append(tree_index, tree_bits_);
addFeature();
encoder_.append(Features.tree_feature.ordinal(),
feature_bits_);
encoder_.append(tree_index, tree_bits_);
addPrefixFeatures(form_chars);
encoder_.reset();
encoder_.append(Features.tree_feature.ordinal(),
feature_bits_);
encoder_.append(tree_index, tree_bits_);
addSuffixFeatures(form_chars);
encoder_.reset();
if (form_cluster_indexes != null) {
for (int cluster_index : form_cluster_indexes) {
encoder_.append(
Features.tree_form_cluster_feature
.ordinal(), feature_bits_);
encoder_.append(tree_index, tree_bits_);
encoder_.append(cluster_index, cluster_bits_);
addFeature();
}
}
}
}
if (use_alignment_features_) {
short[] lemma_chars = candidate.getLemmaChars(char_table_,
lemma, false);
List<Integer> alignment = candidate.getAlignment(aligner_,
form, lemma);
addAlignmentIndexes(form_chars, lemma_chars, alignment);
addAffixIndexes(lemma_chars);
}
int lexicon_index = 0;
for (Lexicon lexicon : unigram_lexicons_) {
addUnigramFeature(lexicon_index, lexicon, lemma);
lemma_index++;
}
candidate
.setFeatureIndexes(Converter.toIntArray(context_.list));
}
}
}
private void addUnigramFeature(int lexicon_index, Lexicon unigram_lexicon,
String lemma) {
int[] counts = unigram_lexicon.getCount(lemma);
if (counts == null)
return;
if (use_shape_lexicon_) {
for (int i = 0; i < HashLexicon.ARRAY_LENGTH; i++) {
int count = counts[i];
if (count > 0) {
encoder_.append(Features.lexicon_feature.ordinal(),
feature_bits_);
encoder_.append(lexicon_index, unigram_lexicons_bits_);
encoder_.append(i, unigram_count_position_bits_);
addFeature();
}
}
} else {
int count = counts[HashLexicon.ARRAY_LENGTH - 1];
if (count > 0) {
encoder_.append(Features.lexicon_feature.ordinal(),
feature_bits_);
encoder_.append(lexicon_index, unigram_lexicons_bits_);
addFeature();
}
}
}
private void addPrefixFeatures(short[] chars) {
encoder_.append(false);
for (int i = 0; i < Math.min(chars.length, max_affix_length_); i++) {
int c = chars[i];
if (c < 0)
return;
encoder_.append(c, char_bits_);
addFeature(false);
}
}
private void addSuffixFeatures(short[] chars) {
encoder_.append(true);
for (int i = chars.length - 1; i >= Math.max(0, chars.length
- max_affix_length_); i--) {
int c = chars[i];
if (c < 0)
return;
encoder_.append(c, char_bits_);
addFeature(false);
}
}
private void addAffixIndexes(short[] lemma_chars) {
encoder_.append(Features.affix_feature.ordinal(), feature_bits_);
addPrefixFeatures(lemma_chars);
encoder_.reset();
encoder_.append(Features.affix_feature.ordinal(), feature_bits_);
addSuffixFeatures(lemma_chars);
encoder_.reset();
}
private void addAlignmentIndexes(short[] form_chars, short[] lemma_chars,
List<Integer> alignment) {
Iterator<Integer> iterator = alignment.iterator();
int input_start = 0;
int output_start = 0;
while (iterator.hasNext()) {
int input_length = iterator.next();
int output_length = iterator.next();
int input_end = input_start + input_length;
int output_end = output_start + output_length;
addAlignmentSegmentIndexes(form_chars, lemma_chars, input_start,
input_end, output_start, output_end);
input_start = input_end;
output_start = output_end;
}
}
private void addAlignmentSegmentIndexes(short[] form_chars,
short[] lemma_chars, int input_start, int input_end,
int output_start, int output_end) {
if (isCopySegment(form_chars, lemma_chars, input_start, input_end,
output_start, output_end)) {
encoder_.append(Features.align_copy_feature.ordinal(),
feature_bits_);
addFeature();
if (!copy_conjunctions_) {
return;
}
}
encoder_.append(Features.align_feature.ordinal(), feature_bits_);
addSegment(form_chars, input_start, input_end);
addSegment(lemma_chars, output_start, output_end);
addFeature(false);
addWindow(form_chars, lemma_chars, input_start, input_end,
output_start, output_end);
encoder_.reset();
encoder_.append(Features.align_feature_output.ordinal(), feature_bits_);
addSegment(lemma_chars, output_start, output_end);
addFeature(false);
addWindow(form_chars, lemma_chars, input_start, input_end,
output_start, output_end);
encoder_.reset();
}
private void addWindow(short[] form_chars, short[] lemma_chars,
int input_start, int input_end, int output_start, int output_end) {
encoder_.storeState();
int feature_bits_ = 3;
for (int window = 1; window <= max_window; window++) {
int index = 0;
encoder_.append(index++, feature_bits_);
encoder_.append(window, window_bits_);
addSegment(form_chars, input_start - window, input_start);
addSegment(form_chars, input_end + 1, input_end + window + 1);
addFeature(false);
encoder_.restoreState();
encoder_.append(index++, feature_bits_);
encoder_.append(window, window_bits_);
addSegment(form_chars, input_end + 1, input_end + window + 1);
addFeature(false);
encoder_.restoreState();
encoder_.append(index++, feature_bits_);
encoder_.append(window, window_bits_);
addSegment(form_chars, input_start - window, input_start);
addFeature(false);
encoder_.restoreState();
}
}
private boolean isCopySegment(short[] form_chars, short[] lemma_chars,
int input_start, int input_end, int output_start, int output_end) {
if (input_end - input_start != 1)
return false;
if (output_end - output_start != 1) {
return false;
}
return form_chars[input_start] == lemma_chars[output_start];
}
private void addSegment(short[] chars, int start, int end) {
encoder_.append(end - start, length_bits_);
for (int i = start; i < end; i++) {
int c;
if (i >= 0 && i < chars.length) {
c = chars[i];
} else {
c = char_table_.size();
}
if (c < 0)
return;
encoder_.append(c, char_bits_);
}
}
private void addFeature(boolean reset) {
real_capacity_ = Math.max(real_capacity_, encoder_.getCurrentLength());
int index = feature_table_.getFeatureIndex(encoder_, context_.insert);
if (index >= 0) {
context_.list.add(index);
}
if (reset)
encoder_.reset();
}
private void addFeature() {
addFeature(true);
}
public String select(RankerInstance instance) {
Map.Entry<String, LemmaCandidate> best_pair = null;
for (Map.Entry<String, LemmaCandidate> candidate_pair : instance
.getCandidateSet()) {
LemmaCandidate candidate = candidate_pair.getValue();
double score = score(candidate,
instance.getPosIndex(pos_table_, false),
instance.getMorphIndexes(morph_table_, false));
candidate.setScore(score);
if (best_pair == null || score > best_pair.getValue().getScore()) {
best_pair = candidate_pair;
}
}
return best_pair.getKey();
}
public double score(LemmaCandidate candidate, int pos_index,
int[] morph_indexes) {
assert candidate != null;
double score = 0.0;
for (int index : candidate.getFeatureIndexes()) {
score += updateScore(index, pos_index, morph_indexes, 0.0);
}
return score;
}
public void update(RankerInstance instance, String lemma, double update) {
LemmaCandidate candidate = instance.getCandidateSet().getCandidate(
lemma);
update(candidate, instance.getPosIndex(pos_table_, false),
instance.getMorphIndexes(morph_table_, false), update);
}
public void update(LemmaCandidate candidate, int pos_index,
int[] morph_indexes, double update) {
for (int index : candidate.getFeatureIndexes()) {
updateScore(index, pos_index, morph_indexes, update);
}
}
private double updateScore(long index, long pos_index, int[] morph_indexes,
double update) {
double score = 0.0;
long f_index = index;
score += updateScore(f_index, update);
if (pos_index >= 0) {
long p_index = f_index + feat_length_ * (pos_index + 1L);
score += updateScore(p_index, update);
for (long morph_index : morph_indexes) {
if (ignores_indexes_ != null) {
if (ignores_indexes_.contains((int) morph_index)) {
continue;
}
}
long m_index = p_index + (morph_index + 1L) * feat_length_
* pos_length_;
score += updateScore(m_index, update);
}
}
return score;
}
private double updateScore(long index, double update) {
int int_index = (int) (index % (long) weights_.length);
weights_[int_index] += update;
if (accumulated_penalties_ != null && Math.abs(update) > EPSILON) {
applyPenalty(weights_[int_index], int_index);
}
return weights_[int_index];
}
public void setPenalty(boolean penalize, double accumulated_penalty) {
if (penalize) {
accumulated_penalty_ = accumulated_penalty;
if (accumulated_penalties_ == null) {
accumulated_penalties_ = new double[weights_.length];
}
} else {
accumulated_penalties_ = null;
}
}
private void applyPenalty(double weight, int index) {
double old_weight = weight;
if (old_weight - EPSILON > 0.) {
weight = Math.max(0, old_weight
- (accumulated_penalty_ + accumulated_penalties_[index]));
} else if (old_weight + EPSILON < 0.) {
weight = Math.min(0, old_weight
+ (accumulated_penalty_ - accumulated_penalties_[index]));
}
accumulated_penalties_[index] += weight - old_weight;
weights_[index] = weight;
}
public double[] getWeights() {
return weights_;
}
public void setWeights(double[] weights) {
weights_ = weights;
}
public SymbolTable<String> getPosTable() {
return pos_table_;
}
public SymbolTable<String> getMorphTable() {
return morph_table_;
}
public boolean isOOV(LemmaInstance instance) {
return form_table_.toIndex(instance.getForm(), -1) == -1;
}
public List<Double> scores(RankerInstance rinstance) {
LemmaCandidateSet set = rinstance.getCandidateSet();
List<Double> scores = new ArrayList<>(set.size());
for (LemmaCandidate candidate : set.getCandidates()) {
double score = score(candidate, rinstance.getPosIndex(pos_table_, false), rinstance.getMorphIndexes(morph_table_, false));
scores.add(score);
candidate.setScore(score);
}
return scores;
}
}