// Copyright 2015 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package lemming.lemma.toutanova; import java.io.IOException; import java.io.ObjectInputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.logging.Logger; import lemming.lemma.LemmaInstance; import lemming.lemma.toutanova.Aligner.Pair; import lemming.lemma.toutanova.ToutanovaTrainer.ToutanovaOptions; import marmot.core.Feature; import marmot.util.DynamicWeights; import marmot.util.Encoder; import marmot.util.SymbolTable; public class ToutanovaModel implements Serializable { private static final long serialVersionUID = 1L; private String alphabet_[]; private SymbolTable<String> output_table_; private SymbolTable<String> pos_table_; private int max_input_segment_length_; private int num_output_bits; private SymbolTable<Character> char_table; private Set<String> form_vocab_; transient private Encoder encoder; transient private Encoder.State encoder_state; private int num_char_bits; private int num_pos_bits; private IndexScorer scorer_; private IndexUpdater updater_; private boolean use_zero_order_; private int max_input_segment_length_bits_; private DynamicWeights weights_; private static final int length_bits_ = 6; private final static int FEATURE_BITS = Encoder.bitsNeeded(2); private final static int TRANS_FEAT = 0; private final static int OUTPUT_FEAT = 1; private final static int PAIR_FEAT = 2; private final static String COPY_SYMBOL = "<COPY>"; public void init(ToutanovaOptions options, List<ToutanovaInstance> train_instances, List<ToutanovaInstance> test_instances) { Logger logger = Logger.getLogger(getClass().getName()); max_window = options.getMaxWindowSize(); createOutputTable(options, train_instances); logger.info("Output alphabet size: " + output_table_.size()); logger.info("Max input segment length: " + max_input_segment_length_); if (options.getFilterAlphabet() > 0) { filterRareOutputSymbols(options, train_instances); createOutputTable(options, train_instances); logger.info("Output alphabet size: " + output_table_.size()); logger.info("Max input segment length: " + max_input_segment_length_); } char_table = new SymbolTable<>(); if (options.getUsePos()) { pos_table_ = new SymbolTable<>(); } form_vocab_ = new HashSet<>(); for (ToutanovaInstance instance : train_instances) { form_vocab_.add(instance.getInstance().getForm()); } addIndexes(train_instances, true); if (test_instances != null) addIndexes(test_instances, false); num_output_bits = Encoder.bitsNeeded(output_table_.size()); alphabet_ = new String[output_table_.size()]; for (Map.Entry<String, Integer> entry : output_table_.entrySet()) alphabet_[entry.getValue()] = entry.getKey(); output_table_.setBidirectional(false); num_char_bits = Encoder.bitsNeeded(char_table.size()); num_pos_bits = -1; if (pos_table_ != null) num_pos_bits = Encoder.bitsNeeded(pos_table_.size()); weights_ = new DynamicWeights(options.getRandom()); SymbolTable<Feature> feature_map = new SymbolTable<>(); scorer_ = new IndexScorer(weights_, feature_map, num_pos_bits); updater_ = new IndexUpdater(weights_, feature_map, num_pos_bits); use_zero_order_ = options.getDecoderInstance().getOrder() < 1; setupTemp(); } private void setupTemp() { encoder = new Encoder(10); encoder_state = new Encoder.State(); } private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { ois.defaultReadObject(); setupTemp(); } private void createOutputTable(ToutanovaOptions options, List<ToutanovaInstance> train_instances) { output_table_ = new SymbolTable<>(true); output_table_.insert(COPY_SYMBOL); max_input_segment_length_ = 0; for (ToutanovaInstance instance : train_instances) { if (instance.isRare()) { instance.setResult(null); continue; } String form = instance.getInstance().getForm(); assert instance.getAlignment() != null; List<Pair> pairs = Aligner.Pair.toPairs(form, instance .getInstance().getLemma(), instance.getAlignment()); List<Integer> form_indexes = new ArrayList<>(pairs.size()); List<Integer> lemma_segments = new ArrayList<>(pairs.size()); int start_index = 0; for (Pair pair : pairs) { int current_input_length = pair.getInputSegment().length(); max_input_segment_length_ = Math.max(max_input_segment_length_, current_input_length); start_index += current_input_length; form_indexes.add(start_index); int output_segment_index = 0; if (!pair.getInputSegment().equals(pair.getOutputSegment())) { output_segment_index = output_table_.toIndex( pair.getOutputSegment(), true); } lemma_segments.add(output_segment_index); } Result result = new Result(this, lemma_segments, form_indexes, form); assert (result.getOutput() .equals(instance.getInstance().getLemma())); instance.setResult(result); } max_input_segment_length_bits_ = Encoder .bitsNeeded(max_input_segment_length_); } private void filterRareOutputSymbols(ToutanovaOptions options, List<ToutanovaInstance> train_instances) { Logger logger = Logger.getLogger(getClass().getName()); int[] count = new int[output_table_.size()]; for (ToutanovaInstance instance : train_instances) { for (int output_index : instance.getResult().getOutputs()) { count[output_index]++; } } int rare_output_symbols = 0; for (int i = 0; i < count.length; i++) { if (count[i] == 1) { rare_output_symbols++; } } logger.info(String.format("Num rare output symbols (< %d): %d", options.getFilterAlphabet(), rare_output_symbols)); for (ToutanovaInstance instance : train_instances) { boolean instance_is_rare = false; for (int output_index : instance.getResult().getOutputs()) { if (count[output_index] <= options.getFilterAlphabet()) { instance_is_rare = true; break; } } instance.setRare(instance_is_rare); } } public SymbolTable<String> getOutputTable() { return output_table_; } public int getMaxInputSegmentLength() { return max_input_segment_length_; } public String getOutput(int o) { if (alphabet_ == null) { return output_table_.toSymbol(o); } return alphabet_[o]; } public void consumeTransitionFeature(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int last_o, int o) { if (last_o < 0) { return; } encoder.reset(); encoder.append(TRANS_FEAT, FEATURE_BITS); encoder.append(last_o, num_output_bits); encoder.append(o, num_output_bits); // if (use_context_feature_) { // encoder.append(l_start == 0); // encoder.append(l_end == instance.getFormCharIndexes().length); // } consumer.consume(instance, encoder); addAffixes(instance, consumer, l_start, l_end); } private int max_window = 2; private void addAffixes(ToutanovaInstance instance, IndexConsumer consumer, int l_start, int l_end) { for (int window = 1; window <= max_window; window++) { encoder.storeState(encoder_state); addSegment(instance.getFormCharIndexes(), l_start - window, l_start); addSegment(instance.getFormCharIndexes(), l_end + 1, l_end + window + 1); consumer.consume(instance, encoder); encoder.restoreState(encoder_state); } for (int window = 1; window <= max_window; window++) { encoder.storeState(encoder_state); addSegment(instance.getFormCharIndexes(), l_start - window, l_start); consumer.consume(instance, encoder); encoder.restoreState(encoder_state); } for (int window = 1; window <= max_window; window++) { encoder.storeState(encoder_state); addSegment(instance.getFormCharIndexes(), l_end + 1, l_end + window + 1); consumer.consume(instance, encoder); encoder.restoreState(encoder_state); } } private void addSegment(int[] chars, int start, int end) { encoder.append(end - start, length_bits_); for (int i = start; i < end; i++) { int c; if (i >= 0 && i < chars.length) { c = chars[i]; } else { c = char_table.size(); } if (c < 0) return; encoder.append(c, num_char_bits); } } public void consumeOutputFeature(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int o) { encoder.reset(); encoder.append(OUTPUT_FEAT, FEATURE_BITS); encoder.append(o, num_output_bits); // if (use_context_feature_) { // encoder.append(l_start == 0); // encoder.append(l_end == instance.getFormCharIndexes().length); // } consumer.consume(instance, encoder); addAffixes(instance, consumer, l_start, l_end); } public void consumePairFeature(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int o) { int[] chars = instance.getFormCharIndexes(); encoder.reset(); encoder.append(PAIR_FEAT, FEATURE_BITS); encoder.append(o, num_output_bits); encoder.append(l_end - l_start, max_input_segment_length_bits_); // if (use_context_feature_) { // encoder.append(l_start == 0); // encoder.append(l_end == instance.getFormCharIndexes().length); // } encoder.append(l_end - l_start, 4); for (int l = l_start; l < l_end; l++) { int c = chars[l]; if (c < 0) { return; } encoder.append(c, num_char_bits); } consumer.consume(instance, encoder); addAffixes(instance, consumer, l_start, l_end); } private void consumeOutputPair(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int o) { consumePairFeature(consumer, instance, l_start, l_end, o); consumeOutputFeature(consumer, instance, l_start, l_end, o); } private void consumeTransition(IndexConsumer consumer, ToutanovaInstance instance, int l_start, int l_end, int last_o, int o) { if (use_zero_order_) { return; } consumeTransitionFeature(consumer, instance, l_start, l_end, last_o, o); } public double getPairScore(ToutanovaInstance instance, int l_start, int l_end, int o) { scorer_.reset(); consumeOutputPair(scorer_, instance, l_start, l_end, o); return scorer_.getScore(); } public double getTransitionScore(ToutanovaInstance instance, int last_o, int o, int l_start, int l_end) { scorer_.reset(); consumeTransition(scorer_, instance, l_start, l_end, last_o, o); return scorer_.getScore(); } public double getScore(ToutanovaInstance instance, Result result) { scorer_.reset(); Iterator<Integer> output_iterator = result.getOutputs().iterator(); Iterator<Integer> input_iterator = result.getInputs().iterator(); int last_o = -1; int l_start = 0; while (output_iterator.hasNext()) { int o = output_iterator.next(); int l_end = input_iterator.next(); if (last_o >= 0) { consumeTransition(scorer_, instance, l_start, l_end, last_o, o); } consumeOutputPair(scorer_, instance, l_start, l_end, o); last_o = o; l_start = l_end; } return scorer_.getScore(); } public void update(ToutanovaInstance instance, Result result, double update) { updater_.setUpdate(update); Iterator<Integer> output_iterator = result.getOutputs().iterator(); Iterator<Integer> input_iterator = result.getInputs().iterator(); int last_o = -1; int l_start = 0; while (output_iterator.hasNext()) { int o = output_iterator.next(); int l_end = input_iterator.next(); if (last_o >= 0) { consumeTransition(updater_, instance, l_start, l_end, last_o, o); } consumeOutputPair(updater_, instance, l_start, l_end, o); last_o = o; l_start = l_end; } } public void addIndexes(List<ToutanovaInstance> instances, boolean insert) { for (ToutanovaInstance instance : instances) { addIndexes(instance, insert); } } public void addIndexes(ToutanovaInstance instance, boolean insert) { if (!instance.isRare()) { String form = instance.getInstance().getForm(); int[] char_indexes = new int[form.length()]; for (int i = 0; i < form.length(); i++) { char_indexes[i] = char_table .toIndex(form.charAt(i), -1, insert); } instance.setFormCharIndexes(char_indexes); if (pos_table_ != null) { String pos_tag = instance.getInstance().getPosTag(); if (pos_tag != null) { int index = pos_table_.toIndex(pos_tag, -1, insert); instance.setPosTagIndex(index); } } } } public DynamicWeights getWeights() { return weights_; } public void setWeights(DynamicWeights weights) { weights_ = weights; scorer_.setWeights(weights); updater_.setWeights(weights); } public boolean isOOV(LemmaInstance instance) { return !form_vocab_.contains(instance.getForm()); } }