package chipmunk.segmenter; import java.util.Arrays; import java.util.Collections; import java.util.LinkedList; import java.util.List; public class SegmentationDecoder { private SegmenterModel model_; private int num_tags_; private int max_segment_length; private double[] score_array_; private int[] tag_array_; private int[] index_array_; private int input_length_; public SegmentationDecoder(SegmenterModel model) { model_ = model; num_tags_ = model_.getNumTags(); max_segment_length = model_.getMaxSegmentLength(); } SegmentationResult decode(SegmentationInstance instance) { input_length_ = instance.getLength(); checkArraySize(num_tags_ * input_length_); Arrays.fill(score_array_, Double.NEGATIVE_INFINITY); Arrays.fill(tag_array_, -1); Arrays.fill(index_array_, -1); for (int l_end = 1; l_end < input_length_ + 1; l_end++) { for (int tag = 0; tag < num_tags_; tag++) { double best_score = Double.NEGATIVE_INFINITY; int best_output = -1; int best_index = -1; for (int l_start = Math.max(0, l_end - max_segment_length); l_start < l_end; l_start++) { double pair_score = model_.getPairScore(instance, l_start, l_end, tag); if (l_start == 0) { double score = pair_score; if (score > best_score) { best_score = score; best_output = -1; best_index = l_start; } } else { for (int last_tag = 0; last_tag < num_tags_; last_tag++) { double prev_cost = score_array_[getIndex(last_tag, l_start - 1)]; double transiton_score = model_.getTransitionScore(instance, last_tag, tag, l_start, l_end); double score = pair_score + transiton_score + prev_cost; if (score > best_score) { best_score = score; best_output = last_tag; best_index = l_start; } } } } score_array_[getIndex(tag, l_end - 1)] = best_score; tag_array_[getIndex(tag, l_end - 1)] = best_output; index_array_[getIndex(tag, l_end - 1)] = best_index; } } SegmentationResult result = backTrace(); return result; } private SegmentationResult backTrace() { List<Integer> tags = new LinkedList<>(); List<Integer> input_indexes = new LinkedList<>(); int end_index = input_length_; double best_score = Double.NEGATIVE_INFINITY; int end_tag = -1; for (int tag = 0; tag < num_tags_; tag++) { double score = score_array_[getIndex(tag, end_index - 1)]; if (score > best_score) { best_score = score; end_tag = tag; } } tags.add(end_tag); input_indexes.add(end_index); while (true) { int start_index = index_array_[getIndex(end_tag, end_index - 1)]; int start_tag = tag_array_[getIndex(end_tag, end_index - 1)]; if (start_tag < 0) break; tags.add(start_tag); input_indexes.add(start_index); end_tag = start_tag; end_index = start_index; } Collections.reverse(tags); Collections.reverse(input_indexes); return new SegmentationResult(tags, input_indexes, best_score); } private int getIndex(int tag, int index) { return tag * input_length_ + index; } private void checkArraySize(int required_length) { if (score_array_ == null || score_array_.length < required_length) { score_array_ = new double[required_length]; tag_array_ = new int[score_array_.length]; index_array_ = new int[score_array_.length]; } } }