package chipmunk.test.segmenter; import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Random; import marmot.util.DynamicWeights; import marmot.util.Numerics; import org.junit.Assert; import org.junit.Test; import chipmunk.segmenter.SegmentationInstance; import chipmunk.segmenter.SegmentationReading; import chipmunk.segmenter.SegmentationResult; import chipmunk.segmenter.SegmentationSumLattice; import chipmunk.segmenter.SegmenterModel; import chipmunk.segmenter.SegmenterOptions; import chipmunk.segmenter.Word; public class SumLatticeTest { double explicit_update(SegmentationInstance instance, SegmenterModel model) { int max_segment_length = model.getMaxSegmentLength(); List<SegmentationResult> results = new LinkedList<>(); addAllResults(instance, model, max_segment_length, results, 0); double score_sum = Double.NEGATIVE_INFINITY; for (SegmentationResult result : results) { double score = model.getScore(instance, result); score_sum = Numerics.sumLogProb(score, score_sum); } for (SegmentationResult result : results) { double score = model.getScore(instance, result); double log_prob = score - score_sum; double prob = Math.exp(log_prob); model.update(instance, result, -prob); } assert instance.getResults().size() == 1; SegmentationResult result = instance.getResults().iterator().next(); double score = model.getScore(instance, result); double log_prob = score - score_sum; model.update(instance, result, 1.0); return log_prob; } private void addAllResults(SegmentationInstance instance, SegmenterModel model, int max_segment_length, List<SegmentationResult> results, int start) { String word = instance.getWord().getWord(); for (int end = start + 1; end <= Math.min(start + max_segment_length, word.length()); end++) { List<SegmentationResult> intermediates = new LinkedList<>(); if (end == word.length()) { for (int tag = 0; tag < model.getNumTags(); tag++) { List<Integer> tags = new LinkedList<>(); tags.add(tag); List<Integer> indexes = new LinkedList<>(); indexes.add(end); results.add(new SegmentationResult(tags, indexes)); } } else { addAllResults(instance, model, max_segment_length, intermediates, end); for (SegmentationResult intermediate : intermediates) { for (int tag = 0; tag < model.getNumTags(); tag++) { List<Integer> tags = new LinkedList<>(); tags.add(tag); tags.addAll(intermediate.getTags()); List<Integer> indexes = new LinkedList<>(); indexes.add(end); indexes.addAll(intermediate.getInputIndexes()); results.add(new SegmentationResult(tags, indexes)); } } } } } @Test public void test() { List<Word> words = new LinkedList<>(); words.add(toWord(Arrays.asList("b"), Arrays.asList("B"))); words.add(toWord(Arrays.asList("aa"), Arrays.asList("A"))); words.add(toWord(Arrays.asList("a", "bb"), Arrays.asList("A", "B"))); words.add(toWord(Arrays.asList("aa", "bb"), Arrays.asList("A", "B"))); words.add(toWord(Arrays.asList("a", "b"), Arrays.asList("A", "B"))); words.add(toWord(Arrays.asList("aa", "b"), Arrays.asList("A", "B"))); words.add(toWord(Arrays.asList("aa", "c"), Arrays.asList("A", "C"))); SegmenterModel model = new SegmenterModel(); SegmenterOptions options = new SegmenterOptions(); options.setOption(SegmenterOptions.USE_CHARACTER_FEATURE, false); options.setOption(SegmenterOptions.USE_SEGMENT_CONTEXT, false); model.init(options, words); SegmentationSumLattice lattice = new SegmentationSumLattice(model); Random random = new Random(42); for (int trial = 0; trial < 10; trial ++) { double[] weights = new double[50]; for (int i=0; i<weights.length; i++) { weights[i] = random.nextGaussian(); } double[] gradient = new double[weights.length]; model.setScorerWeights(new DynamicWeights(weights, false, false)); model.setUpdaterWeights(new DynamicWeights(gradient, false, false)); for (Word word : words) { // System.err.println("\n\n\nNEW WORD:" + word); SegmentationInstance instance = model.getInstance(word); // System.err.println("LATTICE"); double act_value = lattice.update(instance, true); double[] act_gradient = gradient.clone(); Arrays.fill(gradient, 0.0); // System.err.println("\n\nEXPLICIT"); double real_value = explicit_update(instance, model); double[] real_gradient = gradient.clone(); Arrays.fill(gradient, 0.0); boolean equal_gradient = Numerics.approximatelyEqual(act_gradient, real_gradient, 1e-5); if (!equal_gradient) { System.err.println(Arrays.toString(act_gradient) + "\n" + Arrays.toString(real_gradient)); } boolean equal_value = Numerics.approximatelyEqual(act_value, real_value); if (!equal_value) { System.err.println(word + " " + act_value + "\n" + real_value); } Assert.assertTrue(equal_gradient && equal_value); } } } private Word toWord(List<String> segments, List<String> tags) { String form = ""; for (String segment : segments) { form += segment; } Word w = new Word(form); w.add(new SegmentationReading(segments, tags)); return w; } }