// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.core.lattice;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import lemming.lemma.ranker.RankerCandidate;
import marmot.core.State;
import marmot.core.Transition;
import marmot.core.WeightVector;
import marmot.util.Check;
import marmot.util.Numerics;
public class SequenceSumLattice implements SumLattice {
private ForwardSequenceLattice forward_;
private BackwardSequenceLattice backward_;
private List<List<State>> candidates_;
private double log_threshold_;
private boolean initilized_;
private State boundary_;
private int order_;
private List<Integer> gold_candidate_indexes_;
private boolean oracle_;
public SequenceSumLattice(List<List<State>> candidates, State boundary,
double threshold, int order, boolean oracle) {
forward_ = new ForwardSequenceLattice(candidates, boundary);
backward_ = new BackwardSequenceLattice(candidates, boundary);
candidates_ = candidates;
log_threshold_ = Math.log(threshold);
initilized_ = false;
boundary_ = boundary;
order_ = order;
oracle_ = oracle;
}
@Override
public List<List<State>> getCandidates() {
return candidates_;
}
public void init() {
if (initilized_) {
return;
}
initilized_ = true;
forward_.init();
backward_.init();
}
@Override
public List<List<State>> prune() {
return pruneStates();
}
public List<List<State>> pruneStates() {
init();
double score_sum_forward = forward_.partitionFunction();
assert Check.isNormal(score_sum_forward);
assert Check.isNormal(backward_.partitionFunction());
assert Numerics.approximatelyEqual(score_sum_forward, backward_.partitionFunction());
List<List<State>> candidates = new ArrayList<List<State>>(
candidates_.size());
int[] index_map = null;
int num_previous_states = 1;
for (int index = 0; index < candidates_.size(); index++) {
int num_states = candidates_.get(index).size();
int[] new_index_map = new int[num_states];
Arrays.fill(new_index_map, -1);
double score_sum = Double.NEGATIVE_INFINITY;
List<State> states = new ArrayList<State>(num_states);
int max_state_index = -1;
double max_score = Double.NEGATIVE_INFINITY;
for (int state_index = 0; state_index < num_states; state_index++) {
State state = candidates_.get(index).get(state_index);
double score = forward_.get(index, state_index)
+ backward_.get(index + 1, state_index);
score_sum = Numerics.sumLogProb(score_sum, score);
if (index_map != null) {
boolean found_transition = false;
for (int transition_index = 0; transition_index < num_previous_states; transition_index++) {
State transition = state
.getTransition(transition_index);
if (transition != null
&& index_map[transition_index] >= 0) {
found_transition = true;
break;
}
}
if (!found_transition) {
continue;
}
}
boolean is_oracle_state = false;
if (oracle_ && gold_candidate_indexes_ != null ) {
is_oracle_state = gold_candidate_indexes_.get(index) == state_index;
}
if ((score - score_sum_forward > log_threshold_) || is_oracle_state) {
if (states.size() > 50)
continue;
states.add(fixTransitions(state, index_map,
num_previous_states));
new_index_map[state_index] = states.size() - 1;
}
if (score > max_score) {
max_score = score;
max_state_index = state_index;
}
}
assert score_sum != Double.NEGATIVE_INFINITY;
if (Math.abs(score_sum - score_sum_forward) > 1e-5) {
Logger logger = Logger.getLogger(getClass().getName());
logger.warning(String.format("Difference in FB: %g %g", score_sum, score_sum_forward));
}
assert Math.abs(score_sum - score_sum_forward) < 1e-5;
if (states.isEmpty()) {
states.add(fixTransitions(
candidates_.get(index).get(max_state_index), index_map,
num_previous_states));
new_index_map[max_state_index] = 0;
}
assert !states.isEmpty();
candidates.add(states);
num_previous_states = num_states;
index_map = new_index_map;
}
assert candidates.size() == candidates_.size();
return candidates;
}
private State fixTransitions(State state, int[] index_map, int num_states) {
if (index_map == null) {
return state;
}
state = state.copy();
Transition[] old_transitions = state.getTransitions();
Transition[] new_transitions = new Transition[num_states];
for (int index = 0; index < old_transitions.length; index++) {
int new_index = index_map[index];
if (new_index >= 0) {
new_transitions[new_index] = old_transitions[index];
}
}
state.setTransitions(new_transitions);
return state;
}
@Override
public double update(WeightVector weights, double step_width) {
init();
double ll = 0;
double score_sum = forward_.partitionFunction();
int last_gold_candidate_index = 0;
for (int index = 0; index < candidates_.size(); index++) {
int gold_candidate_index = gold_candidate_indexes_.get(index);
double state_sum = Double.NEGATIVE_INFINITY;
double trans_sum = Double.NEGATIVE_INFINITY;
int state_index = 0;
for (State state : candidates_.get(index)) {
boolean is_gold_sequence_state = state_index == gold_candidate_index;
int trans_index = 0;
for (State transition : state.getTransitions()) {
if (transition != null) {
double trans_score = forward_.get(index - 1,
trans_index)
+ state.getScore()
+ transition.getScore()
+ backward_.get(index + 1, state_index);
trans_sum = Numerics.sumLogProb(trans_sum, trans_score);
double p = Math.exp(trans_score - score_sum);
if (trans_index == last_gold_candidate_index && is_gold_sequence_state) {
ll += transition.getScore();
weights.updateWeights(transition, (1.0 - p) * step_width, true);
} else {
weights.updateWeights(transition, -p * step_width, true);
}
}
trans_index++;
}
double state_score = forward_.get(index, state_index)
+ backward_.get(index + 1, state_index);
state_sum = Numerics.sumLogProb(state_sum, state_score);
double p = Math.exp(state_score - score_sum);
double value = -p;
if (is_gold_sequence_state) {
ll += state.getScore();
value += 1.0;
}
state.incrementEstimatedCounts(value * step_width);
State zero_order_state = state.getZeroOrderState();
if (zero_order_state.getLemmaCandidates() != null) {
double new_state_score = state_score - zero_order_state.getScore() + zero_order_state.getRealScore();
for (RankerCandidate candidate : zero_order_state.getLemmaCandidates()) {
double score = new_state_score + candidate.getScore();
p = Math.exp(score - score_sum);
value = -p;
if (is_gold_sequence_state && candidate.isCorrect()) {
value += 1.0;
}
candidate.incrementEstimatedCounts(value * step_width);
}
}
state_index++;
}
for (State state : candidates_.get(index)) {
state.updateWeights(weights);
}
last_gold_candidate_index = gold_candidate_index;
}
ll -= score_sum;
return ll;
}
// protected void normalTest(double score) {
// if ((Double.isNaN(score) || Double.isInfinite(score))) {
// throw new RuntimeException("normalTest: " + score);
// }
// }
//
// protected double diffTest(double a, double b) {
// double diff = Math.abs(a - b);
// if (diff > 1.e-5) {
// throw new RuntimeException(String.format("test failed: %g %g : %g",
// a, b, diff));
// }
// return diff;
// }
@Override
public int getOrder() {
return order_;
}
public static List<List<State>> getZeroOrderCandidates(List<List<State>> candidates, int boundary_index) {
List<List<State>> new_candidates = new ArrayList<List<State>>(candidates.size());
boolean found_boundary = false;
for (List<State> states : candidates) {
List<State> new_states = new ArrayList<State>();
for (State state : states) {
State zero_order_state = state.getZeroOrderState();
assert !(zero_order_state instanceof Transition);
if (zero_order_state.getIndex() == boundary_index) {
found_boundary = true;
}
boolean contains = false;
for (State new_state : new_states) {
if (new_state.equalIndexes(zero_order_state)) {
contains = true;
break;
}
}
if (!contains) {
State new_state = zero_order_state.copy();
new_state.setTransitions(null);
new_states.add(new_state);
assert new_state.getIndex() >= 0;
assert new_state.getTransitions() == null;
}
}
new_candidates.add(new_states);
if (found_boundary) {
assert new_states.size() == 1;
break;
}
assert !new_states.isEmpty();
}
assert !new_candidates.isEmpty();
for (List<State> states: new_candidates) {
for (State state : states) {
assert state.getTransitions() == null;
}
}
assert found_boundary;
return new_candidates;
}
@Override
public List<List<State>> getZeroOrderCandidates(boolean filter) {
List<List<State>> candidates;
if (filter) {
candidates = prune();
} else {
candidates = candidates_;
}
List<List<State>> new_candidates = new ArrayList<List<State>>(candidates.size());
boolean found_boundary = false;
for (List<State> states : candidates) {
List<State> new_states = new ArrayList<State>();
for (State state : states) {
State zero_order_state = state.getZeroOrderState();
assert !(zero_order_state instanceof Transition);
if (zero_order_state.getIndex() == boundary_.getIndex()) {
found_boundary = true;
}
boolean contains = false;
for (State new_state : new_states) {
if (new_state.equalIndexes(zero_order_state)) {
contains = true;
break;
}
}
if (!contains) {
State new_state = zero_order_state.copy();
new_state.setTransitions(null);
new_states.add(new_state);
assert new_state.getIndex() >= 0;
assert new_state.getTransitions() == null;
}
}
new_candidates.add(new_states);
if (found_boundary) {
assert new_states.size() == 1;
break;
}
assert !new_states.isEmpty();
}
assert !new_candidates.isEmpty();
for (List<State> states: new_candidates) {
for (State state : states) {
assert state.getTransitions() == null;
}
}
assert found_boundary;
return new_candidates;
}
@Override
public void setGoldCandidates(List<Integer> candidates) {
gold_candidate_indexes_ = candidates;
}
@Override
public int getLevel() {
return candidates_.get(0).get(0).getLevel();
}
@Override
public List<Integer> getGoldCandidates() {
return gold_candidate_indexes_;
}
}