// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.core;
import java.util.List;
import lemming.lemma.ranker.RankerCandidate;
import marmot.util.Check;
import marmot.util.Numerics;
public class State {
private FeatureVector vector_;
private double score_;
protected double estimated_count_;
private Transition[] transitions_;
private int index_;
private State sub_level_state_;
// For joint lemmatization
private List<RankerCandidate> candidates_;
private double candidate_score_sum_;
public State() {
index_ = -1;
candidate_score_sum_ = Double.NEGATIVE_INFINITY;
}
public State(int index) {
assert index >= 0;
index_ = index;
}
public State(int index, State sub_level_state) {
this(index);
sub_level_state_ = sub_level_state;
}
public void setVector(FeatureVector vector) {
vector_ = vector;
}
public void setScore(double score) {
score_ = score;
}
public FeatureVector getVector() {
return vector_;
}
public double getScore() {
if (candidates_ != null) {
return candidate_score_sum_;
}
return score_;
}
public void setTransitions(Transition[] transitions) {
transitions_ = transitions;
}
public int getIndex() {
return index_;
}
public State getZeroOrderState() {
return this;
}
public Transition getTransition(int previous_state_index) {
return transitions_[previous_state_index];
}
public int getOrder() {
return 1;
}
public boolean canTransitionTo(State other) {
if (other.getOrder() != 1) {
assert getIndex() == 0;
}
return true;
}
public Transition[] getTransitions() {
return transitions_;
}
public void incrementEstimatedCounts(double d) {
estimated_count_ += d;
}
public void updateWeights(WeightVector weights) {
if (estimated_count_ != 0.0) {
weights.updateWeights(this, estimated_count_, true);
estimated_count_ = 0.0;
}
if (candidates_ != null) {
for (RankerCandidate candidate : candidates_) {
candidate.updateWeights(this, weights);
}
}
}
public int getLevel() {
if (sub_level_state_ == null) {
return 0;
}
return sub_level_state_.getLevel() + 1;
}
public State getSubLevel(int depth) {
assert depth >= 0;
if (depth == 0)
return this;
if (sub_level_state_ == null) {
if (depth == 1)
return null;
throw new RuntimeException("Can't reach depth!");
}
return sub_level_state_.getSubLevel(depth - 1);
}
public State getSubLevelState() {
return sub_level_state_;
}
public void setSubLevelState(State sub_level_state) {
sub_level_state_ = sub_level_state;
}
public boolean equalIndexes(State other) {
if (index_ != other.getIndex()) {
return false;
}
assert other.getLevel() == this.getLevel();
if (other.getSubLevelState() != null) {
return other.getSubLevelState().equalIndexes(getSubLevelState());
}
return true;
}
public State getSubOrderState() {
return null;
}
public State getPreviousSubOrderState() {
return null;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append('[');
if (sub_level_state_ != null) {
sb.append(sub_level_state_);
}
sb.append(' ');
sb.append(index_);
sb.append(']');
return sb.toString();
}
public boolean check() {
return transitions_ == null;
}
public State copy(State state) {
state.vector_ = vector_;
state.score_ = score_;
state.transitions_ = transitions_;
state.index_ = index_;
state.sub_level_state_ = sub_level_state_;
state.estimated_count_ = estimated_count_;
state.candidate_score_sum_ = candidate_score_sum_;
state.candidates_ = candidates_;
return state;
}
public State copy() {
State state = copy(new State());
assert state.index_ >= 0;
return state;
}
public void setLemmaCandidates(List<RankerCandidate> candidates) {
assert getOrder() == 1;
candidates_ = candidates;
}
public List<RankerCandidate> getLemmaCandidates() {
assert getOrder() == 1;
return candidates_;
}
public void setLemmaScoreSum() {
assert getOrder() == 1;
assert getLemmaCandidates() != null;
assert Check.isNormal(score_);
candidate_score_sum_ = Double.NEGATIVE_INFINITY;
for (RankerCandidate candidate : getLemmaCandidates()) {
double candidate_score = candidate.getScore() + score_;
assert Check.isNormal(candidate_score);
candidate_score_sum_ = Numerics.sumLogProb(candidate_score_sum_, candidate_score);
}
assert candidate_score_sum_ != Double.NEGATIVE_INFINITY;
assert Check.isNormal(candidate_score_sum_);
}
public double getRealScore() {
assert getOrder() == 1;
return score_;
}
}