// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.core; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import marmot.core.lattice.Hypothesis; import marmot.core.lattice.SequenceSumLattice; import marmot.core.lattice.SequenceViterbiLattice; import marmot.core.lattice.SumLattice; import marmot.core.lattice.ViterbiLattice; import marmot.core.lattice.ZeroOrderSumLattice; import marmot.core.lattice.ZeroOrderViterbiLattice; public class SimpleTagger implements Tagger { private static final long serialVersionUID = 1L; private Model model_; private WeightVector weight_vector_; private int num_level_; private double[][] threshs_; private double[] candidates_per_state_; private double[][] num_states_; private double[][] length_; private int order_; private boolean prune_; private int effective_order_; private int beam_size_; private boolean oracle_; private final int AVERAGE_NUMBER_OF_CANDIDATES = 5; public SimpleTagger(Model model, int order, WeightVector weight_vector) { order_ = order; model_ = model; prune_ = model.getOptions().getPrune(); beam_size_ = model.getOptions().getBeamSize(); oracle_ = model.getOptions().getOracle(); effective_order_ = Math.min(order, model.getOptions() .getEffectiveOrder()); weight_vector_ = weight_vector; candidates_per_state_ = model.getOptions().getCandidatesPerState(); int levels = model_.getTagTables().size(); num_level_ = levels; threshs_ = new double[levels][getOrder() + 1]; length_ = new double[levels][getOrder() + 1]; num_states_ = new double[levels][getOrder() + 1]; for (int level = 0; level < threshs_.length; level++) { Arrays.fill(threshs_[level], model.getOptions().getProbThreshold()); Arrays.fill(length_[level], 0); Arrays.fill(num_states_[level], 0); } } private void addTransitions(List<List<State>> states, int level, int order) { List<State> last_states = Collections.singletonList((State) model_ .getBoundaryState(level)); for (int index = 0; index < states.size(); index++) { List<State> current_states = states.get(index); Transition[][] transitions = new Transition[last_states.size()][current_states .size()]; int from_index = 0; for (State last_state : last_states) { FeatureVector vector = weight_vector_ .extractTransitionFeatures(last_state); int to_index = 0; for (State state : current_states) { if (last_state.canTransitionTo(state)) { Transition transition = new Transition(last_state, state, order); transition.setVector(vector); double score = 0.0; State run = state; while (run != null) { score += weight_vector_.dotProduct(run, vector); run = run.getSubLevelState(); } transition.setScore(score); transitions[from_index][to_index] = transition; } to_index++; } from_index++; } int to_index = 0; for (State state : current_states) { boolean found_transition = false; Transition[] transition_row = new Transition[last_states.size()]; for (from_index = 0; from_index < last_states.size(); from_index++) { transition_row[from_index] = transitions[from_index][to_index]; if (transition_row[from_index] != null) found_transition = true; } assert (found_transition); state.setTransitions(transition_row); to_index++; } last_states = current_states; } } protected List<List<State>> increaseOrder(List<List<State>> states, int level) { List<List<State>> new_state_candidates = new ArrayList<List<State>>( states.size() + 1); for (int index = 0; index < states.size(); index++) { int num_previous_states; if (index == 0) { num_previous_states = 1; } else { num_previous_states = states.get(index - 1).size(); } List<State> current_states = states.get(index); List<State> new_states = new ArrayList<State>(current_states.size() * num_previous_states); for (State state : current_states) { Transition[] transitions = state.getTransitions(); state.setTransitions(null); assert num_previous_states <= transitions.length; for (int previous_state_index = 0; previous_state_index < num_previous_states; previous_state_index++) { Transition t = transitions[previous_state_index]; if (t == null) { continue; } t.setScore(t.getScore() + state.getScore()); new_states.add(t); t.getSubOrderState().setTransitions(null); assert t.check(); } } assert !new_states.isEmpty(); new_state_candidates.add(new_states); } new_state_candidates.add(Collections.singletonList(model_ .getBoundaryState(level))); return new_state_candidates; } private boolean cache_feature_vector_ = false; private Result result_; protected List<List<State>> getStates(Sequence sequence, boolean training) { List<List<State>> candidates = new ArrayList<List<State>>( sequence.size() + 1); for (int index = 0; index < sequence.size(); index++) { Token token = sequence.get(index); FeatureVector vector = token.getVector(); if (vector == null) { vector = weight_vector_.extractStateFeatures(sequence, index); if (cache_feature_vector_) token.setVector(vector); } int[] tag_indexes = model_.getTagCandidates(sequence, index, null); List<State> states = new ArrayList<State>(tag_indexes.length); for (int tag_index : tag_indexes) { if (tag_index == -1) break; State state = new State(tag_index); state.setVector(vector); state.setScore(weight_vector_.dotProduct(state, vector)); model_.setLemmaCandidates(token, state, true, training); states.add(state); } assert states.size() > 0; candidates.add(states); } candidates.add(Collections.singletonList(model_.getBoundaryState(0))); return candidates; } @Override public String setThresholds(boolean print) { StringBuilder sb = null; if (print) { sb = new StringBuilder(); } for (int level = 0; level < num_states_.length; level++) { for (int order = 0; order < num_states_[level].length; order++) { if (length_[level][order] > 0) { double num_states = num_states_[level][order] / length_[level][order]; int effective_order = Math.min(order, candidates_per_state_.length - 1); double want = candidates_per_state_[effective_order]; if (Math.abs(num_states - want) > 1e-1) { if (num_states > want) { threshs_[level][order] += 0.1 * threshs_[level][order]; } else { threshs_[level][order] -= 0.1 * threshs_[level][order]; } } if (print) { sb.append(' '); sb.append(num_states); } num_states_[level][order] = 0; length_[level][order] = 0; } } if (print) { sb.append('\n'); } } if (print) return sb.toString(); return null; } private List<List<State>> increaseLevel(List<List<State>> candidates, Sequence sentence) { List<List<State>> new_candidates = new ArrayList<List<State>>( candidates.size()); final int average_size = AVERAGE_NUMBER_OF_CANDIDATES; int index = 0; for (List<State> current_states : candidates) { List<State> new_current_states; if (index < candidates.size() - 1) { new_current_states = new ArrayList<State>(current_states.size() * average_size); for (State state : current_states) { FeatureVector vector = weight_vector_ .extractStateFeatures(state); assert state.getTransitions() == null; int[] tag_indexes = model_.getTagCandidates(sentence, index, state); for (int tag_index : tag_indexes) { if (tag_index == -1) { break; } assert state.getOrder() == 1; State new_state = new State(tag_index, state); new_state.setVector(vector); new_state.setScore(weight_vector_.dotProduct(new_state, vector) + state.getRealScore()); model_.setLemmaCandidates(new_state, true); new_current_states.add(new_state); } } } else { new_current_states = Collections .singletonList(model_.getBoundaryState(current_states .get(0).getLevel() + 1)); } new_candidates.add(new_current_states); index++; } return new_candidates; } protected void incrementStateCounter(int level, int order, List<List<State>> candidates) { int num_states = 0; for (List<State> states : candidates) { num_states += states.size(); } int length = candidates.size(); num_states_[level][order] += num_states; length_[level][order] += length; } @Override public SumLattice getSumLattice(boolean training, Sequence sentence) { int order = getOrder(); List<List<State>> candidates = null; SumLattice lattice = null; for (int level = 0; level < getNumLevels(); level++) { if (level == 0) { candidates = getStates(sentence, training); } else { candidates = lattice.getZeroOrderCandidates(prune_); incrementStateCounter(level - 1, lattice.getOrder(), candidates); if (training && testForGoldCandidates(sentence, candidates, lattice) == null) { return lattice; } int old_size = candidates.size(); candidates = increaseLevel(candidates, sentence); assert candidates.size() == old_size; for (List<State> states : candidates) { assert !states.isEmpty(); } } lattice = new ZeroOrderSumLattice(candidates, threshs_[level][0], oracle_); if (oracle_ || training) lattice.setGoldCandidates(getGoldIndexes(sentence, lattice.getCandidates())); int effective_order = effective_order_; if (level + 1 == getNumLevels()) { effective_order = order; } for (int current_order = 0; current_order < effective_order; current_order++) { if (prune_) { candidates = lattice.prune(); incrementStateCounter(level, current_order, lattice.getZeroOrderCandidates(true)); assert candidates.size() > 0; } if (current_order == 0) { if (level == 0) { int index = 0; for (List<State> states : candidates) { if (index + 1 < candidates.size()) { // Last state is boundary state // Add lemma scores with pos features for (State state : states) { model_.setLemmaCandidates( sentence.get(index), state, false, training); } } index++; } } else if (level + 1 == getNumLevels()) { int index = 0; for (List<State> states : candidates) { if (index + 1 < candidates.size()) { // Last state is boundary state // Add lemma scores with morph features for (State state : states) { model_.setLemmaCandidates(state, false); } } index++; } } } /* * During training, if gold sequence is not among the new * candidates return the lattice immediately to do an early * update. */ if (training && testForGoldCandidates(sentence, candidates, lattice) == null) { return lattice; } if (current_order > 0) { candidates = increaseOrder(candidates, level); } addTransitions(candidates, level, current_order + 2); lattice = new SequenceSumLattice(candidates, model_.getBoundaryState(level), threshs_[level][current_order + 1], current_order + 1, false); if (oracle_ || training) lattice.setGoldCandidates(getGoldIndexes(sentence, lattice.getCandidates())); } } assert lattice.getCandidates().size() >= sentence.size(); return lattice; } private List<Integer> testForGoldCandidates(Sequence sentence, List<List<State>> candidates, SumLattice lattice) { List<Integer> gold_candidates = getGoldIndexes(sentence, candidates); if (gold_candidates != null) { return gold_candidates; } return null; } public int getOrder() { return order_; } @Override public int getNumLevels() { return num_level_; } @Override public List<Integer> getGoldIndexes(Sequence sequence, List<List<State>> candidates) { List<Integer> list = new ArrayList<Integer>(candidates.size()); int last_candidate_index = 0; for (int index = 0; index < candidates.size(); index++) { List<State> current_candidates = candidates.get(index); List<Integer> current_candidate_indexes = new ArrayList<Integer>( current_candidates.size()); for (int candidate_index = 0; candidate_index < current_candidates .size(); candidate_index++) { current_candidate_indexes.add(candidate_index); } int max_level = current_candidates.get(0).getZeroOrderState() .getLevel(); for (int level = max_level; level >= 0; level--) { List<Integer> new_current_candidate_indexes = new ArrayList<Integer>( current_candidate_indexes.size()); int gold_tag_index; if (index < sequence.size()) { gold_tag_index = sequence.get(index).getTagIndexes()[level]; } else { gold_tag_index = model_.getBoundaryIndex(); } for (int state_index = 0; state_index < current_candidate_indexes .size(); state_index++) { int candidate_index = current_candidate_indexes .get(state_index); State state = current_candidates.get(candidate_index); if (level == max_level) { // check transition! boolean valid = (state.getTransitions() == null || state .getTransition(last_candidate_index) != null); if (!valid) { continue; } } if (gold_tag_index == state.getZeroOrderState() .getSubLevel(max_level - level).getIndex()) { new_current_candidate_indexes.add(candidate_index); } } current_candidate_indexes = new_current_candidate_indexes; if (current_candidate_indexes.isEmpty()) { return null; } } assert current_candidate_indexes.size() == 1; int gold_candidate_index = current_candidate_indexes.get(0); list.add(gold_candidate_index); last_candidate_index = gold_candidate_index; } return list; } @Override public Model getModel() { return model_; } @Override public WeightVector getWeightVector() { return weight_vector_; } @Override public List<List<String>> tag(Sequence sentence) { List<int[]> indexes = tag_(sentence); List<List<String>> strings = new ArrayList<List<String>>(indexes.size()); for (int[] array : indexes) { strings.add(indexesToStrings(array)); } return strings; } protected List<String> indexesToStrings(int[] indexes) { List<String> sarray = new ArrayList<String>(indexes.length); int level = 0; for (int index : indexes) { sarray.add(model_.getTagTables().get(level).toSymbol(index)); level++; } return sarray; } protected int[] stateToIndexes(State state) { int num_levels = state.getLevel() + 1; int[] indexes = new int[num_levels]; for (int level = num_levels - 1; level >= 0; level--) { assert state != null; assert state.getIndex() >= 0; indexes[level] = state.getIndex(); state = state.getSubLevelState(); } return indexes; } protected List<State> tag_states(Sequence sequence) { List<State> list = new ArrayList<State>(sequence.size()); SumLattice sum_lattice = getSumLattice(false, sequence); List<List<State>> candidates = sum_lattice.getCandidates(); ViterbiLattice lattice; if (sum_lattice instanceof ZeroOrderSumLattice) { lattice = new ZeroOrderViterbiLattice(candidates, beam_size_, model_.getMarganlizeLemmas()); } else { lattice = new SequenceViterbiLattice(candidates, model_.getBoundaryState(getNumLevels() - 1), beam_size_, model_.getMarganlizeLemmas()); } Hypothesis h = lattice.getViterbiSequence(); List<Integer> state_indexes = h.getStates(); for (int index = 0; index < sequence.size(); index++) { int candidate_index = state_indexes.get(index); List<State> token_candidates = candidates.get(index); State state = token_candidates.get(candidate_index); state = state.getZeroOrderState(); list.add(state); } return list; } protected List<int[]> tag_(Sequence sequence) { List<int[]> list = new ArrayList<int[]>(sequence.size()); List<State> states = tag_states(sequence); for (State state : states) { int[] indexes = stateToIndexes(state); list.add(indexes); } return list; } public void setMaxLevel(int level) { num_level_ = level; } @Override public void setResult(Result result) { result_ = result; } @Override public Result getResult() { return result_; } }