package chipmunk.segmenter; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import marmot.core.Feature; import marmot.util.DynamicWeights; import marmot.util.Encoder; import marmot.util.SymbolTable; public class SegmenterModel implements Serializable { private static final long serialVersionUID = 1L; private SymbolTable<String> tag_table_; private SymbolTable<Character> char_table_; private int max_segment_length_; transient private Encoder encoder_; transient private Encoder.State encoder_state_; private int window_length_bits_; private int num_char_bits_; private int num_tag_bits_; private int max_segment_length_bits_; private IndexScorer scorer_; private IndexUpdater updater_; private final static int FEATURE_BITS = Encoder.bitsNeeded(4); private final static int TRANS_FEAT = 0; private final static int TAG_FEAT = 1; private final static int PAIR_FEAT = 2; private final static int CHARACTER_FEAT = 3; private final static int DICT_FEAT = 4; private List<Dictionary> dictionaries_; private int dictionary_bits_; private List<List<Integer>> tags_to_subtags_; private SegmenterOptions options_; public void init(SegmenterOptions options, Collection<Word> words) { options_ = options; tag_table_ = new SymbolTable<>(true); char_table_ = new SymbolTable<>(); max_segment_length_ = 0; for (Word word : words) { for (SegmentationReading reading : word.getReadings()) { for (String segment : reading.getSegments()) { assert segment.length() > 0; if (segment.length() > max_segment_length_) { max_segment_length_ = segment.length(); } } for (String tag : reading.getTags()) { tag_table_.toIndex(tag, true); } } for (int i = 0; i < word.getWord().length(); i++) { char c = word.getWord().charAt(i); char_table_.toIndex(c, true); } } tags_to_subtags_ = new ArrayList<>(); for (int i=0;i<tag_table_.size(); i++) { tags_to_subtags_.add(null); } SymbolTable<String> subtag_table = new SymbolTable<>(); for (Map.Entry<String, Integer> entry : tag_table_.entrySet()) { String tag = entry.getKey(); String[] subtags = TagSet.split(tag); List<Integer> indexes = new LinkedList<>(); indexes.add(subtag_table.toIndex(tag, true)); if (subtags.length > 1) { for (String subtag : subtags) { indexes.add(subtag_table.toIndex(subtag, true)); } } tags_to_subtags_.set(entry.getValue(), indexes); } if (options_.getBoolean(SegmenterOptions.VERBOSE)) { System.err.println("Tag table: " + tag_table_); System.err.println("Num tags: " + tag_table_.size()); } dictionaries_ = new LinkedList<>(); Collection<String> dictionary_paths = options_.getDictionaries(); for (String path : dictionary_paths) { Dictionary dictionary = new Dictionary(path, options_.getString(SegmenterOptions.LANG), max_segment_length_); if (options_.getBoolean(SegmenterOptions.VERBOSE)) { System.err.format("Created dictionary with %d entries from %s\n", dictionary.size(), path); } dictionaries_.add(dictionary); } num_tag_bits_ = Encoder.bitsNeeded(subtag_table.size()); num_char_bits_ = Encoder.bitsNeeded(char_table_.size()); max_segment_length_bits_ = Encoder.bitsNeeded(max_segment_length_); window_length_bits_ = Encoder.bitsNeeded(options_.getInt(SegmenterOptions.MAX_CHARACTER_WINDOW) + 1); if (!dictionaries_.isEmpty()) { dictionary_bits_ = Encoder.bitsNeeded(dictionaries_.size()); } SymbolTable<Feature> feature_map = new SymbolTable<>(); scorer_ = new IndexScorer(null, feature_map, num_tag_bits_); updater_ = new IndexUpdater(null, feature_map, num_tag_bits_); } private void prepareEncoder() { if (encoder_ == null) { encoder_ = new Encoder(10); encoder_state_ = new Encoder.State(); } encoder_.reset(); } public int getNumTags() { return tag_table_.size(); } public int getMaxSegmentLength() { return max_segment_length_; } public double getPairScore(SegmentationInstance instance, int l_start, int l_end, int tag) { scorer_.reset(); consumeTagPair(scorer_, instance, l_start, l_end, tag); return scorer_.getScore(); } public double getTransitionScore(SegmentationInstance instance, int last_tag, int tag, int l_start, int l_end) { scorer_.reset(); consumeTransition(scorer_, instance, l_start, l_end, last_tag, tag); return scorer_.getScore(); } private void consumeTagPair(IndexConsumer consumer, SegmentationInstance instance, int l_start, int l_end, int tag) { assert l_start >= 0 && l_end <= instance.getLength(); consumePairFeature(consumer, instance, l_start, l_end, tag); consumeCharacterFeature(consumer, instance, l_start, l_end, tag); consumeTagFeature(consumer, instance, l_start, l_end, tag); if (!dictionaries_.isEmpty()) { String segment = instance.getWord().getWord().substring(l_start, l_end); int length = l_end - l_start; assert segment.length() == length; assert length <= max_segment_length_; int dict_index = 0; for (Dictionary dictionary : dictionaries_) { boolean value = dictionary.contains(segment); prepareEncoder(); encoder_.append(DICT_FEAT, FEATURE_BITS); encoder_.append(dict_index, dictionary_bits_); encoder_.append(0, max_segment_length_bits_); encoder_.append(value); consumer.consume(encoder_, tags_to_subtags_.get(tag)); prepareEncoder(); encoder_.append(DICT_FEAT, FEATURE_BITS); encoder_.append(dict_index, dictionary_bits_); encoder_.append(length, max_segment_length_bits_); encoder_.append(value); consumer.consume(encoder_, tags_to_subtags_.get(tag)); dict_index ++; } } } private void consumeCharacterFeature(IndexConsumer consumer, SegmentationInstance instance, int l_start, int l_end, int tag) { if (options_.getBoolean(SegmenterOptions.USE_CHARACTER_FEATURE)) { short[] chars = instance.getFormCharIndexes(char_table_); prepareEncoder(); encoder_.append(CHARACTER_FEAT, FEATURE_BITS); encoder_.append(0, 2); encoder_.storeState(encoder_state_); for (int window = 1; window <= options_.getInt(SegmenterOptions.MAX_CHARACTER_WINDOW); window++) { encoder_.restoreState(encoder_state_); addSegment(chars, l_start, l_start + window); consumer.consume(encoder_, tags_to_subtags_.get(tag)); } prepareEncoder(); encoder_.append(CHARACTER_FEAT, FEATURE_BITS); encoder_.append(2, 2); encoder_.storeState(encoder_state_); for (int window = 1; window <= options_.getInt(SegmenterOptions.MAX_CHARACTER_WINDOW); window++) { encoder_.restoreState(encoder_state_); addSegment(chars, l_end - window, l_end); consumer.consume(encoder_, tags_to_subtags_.get(tag)); } } } private void consumeTransition(IndexConsumer consumer, SegmentationInstance instance, int l_start, int l_end, int last_tag, int tag) { consumeTransitionFeature(consumer, instance, l_start, l_end, last_tag, tag); } public void consumeTagFeature(IndexConsumer consumer, SegmentationInstance instance, int l_start, int l_end, int tag) { prepareEncoder(); encoder_.append(TAG_FEAT, FEATURE_BITS); consumer.consume(encoder_, tags_to_subtags_.get(tag)); } public void consumePairFeature(IndexConsumer consumer, SegmentationInstance instance, int l_start, int l_end, int tag) { assert l_start >= 0 && l_end <= instance.getLength(); prepareEncoder(); short[] chars = instance.getFormCharIndexes(char_table_); assert chars.length == instance.getLength(); encoder_.append(PAIR_FEAT, FEATURE_BITS); encoder_.append(l_end - l_start, max_segment_length_bits_); for (int l = l_start; l < l_end; l++) { int c = chars[l]; if (c < 0) { return; } encoder_.append(c, num_char_bits_); } consumer.consume(encoder_, tags_to_subtags_.get(tag)); addCharacterContext(instance, consumer, l_start, l_end, tag); } public void consumeTransitionFeature(IndexConsumer consumer, SegmentationInstance instance, int l_start, int l_end, int last_tag, int tag) { if (last_tag < 0) { return; } for (int last_subtag : tags_to_subtags_.get(last_tag)) { prepareEncoder(); encoder_.append(TRANS_FEAT, FEATURE_BITS); encoder_.append(last_subtag, num_tag_bits_); consumer.consume(encoder_, tags_to_subtags_.get(tag)); } } public void update(SegmentationInstance instance, SegmentationResult result, double update) { updater_.setUpdate(update); Iterator<Integer> tag_iterator = result.getTags().iterator(); Iterator<Integer> input_iterator = result.getInputIndexes().iterator(); int last_tag = -1; int l_start = 0; while (tag_iterator.hasNext()) { int tag = tag_iterator.next(); int l_end = input_iterator.next(); assert l_end <= instance.getLength(); if (last_tag >= 0) { consumeTransition(updater_, instance, l_start, l_end, last_tag, tag); } assert l_start >= 0 && l_end <= instance.getLength(); consumeTagPair(updater_, instance, l_start, l_end, tag); last_tag = tag; l_start = l_end; } } public double getScore(SegmentationInstance instance, SegmentationResult result) { scorer_.reset(); Iterator<Integer> tag_iterator = result.getTags().iterator(); Iterator<Integer> input_iterator = result.getInputIndexes().iterator(); int last_tag = -1; int l_start = 0; while (tag_iterator.hasNext()) { int tag = tag_iterator.next(); int l_end = input_iterator.next(); if (last_tag >= 0) { consumeTransition(scorer_, instance, l_start, l_end, last_tag, tag); } consumeTagPair(scorer_, instance, l_start, l_end, tag); last_tag = tag; l_start = l_end; } return scorer_.getScore(); } public SegmentationInstance getInstance(Word word) { List<SegmentationResult> results = new LinkedList<>(); for (SegmentationReading reading : word.getReadings()) { List<Integer> tags = new ArrayList<>(); for (String tag : reading.getTags()) { int tag_index = tag_table_.toIndex(tag, -1, false); tags.add(tag_index); } List<Integer> input_indexes = new ArrayList<>(); int index = 0; for (String segment : reading.getSegments()) { index += segment.length(); assert index <= word.getLength() : word + " " + reading; input_indexes.add(index); } results.add(new SegmentationResult(tags, input_indexes)); } SegmentationInstance instance = new SegmentationInstance(word, results); return instance; } public void setWeights(DynamicWeights weights) { setScorerWeights(weights); setUpdaterWeights(weights); } public void setScorerWeights(DynamicWeights weights) { scorer_.setWeights(weights); } public void setUpdaterWeights(DynamicWeights weights) { updater_.setWeights(weights); } public Word toWord(String form, SegmentationResult result) { List<String> tags = new ArrayList<>(); for (int tag_index : result.getTags()) { tags.add(tag_table_.toSymbol(tag_index)); } List<String> segments = new ArrayList<>(); int start_index = 0; for (int end_index : result.getInputIndexes()) { segments.add(form.substring(start_index, end_index)); start_index = end_index; } Word word = new Word(form); word.add(new SegmentationReading(segments, tags)); return word; } public void update(SegmentationInstance instance, int l_start, int l_end, int tag, double update) { updater_.setUpdate(update); consumeTagPair(updater_, instance, l_start, l_end, tag); } public void update(SegmentationInstance instance, int l_start, int l_end, int last_tag, int tag, double update) { updater_.setUpdate(update); consumeTransition(updater_, instance, l_start, l_end, last_tag, tag); } public void printWeights() { System.err.println(Arrays.toString(scorer_.getWeights().getWeights())); System.err.println(Arrays.toString(updater_.getWeights().getWeights())); } public void setFinal() { updater_ = null; scorer_.setInsert(false); scorer_.getWeights().setExapnd(false); encoder_ = null; } public IndexScorer getScorer() { return scorer_; } public IndexUpdater getUpdater() { return updater_; } private void addCharacterContext(SegmentationInstance instance, IndexConsumer consumer, int l_start, int l_end, int tag_index) { if (options_.getBoolean(SegmenterOptions.USE_SEGMENT_CONTEXT)) { encoder_.storeState(encoder_state_); for (int window = 1; window <= options_.getInt(SegmenterOptions.MAX_CHARACTER_WINDOW); window++) { encoder_.restoreState(encoder_state_); encoder_.append(0, 1); addSegment(instance.getFormCharIndexes(char_table_), l_start- window, l_start); consumer.consume(encoder_, tags_to_subtags_.get(tag_index)); } for (int window = 1; window <= options_.getInt(SegmenterOptions.MAX_CHARACTER_WINDOW); window++) { encoder_.restoreState(encoder_state_); encoder_.append(1, 1); addSegment(instance.getFormCharIndexes(char_table_), l_end, l_end + window); consumer.consume(encoder_, tags_to_subtags_.get(tag_index)); } } } private void addSegment(short[] chars, int start, int end) { encoder_.append(end - start, window_length_bits_); for (int i = start; i < end; i++) { int c; if (i >= 0 && i < chars.length) { c = chars[i]; } else { c = char_table_.size(); } if (c < 0) return; encoder_.append(c, num_char_bits_); } } }