package chipmunk.segmenter; import java.util.Arrays; import marmot.util.Numerics; public class SegmentationSumLattice { private SegmenterModel model_; private int num_tags_; private int max_segment_length; private double[] forward_score_array_; private double[] backward_score_array_; private int input_length_; public SegmentationSumLattice(SegmenterModel model) { model_ = model; num_tags_ = model_.getNumTags(); max_segment_length = model_.getMaxSegmentLength(); } public double update(SegmentationInstance instance, boolean do_update) { input_length_ = instance.getLength(); checkArraySize(num_tags_ * input_length_); Arrays.fill(forward_score_array_, Double.NEGATIVE_INFINITY); for (int l_end = 1; l_end < input_length_ + 1; l_end++) { for (int tag = 0; tag < num_tags_; tag++) { double score_sum = Double.NEGATIVE_INFINITY; for (int l_start = Math.max(0, l_end - max_segment_length); l_start < l_end; l_start++) { double pair_score = model_.getPairScore(instance, l_start, l_end, tag); if (l_start == 0) { double score = pair_score; score_sum = Numerics.sumLogProb(score, score_sum); } else { for (int last_tag = 0; last_tag < num_tags_; last_tag++) { double prev_cost = forward_score_array_[getIndex(last_tag, l_start - 1)]; double transiton_score = model_.getTransitionScore(instance, last_tag, tag, l_start, l_end); double score = pair_score + transiton_score + prev_cost; score_sum = Numerics.sumLogProb(score, score_sum); } } } forward_score_array_[getIndex(tag, l_end - 1)] = score_sum; //System.err.println("FB scoresum tag" + tag + " " + score_sum); } } //double forward_sum = sumTag(forward_score_array_, input_length_ - 1); Arrays.fill(backward_score_array_, Double.NEGATIVE_INFINITY); for (int l_start = input_length_ - 1; l_start >= 0; l_start--) { for (int tag = 0; tag < num_tags_; tag++) { double score_sum = Double.NEGATIVE_INFINITY; for (int l_end = Math.min(input_length_, l_start + max_segment_length); l_end > l_start; l_end--) { double pair_score = model_.getPairScore(instance, l_start, l_end, tag); if (l_end == input_length_) { double score = pair_score; score_sum = Numerics.sumLogProb(score, score_sum); } else { for (int next_tag = 0; next_tag < num_tags_; next_tag++) { double prev_cost = backward_score_array_[getIndex(next_tag, l_end)]; double transiton_score = model_.getTransitionScore(instance, tag, next_tag, l_start, l_end); double score = pair_score + transiton_score + prev_cost; score_sum = Numerics.sumLogProb(score, score_sum); } } } backward_score_array_[getIndex(tag, l_start)] = score_sum; } } double backward_sum = sumTag(backward_score_array_, 0); double sum = backward_sum; for (int l_end = 1; l_end < input_length_ + 1; l_end++) { for (int tag = 0; tag < num_tags_; tag++) { for (int l_start = Math.max(0, l_end - max_segment_length); l_start < l_end; l_start++) { double pair_score = model_.getPairScore(instance, l_start, l_end, tag); double backward_score = 0.0; if (l_end < input_length_) { backward_score = Double.NEGATIVE_INFINITY; for (int next_tag = 0; next_tag < num_tags_; next_tag++) { double trans_score = model_.getTransitionScore(instance, tag, next_tag, l_start, l_end); double next_tag_score = backward_score_array_[getIndex(next_tag, l_end)] + trans_score; backward_score = Numerics.sumLogProb(next_tag_score, backward_score); } } if (l_start == 0) { double score = backward_score + pair_score; double log_prob = score - sum; double prob = Math.exp(log_prob); double update = -prob; if (do_update) { model_.update(instance, l_start, l_end, tag, update); } } else { double update = 0; for (int last_tag = 0; last_tag < num_tags_; last_tag++) { double forward_score = forward_score_array_[getIndex(last_tag, l_start - 1)]; double transiton_score = model_.getTransitionScore(instance, last_tag, tag, l_start, l_end); double score = forward_score + pair_score + transiton_score + backward_score; double log_prob = score - sum; double prob = Math.exp(log_prob); double tag_update = -prob; if (do_update) { model_.update(instance, l_start, l_end, last_tag, tag, tag_update); } update += tag_update; } if (do_update) { model_.update(instance, l_start, l_end, tag, update); } } } } } double real_value = 0.0; for (SegmentationResult result : instance.getResults()) { model_.update(instance, result, 1. / instance.getResults().size()); real_value += model_.getScore(instance, result) - sum; } return real_value; } private double sumTag(double[] score_array, int l) { double score_sum = Double.NEGATIVE_INFINITY; for (int tag = 0; tag < num_tags_; tag++) { double score = score_array[getIndex(tag, l)]; score_sum = Numerics.sumLogProb(score, score_sum); } return score_sum; } private int getIndex(int tag, int index) { return tag * input_length_ + index; } private void checkArraySize(int required_length) { if (forward_score_array_ == null || forward_score_array_.length < required_length) { forward_score_array_ = new double[required_length]; backward_score_array_ = new double[required_length]; } } }