// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.core.lattice; import java.util.ArrayList; import java.util.Collection; import java.util.List; import lemming.lemma.ranker.RankerCandidate; import marmot.core.State; import marmot.core.WeightVector; import marmot.util.Check; import marmot.util.Numerics; public class ZeroOrderSumLattice implements SumLattice { private List<List<State>> candidates_; private double log_threshold_; private List<Integer> gold_candidate_indexes_; private double[] score_sums_; private boolean initialized_; private boolean oracle_; public ZeroOrderSumLattice(List<List<State>> candidates, double threshold, boolean oracle) { candidates_ = candidates; log_threshold_ = Math.log(threshold); initialized_ = false; oracle_ = oracle; } private void init() { if (initialized_) return; initialized_ = true; score_sums_ = new double[candidates_.size()]; for (int index = 0; index < candidates_.size(); index++) { List<State> states = candidates_.get(index); assert !states.isEmpty(); score_sums_[index] = getScoreSum(states); } } @Override public List<List<State>> getCandidates() { return candidates_; } @Override public List<List<State>> prune() { return prune(log_threshold_); } public List<List<State>> prune(double log_threshold) { init(); List<List<State>> candidates = new ArrayList<List<State>>( candidates_.size()); for (int index = 0; index < candidates_.size(); index++) { int num_states = candidates_.get(index).size(); assert num_states >= 0; double score_sum = score_sums_[index]; List<State> states = new ArrayList<State>(num_states); State max_state = null; double max_score = Double.NEGATIVE_INFINITY; int state_index = 0; for (State state : candidates_.get(index)) { double score = state.getScore() - score_sum; assert Check.isNormal(score); boolean is_oracle_state = false; if (oracle_ && gold_candidate_indexes_ != null ) { is_oracle_state = gold_candidate_indexes_.get(index) == state_index; } if (score > log_threshold || is_oracle_state) { states.add(state); } if (score > max_score) { max_score = score; max_state = state; } state_index ++; } assert max_state != null; if (states.isEmpty()) { states.add(max_state); } candidates.add(states); } return candidates; } private double getScoreSum(Collection<State> states) { double score_sum = Double.NEGATIVE_INFINITY; for (State state : states) { assert Check.isNormal(state.getScore()); score_sum = Numerics.sumLogProb(score_sum, state.getScore()); } assert score_sum != Double.NEGATIVE_INFINITY; assert Check.isNormal(score_sum); return score_sum; } @Override public double update(WeightVector weights, double step_width) { init(); double ll = 0; if (gold_candidate_indexes_ == null) { System.err.println("Warning: Gold sequence not in zero order lattice!"); return ll; } for (int index = 0; index < candidates_.size() - 1; index++) { int gold_candidate_index = gold_candidate_indexes_.get(index); List<State> states = candidates_.get(index); double score_sum = score_sums_[index]; ll += update(states, gold_candidate_index, score_sum, weights, step_width); } return ll; } private double update(List<State> states, int gold_candidate_index, double score_sum, WeightVector weights, double step_width) { int candidate_index = 0; double ll=0; for (State state : states) { assert state.getZeroOrderState() == state; double p = Math.exp(state.getScore() - score_sum); double value = -p; if (candidate_index == gold_candidate_index) { value += 1.0; ll = states.get(gold_candidate_index).getScore() - score_sum; } weights.updateWeights(state, value * step_width, false); if (state.getLemmaCandidates() != null) { double new_score = state.getRealScore(); for (RankerCandidate candidate : state.getLemmaCandidates()) { double score = candidate.getScore() + new_score; p = Math.exp(score - score_sum); value = -p; if (candidate.isCorrect() && candidate_index == gold_candidate_index) { value += 1.0; ll = score - score_sum; } candidate.update(state, weights, value * step_width); } } candidate_index++; } return ll; } @Override public int getOrder() { return 0; } @Override public List<List<State>> getZeroOrderCandidates(boolean filter) { if (filter) { return prune(); } return candidates_; } @Override public void setGoldCandidates(List<Integer> candidates) { gold_candidate_indexes_ = candidates; } @Override public int getLevel() { return candidates_.get(0).get(0).getLevel(); } @Override public List<Integer> getGoldCandidates() { return gold_candidate_indexes_; } }