// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.morph;
import java.io.File;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import lemming.lemma.BackupLemmatizer;
import lemming.lemma.GoldLemmaGenerator;
import lemming.lemma.LemmaCandidate;
import lemming.lemma.LemmaCandidateGenerator;
import lemming.lemma.LemmaCandidateSet;
import lemming.lemma.LemmaInstance;
import lemming.lemma.LemmatizerGeneratorTrainer;
import lemming.lemma.SimpleLemmatizer;
import lemming.lemma.SimpleLemmatizerTrainer;
import lemming.lemma.SimpleLemmatizerTrainer.SimpleLemmatizerTrainerOptions;
import lemming.lemma.ranker.Ranker;
import lemming.lemma.ranker.RankerCandidate;
import lemming.lemma.ranker.RankerInstance;
import lemming.lemma.ranker.RankerModel;
import lemming.lemma.ranker.RankerTrainer;
import lemming.lemma.ranker.RankerTrainer.RankerTrainerOptions;
import lemming.lemma.toutanova.EditTreeAligner;
import lemming.lemma.toutanova.EditTreeAlignerTrainer;
import marmot.core.Model;
import marmot.core.Options;
import marmot.core.Sequence;
import marmot.core.State;
import marmot.core.Tagger;
import marmot.core.Token;
import marmot.core.Trainer;
import marmot.core.TrainerFactory;
import marmot.core.WeightVector;
import marmot.morph.analyzer.Analyzer;
import marmot.morph.signature.Trie;
import marmot.util.Copy;
import marmot.util.Counter;
import marmot.util.FeatUtil;
import marmot.util.FileUtils;
import marmot.util.StringUtils;
import marmot.util.StringUtils.Mode;
import marmot.util.SymbolTable;
public class MorphModel extends Model {
private static final long serialVersionUID = 2L;
private final static int POS_INDEX_ = 0;
private final static int MORPH_INDEX_ = 1;
private final static String POS_NAME_ = "pos";
private final static String MORPH_NAME_ = "morph";
private SymbolTable<String> word_table_;
private SymbolTable<String> shape_table_;
private SymbolTable<Character> char_table_;
private SymbolTable<String> token_feature_table_;
private SymbolTable<String> weighted_token_feature_table_;
private List<SymbolTable<String>> subtag_tables_;
private transient Map<String, Integer> signature_cache;
private int[] vocab_;
private int[][] tag_classes_;
private int[][] transitions_;
private int[][][] tag_to_subtag_;
private List<Set<Integer>> observed_sets_;
private int[][] word_to_observed_tags_;
private Trie trie_;
private boolean verbose_;
private boolean shape_;
private boolean tag_morph_;
private int num_folds_;
private int rare_word_max_freq_;
private boolean split_morphs_;
private boolean split_pos_;
private Mode normalize_forms_;
private Analyzer analyzer_;
private RankerModel lemma_model_;
private List<LemmaCandidateGenerator> generators_;
private transient Map<String, List<RankerInstance>> lemma_instance_map_;
public void init(MorphOptions options, Collection<Sequence> sentences) {
verbose_ = options.getVerbose();
rare_word_max_freq_ = options.getRareWordMaxFreq();
shape_ = options.getShape();
tag_morph_ = options.getTagMorph();
split_pos_ = options.getSplitPos();
split_morphs_ = options.getSplitMorphs();
normalize_forms_ = options.getNormalizeForms();
special_signature_ = options.getSpecialSignature();
num_folds_ = options.getNumFolds();
restrict_pos_tags_to_seen_combinations_ = options
.getRestrictPosTagsToSeenCombinations();
init(options, extractCategories(sentences));
subtag_tables_ = new ArrayList<SymbolTable<String>>();
subtag_tables_.add(null);
subtag_tables_.add(null);
if (split_pos_) {
subtag_tables_.set(POS_INDEX_, new SymbolTable<String>());
}
if (tag_morph_ && split_morphs_) {
subtag_tables_.set(MORPH_INDEX_, new SymbolTable<String>());
}
word_table_ = new SymbolTable<String>(true);
char_table_ = new SymbolTable<Character>();
if (shape_) {
shape_table_ = new SymbolTable<String>();
}
signature_cache = new HashMap<String, Integer>();
token_feature_table_ = new SymbolTable<String>();
weighted_token_feature_table_ = new SymbolTable<String>();
String internal_analyzer = options.getInternalAnalyzer();
if (internal_analyzer != null) {
analyzer_ = Analyzer.create(internal_analyzer);
}
if (shape_) {
File file = null;
if (!options.getShapeTriePath().isEmpty()) {
file = new File(options.getShapeTriePath());
}
if (file == null || !file.exists()) {
if (verbose_) {
System.err.println("Inducing shape trie.");
}
trie_ = Trie.train(sentences, options.getVeryVerbose());
if (file != null) {
if (verbose_) {
System.err.format("Writing shape trie to: %s.\n",
options.getShapeTriePath());
}
FileUtils.saveToFile(trie_, options.getShapeTriePath());
}
} else {
System.err.format("Loading shape trie from: %s.\n",
options.getShapeTriePath());
trie_ = FileUtils.loadFromFile(options.getShapeTriePath());
}
}
if (trie_ == null) {
shape_ = false;
}
for (Sequence sentence : sentences) {
for (Token token : sentence) {
Word word = (Word) token;
addIndexes(word, true);
}
}
vocab_ = extractVocabulary(options, sentences);
transitions_ = extractPossibleTransitions(options, sentences);
observed_sets_ = extractObservedSets(sentences);
tag_classes_ = extractTagClasses(getTagTables());
tag_to_subtag_ = extractSubTags(options.getSubTagSeparator());
for (Sequence sentence : sentences) {
for (Token token : sentence) {
Word word = (Word) token;
addShape(word, word.getWordForm(), true);
}
}
if (options.getLemmatizer()) {
initLemmatizer(options, sentences);
}
}
private void initLemmatizer(MorphOptions options,
Collection<Sequence> sentences) {
lemma_use_morph_ = options.getLemmaUseMorph();
marginalize_lemmas_ = options.getMarginalizeLemmas();
lemma_prepruning_extraction_ = options.getLemmaPrePruningExtraction();
lemma_tag_dependent_ = options.getLemmaTagDependent();
RankerTrainerOptions roptions = new RankerTrainerOptions();
roptions.setOption(RankerTrainerOptions.UNIGRAM_FILE,
options.getLemmaUnigramFile());
roptions.setOption(RankerTrainerOptions.IGNORE_FEATURES,
options.getLemmaIgnoreFeatures());
roptions.setOption(RankerTrainerOptions.ASPELL_PATH,
options.getLemmaAspellPath());
roptions.setOption(RankerTrainerOptions.ASPELL_LANG,
options.getLemmaAspellLang());
roptions.setOption(RankerTrainerOptions.USE_SHAPE_LEXICON,
options.getLemmaUseShapeLexicon());
roptions.setOption(RankerTrainerOptions.CLUSTER_FILE,
options.getLemmaClusterFile());
roptions.setOption(RankerTrainerOptions.TAG_DEPENDENT,
lemma_tag_dependent_);
roptions.setOption(RankerTrainerOptions.OFFLINE_FEATURE_EXTRACTION,
false);
roptions.setOption(RankerTrainerOptions.USE_HASH_FEATURE_TABLE,
options.getUseHashFeatureTable());
List<LemmaInstance> instances = LemmaInstance.getInstances(sentences,
true, false);
if (options.getGoldLemma()) {
generators_ = Collections
.singletonList((LemmaCandidateGenerator) new GoldLemmaGenerator());
} else if (options.getLemmaUseLemmingGenerator() > 0) {
LemmatizerGeneratorTrainer trainer = new RankerTrainer();
RankerTrainerOptions new_roptions = new RankerTrainerOptions(
roptions);
new_roptions.setOption(RankerTrainerOptions.USE_MALLET, false);
new_roptions.setOption(RankerTrainerOptions.USE_PERCEPTRON, false);
new_roptions.setOption(RankerTrainerOptions.USE_MORPH, false);
new_roptions
.setOption(RankerTrainerOptions.USE_SHAPE_LEXICON, true);
new_roptions
.setOption(RankerTrainerOptions.USE_CORE_FEATURES, true);
new_roptions.setOption(RankerTrainerOptions.USE_ALIGNMENT_FEATURES,
true);
new_roptions.setOption(
RankerTrainerOptions.OFFLINE_FEATURE_EXTRACTION, false);
new_roptions.setOption(RankerTrainerOptions.TAG_DEPENDENT, true);
new_roptions.setOption(RankerTrainerOptions.USE_HASH_FEATURE_TABLE,
true);
((RankerTrainer) trainer).setOptions(new_roptions);
Ranker ranker = (Ranker) trainer.train(instances, null);
ranker.setNumCandidates(options.getLemmaUseLemmingGenerator());
trainer = new SimpleLemmatizerTrainer();
trainer.getOptions().setOption(
SimpleLemmatizerTrainerOptions.USE_BACKUP, false);
SimpleLemmatizer simple = (SimpleLemmatizer) trainer.train(
instances, null);
generators_ = Collections
.singletonList((LemmaCandidateGenerator) new BackupLemmatizer(
simple, ranker));
} else {
generators_ = roptions.getGenerators(instances);
}
SymbolTable<String> pos_table = getTagTables().get(POS_INDEX_);
for (Sequence sentence : sentences) {
for (Token token : sentence) {
Word word = (Word) token;
addRankerInstances(word);
}
}
SymbolTable<String> morph_table = null;
if (MORPH_INDEX_ < subtag_tables_.size()) {
morph_table = subtag_tables_.get(MORPH_INDEX_);
}
EditTreeAlignerTrainer trainer = new EditTreeAlignerTrainer(
roptions.getRandom(), false, 1, -1);
EditTreeAligner aligner = (EditTreeAligner) trainer.train(instances);
List<RankerInstance> rinstances = new LinkedList<>();
for (List<RankerInstance> list : lemma_instance_map_.values())
for (RankerInstance instance : list)
if (instance != null)
rinstances.add(instance);
lemma_model_ = new RankerModel();
lemma_model_
.init(roptions, rinstances, aligner, pos_table, morph_table);
if (options.getLemmaPretraining()) {
skip_lemma_ = true;
} else {
skip_lemma_ = false;
}
}
private int getBiIndex(int word, int level, int tag) {
int length = 1;
for (int clevel = 0; clevel <= level; clevel++) {
length *= getTagTables().get(clevel).size();
}
assert tag < length;
return word * length + tag;
}
public boolean hasBeenObserved(int form_index, int level, int tag_index) {
if (isRare(form_index)) {
form_index = word_table_.size();
}
Set<Integer> set = observed_sets_.get(level);
int index = getBiIndex(form_index, level, tag_index);
return set.contains(index);
}
private int[][][] extractSubTags(String subtag_separator) {
int[][][] tag_to_subtag = new int[subtag_tables_.size()][][];
int offset = 0;
for (int level = 0; level < subtag_tables_.size(); level++) {
if (level >= getTagTables().size())
break;
SymbolTable<String> table = getTagTables().get(level);
if (table != null && subtag_tables_.get(level) != null) {
tag_to_subtag[level] = new int[table.size()][];
for (Map.Entry<String, Integer> entry : table.entrySet()) {
tag_to_subtag[level][entry.getValue()] = getSubTags(
entry.getKey(), level, true, offset,
subtag_separator);
}
offset += subtag_tables_.get(level).size();
}
}
return tag_to_subtag;
}
private int[][] extractTagClasses(List<SymbolTable<String>> tag_tables) {
int[][] tag_classes = new int[tag_tables.size()][];
for (int level = 0; level < tag_tables.size(); level++) {
int num_tags = tag_tables.get(level).size();
tag_classes[level] = new int[num_tags - 1];
int index = 0;
for (int tag_index = 0; tag_index < num_tags; tag_index++) {
if (tag_index == getBoundaryIndex())
continue;
tag_classes[level][index] = tag_index;
index++;
}
}
return tag_classes;
}
private List<Set<Integer>> extractObservedSets(
Collection<Sequence> sentences) {
List<SymbolTable<String>> tag_tables = getTagTables();
List<Set<Integer>> observed_sets = new ArrayList<Set<Integer>>(
tag_tables.size());
List<Map<Integer, Set<Integer>>> wordform_to_candidates = new ArrayList<Map<Integer, Set<Integer>>>();
for (int level = 0; level < tag_tables.size(); level++) {
wordform_to_candidates.add(new HashMap<Integer, Set<Integer>>());
}
for (Sequence sentence : sentences) {
for (Token xtoken : sentence) {
Word token = (Word) xtoken;
int word_index = token.getWordFormIndex();
int tag_index = 0;
for (int level = 0; level < tag_tables.size(); level++) {
tag_index *= tag_tables.get(level).size();
tag_index += token.getTagIndexes()[level];
Set<Integer> tags = wordform_to_candidates.get(level).get(
word_index);
if (tags == null) {
tags = new HashSet<Integer>();
wordform_to_candidates.get(level).put(word_index, tags);
}
tags.add(tag_index);
}
}
}
if (restrict_pos_tags_to_seen_combinations_) {
word_to_observed_tags_ = new int[vocab_.length][];
for (Map.Entry<Integer, Set<Integer>> entry : wordform_to_candidates
.get(0).entrySet()) {
int word_index = entry.getKey();
if (!isRare(word_index)) {
Set<Integer> tag_set = entry.getValue();
int[] tags = new int[tag_set.size()];
int index = 0;
for (int tag : tag_set) {
tags[index++] = tag;
}
word_to_observed_tags_[word_index] = tags;
}
}
}
List<List<Integer>> open_tag_classes_per_level = getOpenPosTagClassesCrossValidation(
sentences, num_folds_, tag_tables);
for (int level = 0; level < tag_tables.size(); level++) {
Set<Integer> observed_set = new HashSet<Integer>();
observed_sets.add(observed_set);
List<Integer> open_tag_classes = open_tag_classes_per_level
.get(level);
for (int tag : open_tag_classes) {
int biindex = getBiIndex(word_table_.size(), level, tag);
observed_set.add(biindex);
}
for (Entry<Integer, Set<Integer>> entry : wordform_to_candidates
.get(level).entrySet()) {
int word_index = entry.getKey();
Set<Integer> set = entry.getValue();
if (!isRare(word_index)) {
int[] tags = new int[set.size()];
int index = 0;
for (int tag : set) {
tags[index++] = tag;
}
for (int tag : tags) {
int biindex = getBiIndex(word_index, level, tag);
observed_set.add(biindex);
}
}
}
}
return observed_sets;
}
public static List<List<Integer>> getOpenPosTagClassesCrossValidation(
Collection<Sequence> sentences, int num_folds,
List<SymbolTable<String>> tag_tables) {
int sentences_per_fold = sentences.size() / num_folds;
if (sentences_per_fold == 0)
sentences_per_fold = 1;
Set<Integer> known = new HashSet<Integer>();
List<Counter<Integer>> counters = new ArrayList<Counter<Integer>>(
tag_tables.size());
for (int level = 0; level < tag_tables.size(); level++) {
counters.add(new Counter<Integer>());
}
int start_index = 0;
while (start_index < sentences.size()) {
known.clear();
int end_index = start_index + sentences_per_fold;
if (end_index + sentences_per_fold >= sentences.size()) {
end_index = sentences.size();
}
int index = 0;
for (Sequence sentence : sentences) {
if (index < start_index || index >= end_index) {
for (Token token : sentence) {
known.add(((Word) token).getWordFormIndex());
}
}
index++;
}
index = 0;
for (Sequence sentence : sentences) {
if (index >= start_index && index < end_index) {
for (Token token : sentence) {
int form = ((Word) token).getWordFormIndex();
if (!known.contains(form)) {
int tag_index = 0;
for (int level = 0; level < tag_tables.size(); level++) {
tag_index *= tag_tables.get(level).size();
tag_index += token.getTagIndexes()[level];
counters.get(level).increment(tag_index, 1.0);
}
}
}
}
index++;
}
start_index = end_index;
}
List<List<Integer>> list = new ArrayList<List<Integer>>(
tag_tables.size());
for (int level = 0; level < tag_tables.size(); level++) {
Counter<Integer> counter = counters.get(level);
double total_count = counter.totalCount();
List<Integer> open_tag_classes = new LinkedList<Integer>();
for (Map.Entry<Integer, Double> entry : counter.entrySet()) {
if (entry.getValue() / total_count > 0.0001) {
open_tag_classes.add(entry.getKey());
}
}
list.add(open_tag_classes);
}
return list;
}
private int[] extractVocabulary(MorphOptions options,
Collection<Sequence> sentences) {
Counter<Integer> vocab_counter = new Counter<Integer>();
for (Sequence sentence : sentences) {
for (Token token : sentence) {
Word word = (Word) token;
int word_index = word.getWordFormIndex();
vocab_counter.increment(word_index, 1.);
}
}
int[] vocab_array = new int[vocab_counter.size()];
for (Map.Entry<Integer, Double> entry : vocab_counter.entrySet()) {
vocab_array[entry.getKey()] = entry.getValue().intValue();
}
return vocab_array;
}
private int[][] extractPossibleTransitions(MorphOptions options,
Collection<Sequence> sentences) {
if (!(options.getRestricTransitions() && tag_morph_))
return null;
Map<Integer, Set<Integer>> tag_to_morph = new HashMap<Integer, Set<Integer>>();
for (Sequence sentence : sentences) {
for (Token token : sentence) {
int from_index = token.getTagIndexes()[POS_INDEX_];
int to_index = token.getTagIndexes()[MORPH_INDEX_];
Set<Integer> tags = tag_to_morph.get(from_index);
if (tags == null) {
tags = new HashSet<Integer>();
tag_to_morph.put(from_index, tags);
}
tags.add(to_index);
}
}
int[][] transitions = new int[tag_to_morph.size() + 1][];
transitions[0] = new int[1];
for (Map.Entry<Integer, Set<Integer>> entry : tag_to_morph.entrySet()) {
int from_index = entry.getKey();
int[] to_indexes = new int[entry.getValue().size()];
int index = 0;
for (int to_index : entry.getValue()) {
to_indexes[index++] = to_index;
}
Arrays.sort(to_indexes);
assert transitions[from_index] == null;
transitions[from_index] = to_indexes;
}
return transitions;
}
private SymbolTable<String> extractCategories(Collection<Sequence> sentences) {
SymbolTable<String> catgegory_table = new SymbolTable<String>(true);
catgegory_table.toIndex(POS_NAME_, true);
if (tag_morph_) {
catgegory_table.toIndex(MORPH_NAME_, true);
}
return catgegory_table;
}
private transient Set<Character> unseen_char_set_;
private boolean special_signature_;
private boolean skip_lemma_;
private boolean marginalize_lemmas_;
private boolean lemma_use_morph_;
private boolean lemma_tag_dependent_;
private boolean restrict_pos_tags_to_seen_combinations_;
private void addCharIndexes(Word word, String form, boolean insert) {
short[] char_indexes = FeatUtil.getCharIndexes(form, char_table_,
insert);
assert char_indexes != null;
for (int index = 0; index < form.length(); index++) {
char c = form.charAt(index);
if (char_indexes[index] < 0) {
if (verbose_) {
if (unseen_char_set_ == null) {
unseen_char_set_ = new HashSet<Character>();
}
if (!unseen_char_set_.contains(c)) {
System.err
.format("Warning: Unknown character: %c\n", c);
unseen_char_set_.add(c);
}
}
}
}
word.setCharIndexes(char_indexes);
}
private void addSignature(Word word, String form, boolean insert) {
if (signature_cache == null) {
signature_cache = new HashMap<String, Integer>();
}
Integer signature = signature_cache.get(form);
if (signature == null) {
signature = FeatUtil.getSignature(form, special_signature_);
signature_cache.put(form, signature);
}
word.setWordSignature(signature);
}
private void addTokenFeatures(Word word, Word in_word, boolean insert) {
String[] token_features = in_word.getTokenFeatures();
List<String> readings = null;
if (analyzer_ != null) {
readings = analyzer_.analyze(in_word.getWordForm());
}
int indexes_length = 0;
if (token_features != null) {
indexes_length += token_features.length;
}
if (readings != null) {
indexes_length += readings.size();
}
if (indexes_length > 0) {
int[] indexes = new int[indexes_length];
int index = 0;
if (token_features != null) {
for (String feature : token_features) {
indexes[index] = token_feature_table_.toIndex(feature, -1,
insert);
index++;
}
}
if (readings != null) {
for (String feature : readings) {
indexes[index] = token_feature_table_.toIndex(feature, -1,
insert);
index++;
}
}
word.setTokenFeatureIndexes(indexes);
}
token_features = word.getWeightedTokenFeatures();
if (token_features != null && weighted_token_feature_table_ != null) {
int[] indexes = new int[token_features.length];
int index = 0;
for (String feature : token_features) {
indexes[index] = weighted_token_feature_table_.toIndex(feature,
-1, insert);
index++;
}
word.setWeightedTokenFeatureIndexes(indexes);
}
}
public void addIndexes(Word word, boolean insert) {
String word_form = word.getWordForm();
addTagIndexes(word, -1, insert);
addSignature(word, word_form, insert);
addTokenFeatures(word, word, insert);
addShape(word, word_form, insert);
String normalized_form = StringUtils.normalize(word_form,
normalize_forms_);
int word_index = word_table_.toIndex(normalized_form, -1, insert);
word.setWordIndex(word_index);
addCharIndexes(word, normalized_form, insert);
}
private RankerInstance getRankerInstance(Word word, int pos_index,
boolean training) {
List<RankerInstance> instances = word.getRankerIstances();
if (instances == null) {
instances = addRankerInstances(word);
}
if (!lemma_tag_dependent_) {
pos_index = 0;
}
RankerInstance instance = instances.get(pos_index);
assert instance != null;
return instance;
}
private List<RankerInstance> addRankerInstances(Word word) {
if (lemma_instance_map_ == null) {
lemma_instance_map_ = new HashMap<>();
}
List<RankerInstance> instances = lemma_instance_map_.get(word
.getWordForm());
if (instances == null) {
SymbolTable<String> pos_table = getTagTables().get(0);
if (lemma_tag_dependent_) {
instances = new ArrayList<>(pos_table.size());
for (int index = 0; index < pos_table.size(); index++) {
instances.add(null);
}
LemmaCandidateSet total_set = new LemmaCandidateSet();
for (Map.Entry<String, Integer> entry : pos_table.entrySet()) {
int current_pos_index = entry.getValue();
String current_pos = entry.getKey();
if (restrict_pos_tags_to_seen_combinations_
&& !isRare(word.getWordFormIndex())
&& !hasBeenObserved(word.getWordFormIndex(), 0,
current_pos_index))
continue;
RankerInstance rinstance = getRankerInstance(word,
current_pos, total_set);
instances.set(current_pos_index, rinstance);
}
} else {
RankerInstance rinstance = getRankerInstance(word, "_", null);
instances = Collections.singletonList(rinstance);
}
lemma_instance_map_.put(word.getWordForm(), instances);
}
word.setRankerIstances(instances);
return instances;
}
private RankerInstance getRankerInstance(Word word, String pos,
LemmaCandidateSet total_set) {
LemmaInstance instance = LemmaInstance.getInstance(word, false, false);
instance.setPosTag(pos);
RankerInstance rinstance = RankerInstance.getInstance(instance,
generators_);
instance.setPosTag(null);
if (total_set != null) {
LemmaCandidateSet new_set = new LemmaCandidateSet();
for (Map.Entry<String, LemmaCandidate> entry : rinstance
.getCandidateSet()) {
LemmaCandidate candidate = total_set.getCandidate(entry
.getKey());
new_set.addCandidate(entry.getKey(), candidate);
}
rinstance.setCandidateSet(new_set);
}
if (rinstance.getCandidateSet().size() == 0) {
if (total_set == null) {
rinstance.getCandidateSet().getCandidate(instance.getForm());
} else {
rinstance.getCandidateSet().addCandidate(instance.getForm(),
total_set.getCandidate(instance.getForm()));
}
}
return rinstance;
}
private int[] getSubTags(String morph, int level, boolean insert,
int offset, String subtag_separator) {
if (morph.equals(BORDER_SYMBOL_)) {
return null;
}
if (morph.equals("_")) {
return null;
}
if (level >= subtag_tables_.size()) {
return null;
}
SymbolTable<String> subtag_table = subtag_tables_.get(level);
if (subtag_table == null) {
return null;
}
String[] sub_tags = morph.split(subtag_separator);
if (sub_tags.length == 1) {
return null;
}
List<Integer> indexes = new LinkedList<Integer>();
for (String sub_tag : sub_tags) {
if (sub_tag.length() > 0) {
int value = subtag_table.toIndex(sub_tag, -1, insert);
if (value >= 0) {
indexes.add(value);
}
}
}
int[] array = new int[indexes.size()];
int i = 0;
for (int index : indexes) {
array[i++] = index + offset;
}
return array;
}
private void addTagIndexes(Word word, int head, boolean insert) {
List<SymbolTable<String>> tag_tables = getTagTables();
String pos_tag = word.getPosTag();
String morph = word.getMorphTag();
int[] tag_indexes = new int[tag_tables.size()];
if (pos_tag == null) {
tag_indexes[0] = -1;
} else {
tag_indexes[0] = tag_tables.get(0).toIndex(pos_tag, -1, insert);
}
if (tag_morph_) {
if (morph == null) {
tag_indexes[1] = -1;
} else {
tag_indexes[1] = tag_tables.get(1).toIndex(morph, -1, insert);
}
}
word.setTagIndexes(tag_indexes);
}
private void addShape(Word word, String form, boolean insert) {
if (shape_) {
int word_index = word.getWordFormIndex();
if (vocab_ == null) {
return;
}
if (isRare(word_index)) {
int shape_index = -1;
if (trie_ != null) {
String shape = Integer.toString(trie_.classify(form));
shape_index = shape_table_.toIndex(shape, -1, insert);
}
word.setWordShapeIndex(shape_index);
}
}
}
public boolean isRare(int word) {
if (word < 0 || word >= vocab_.length) {
return true;
}
return vocab_[word] < rare_word_max_freq_;
}
public SymbolTable<String> getWordTable() {
return word_table_;
}
public static Tagger trainOptimal(MorphOptions options,
Collection<Sequence> train_sentences,
Collection<Sequence> test_sentences, List<String> parameters,
List<List<String>> values_list, List<MorphEntry> results) {
if (test_sentences == null) {
throw new InvalidParameterException("test_sentences is null!");
}
assert parameters.size() == values_list.size();
assert !parameters.isEmpty();
if (parameters.size() == 1) {
return trainOptimal(options, train_sentences, test_sentences,
parameters.get(0), values_list.get(0), results);
}
parameters = new LinkedList<String>(parameters);
values_list = new LinkedList<List<String>>(values_list);
Tagger best_tagger = null;
String parameter = ((LinkedList<String>) parameters).pollFirst();
Collection<String> values = ((LinkedList<List<String>>) values_list)
.pollFirst();
for (String value : values) {
options = Copy.clone(options);
options.setProperty(parameter, value);
Tagger tagger = trainOptimal(options, train_sentences,
test_sentences, parameters, values_list, results);
if (best_tagger == null) {
best_tagger = tagger;
} else {
if (tagger.getResult().getScore() > best_tagger.getResult()
.getScore()) {
best_tagger = tagger;
}
}
}
return best_tagger;
}
public static class MorphEntry implements Comparable<MorphEntry> {
private MorphOptions options_;
private MorphResult result_;
public MorphEntry(MorphOptions options, MorphResult result) {
options_ = options;
result_ = result;
}
@Override
public int compareTo(MorphEntry o) {
return -Double.compare(result_.getScore(), o.result_.getScore());
}
public MorphOptions getOptions() {
return options_;
}
public MorphResult getResult() {
return result_;
}
}
public static Tagger trainOptimal(MorphOptions options,
Collection<Sequence> train_sentences,
Collection<Sequence> test_sentences, String parameter,
Collection<String> values, List<MorphEntry> results) {
Tagger best_tagger = null;
if (test_sentences == null) {
throw new InvalidParameterException("test_sentebces is null!");
}
for (String value : values) {
options = Copy.clone(options);
options.setProperty(parameter, value);
Tagger tagger = train(Copy.clone(options), train_sentences,
test_sentences);
results.add(new MorphEntry(options, (MorphResult) tagger
.getResult()));
if (best_tagger == null) {
best_tagger = tagger;
} else {
if (tagger.getResult().getScore() > best_tagger.getResult()
.getScore()) {
best_tagger = tagger;
}
}
}
return best_tagger;
}
public static Tagger trainOptimal(MorphOptions options,
List<Sequence> train_sentences, List<Sequence> test_sentences) {
if (test_sentences == null) {
throw new InvalidParameterException("test_sentences is null!");
}
List<String> parameters = Arrays.asList(Options.ORDER, Options.SEED,
Options.PENALTY);
List<MorphEntry> results = new LinkedList<MorphEntry>();
List<List<String>> values_list = Arrays.asList(
Arrays.asList("1", "3", "5"), Arrays.asList("41", "42", "43"),
Arrays.asList("0.0", "0.05", "0.1", "0.5"));
Tagger tagger = MorphModel.trainOptimal(options, train_sentences,
test_sentences, parameters, values_list, results);
Collections.sort(results);
System.err.println("OPTIMAL OPTIONS AND RESULTS");
for (MorphEntry result : results) {
StringBuilder sb = new StringBuilder();
for (String param : parameters) {
if (sb.length() > 0) {
sb.append(',');
sb.append(' ');
}
sb.append(param);
sb.append(':');
sb.append(result.getOptions().getProperty(param));
}
sb.append('\t');
sb.append(result.getResult().getScore());
System.err.println(sb.toString());
}
return tagger;
}
public static Tagger train(MorphOptions options,
Collection<Sequence> train_sentences,
Collection<Sequence> test_sentences) {
MorphModel model = new MorphModel();
model.init(options, train_sentences);
if (test_sentences != null) {
for (Sequence sentence : test_sentences) {
for (Token token : sentence) {
Word word = (Word) token;
model.addIndexes(word, false);
}
}
}
WeightVector weights = new MorphWeightVector(options);
weights.init(model, train_sentences);
Tagger tagger = new MorphTagger(model, model.getOrder(), weights);
Trainer trainer = TrainerFactory.create(options);
MorphEvaluator evaluator = null;
if (test_sentences != null) {
evaluator = new MorphEvaluator(test_sentences);
}
trainer.train(tagger, train_sentences, evaluator);
if (options.getLemmatizer() && options.getLemmaPretraining()) {
model.skip_lemma_ = false;
if (options.getVerbose()) {
System.err.format("Training with lemmatizer.\n");
}
trainer.train(tagger, train_sentences, evaluator);
}
return tagger;
}
public SymbolTable<Character> getCharTable() {
return char_table_;
}
public int getNumShapes() {
if (trie_ == null) {
return shape_table_.size();
} else {
return trie_.getIndex();
}
}
public SymbolTable<String> getShapeTable() {
return shape_table_;
}
public boolean isOOV(int form_index) {
return form_index < 0 || vocab_[form_index] == 0;
}
public int getNumSubTags() {
int total = 0;
if (subtag_tables_ != null) {
for (SymbolTable<String> table : subtag_tables_) {
if (table != null) {
total += table.size();
}
}
}
return total;
}
public SymbolTable<String> getTokenFeatureTable() {
return token_feature_table_;
}
public SymbolTable<String> getWeightedTokenFeatureTable() {
return weighted_token_feature_table_;
}
@Override
public int[] getTagCandidates(Sequence sequence, int index, State state) {
int level = (state == null) ? 0 : state.getLevel() + 1;
if (transitions_ != null && level == MORPH_INDEX_) {
return transitions_[state.getIndex()];
}
if (level == 0 && restrict_pos_tags_to_seen_combinations_) {
Token token = sequence.get(index);
Word word = (Word) token;
int word_index = word.getWordFormIndex();
if (!isRare(word_index)) {
return word_to_observed_tags_[word_index];
}
}
return tag_classes_[level];
}
public int[][][] getTagToSubTags() {
return tag_to_subtag_;
}
public void setVerbose(boolean verbose) {
verbose_ = verbose;
}
public int getMaxSignature() {
return FeatUtil.getMaxSignature(special_signature_);
}
public static Tagger train(MorphOptions options,
List<Sequence> train_sequences) {
return train(options, train_sequences, null);
}
boolean lemma_prepruning_extraction_ = true;
@Override
public void setLemmaCandidates(Token token, State state, boolean preprune,
boolean training) {
if (lemma_model_ == null || preprune != lemma_prepruning_extraction_)
return;
int pos_index = state.getIndex();
Word word = (Word) token;
RankerInstance instance = getRankerInstance(word, pos_index, training);
assert instance != null;
LemmaCandidateSet candidate_set = instance.getCandidateSet();
List<RankerCandidate> candidates = new ArrayList<>(candidate_set.size());
assert state.getLevel() == 0;
int[] morph_indexes = RankerInstance.EMPTY_ARRAY;
String lemma = word.getLemma().toLowerCase();
for (Map.Entry<String, LemmaCandidate> entry : candidate_set) {
String plemma = entry.getKey();
boolean is_correct = plemma.equals(lemma);
LemmaCandidate candidate = entry.getValue();
assert candidate != null;
double score = getLemmaCandidateScore(candidate, candidate_set,
pos_index, morph_indexes, instance, training);
RankerCandidate rcandidate = new RankerCandidate(plemma, candidate,
is_correct, score);
assert rcandidate.getCandidate() != null;
candidates.add(rcandidate);
}
state.setLemmaCandidates(candidates);
state.setLemmaScoreSum();
assert state.getLemmaCandidates() != null;
}
private double getLemmaCandidateScore(LemmaCandidate candidate,
LemmaCandidateSet candidate_set, int pos_index,
int[] morph_indexes, RankerInstance instance, boolean training) {
if (skip_lemma_) {
candidate.setFeatureIndexes(RankerInstance.EMPTY_ARRAY);
return 0.0;
}
if (candidate.getFeatureIndexes() == null
|| candidate.getFeatureIndexes() == RankerInstance.EMPTY_ARRAY) {
for (Map.Entry<String, LemmaCandidate> entry : candidate_set) {
if (entry.getValue().getFeatureIndexes() == RankerInstance.EMPTY_ARRAY) {
entry.getValue().setFeatureIndexes(null);
}
}
candidate.setFeatureIndexes(null);
lemma_model_.addIndexes(instance, candidate_set, training);
for (Map.Entry<String, LemmaCandidate> entry : candidate_set) {
assert entry.getValue().getFeatureIndexes() != null;
assert entry.getValue().getFeatureIndexes() != RankerInstance.EMPTY_ARRAY;
}
assert candidate.getFeatureIndexes() != null;
assert candidate.getFeatureIndexes() != RankerInstance.EMPTY_ARRAY;
}
return lemma_model_.score(candidate, pos_index, morph_indexes);
}
@Override
public void setLemmaCandidates(State state, boolean preprune) {
if (lemma_model_ == null || preprune != lemma_prepruning_extraction_)
return;
assert state.getLevel() == 1;
State previous_state = state.getSubLevelState();
assert previous_state != null;
assert state != null;
assert previous_state.getOrder() == 1;
assert state.getOrder() == 1;
List<RankerCandidate> prev_candidates = previous_state
.getLemmaCandidates();
assert prev_candidates != null;
assert previous_state.getLevel() == 0;
int pos_index = previous_state.getIndex();
int morph_index = state.getIndex();
int[] morph_indexes = getTagToSubTags()[state.getLevel()][morph_index];
if (morph_indexes == null)
morph_indexes = RankerInstance.EMPTY_ARRAY;
if (!lemma_use_morph_) {
morph_indexes = RankerInstance.EMPTY_ARRAY;
}
List<RankerCandidate> candidates = new ArrayList<>(
prev_candidates.size());
for (RankerCandidate prev_candidate : prev_candidates) {
String pLemma = prev_candidate.getLemma();
LemmaCandidate pcandidate = prev_candidate.getCandidate();
double score = lemma_model_.score(pcandidate, pos_index,
morph_indexes);
candidates.add(new RankerCandidate(pLemma, pcandidate,
prev_candidate.isCorrect(), score));
}
state.setLemmaCandidates(candidates);
state.setLemmaScoreSum();
}
public RankerModel getLemmaModel() {
return lemma_model_;
}
@Override
public boolean getMarganlizeLemmas() {
return marginalize_lemmas_;
}
public boolean getLemmaUseMorph() {
return lemma_use_morph_;
}
}