package chipmunk.segmenter;
import java.util.Arrays;
import java.util.Collection;
import marmot.util.DynamicWeights;
import cc.mallet.optimize.Optimizable.ByGradientValue;
public class SemiCrfObjective implements ByGradientValue {
private SegmenterModel model_;
private Collection<Word> words_;
private double value_;
private double[] gradient_;
private double[] weights_;
private double penalty_;
public SemiCrfObjective(SegmenterModel model, Collection<Word> words,
double penalty) {
model_ = model;
words_ = words;
penalty_ = penalty;
}
public void init() {
DynamicWeights weights = new DynamicWeights(null);
model_.setScorerWeights(weights);
DynamicWeights gradient = new DynamicWeights(null);
model_.setUpdaterWeights(gradient);
model_.getUpdater().setInsert(false);
calcLikelihood();
DynamicWeights scorer = model_.getScorer().getWeights();
DynamicWeights updater = model_.getUpdater().getWeights();
if (scorer.getLength() != updater.getLength()) {
int length = Math.max(scorer.getLength(), updater.getLength());
scorer.setLength(length);
updater.setLength(length);
}
weights_ = scorer.getWeights();
scorer.setExapnd(false);
gradient_ = updater.getWeights();
updater.setExapnd(false);
assert weights_.length == gradient_.length : weights_.length + " "
+ gradient_.length;
calcPenalty();
}
public void update() {
value_ = 0.;
Arrays.fill(gradient_, 0.);
calcLikelihood();
calcPenalty();
}
private void calcPenalty() {
if (penalty_ > 0.0) {
for (int i = 0; i < weights_.length; i++) {
double w = weights_[i];
value_ -= penalty_ * w * w;
gradient_[i] -= 2. * penalty_ * w;
}
}
}
private void calcLikelihood() {
SegmentationSumLattice lattice = new SegmentationSumLattice(model_);
for (Word word : words_) {
SegmentationInstance instance = model_.getInstance(word);
value_ += lattice.update(instance, true);
}
}
@Override
public int getNumParameters() {
return weights_.length;
}
@Override
public double getParameter(int arg0) {
throw new UnsupportedOperationException();
}
@Override
public void getParameters(double[] params) {
System.arraycopy(weights_, 0, params, 0, weights_.length);
}
@Override
public void setParameter(int arg0, double arg1) {
throw new UnsupportedOperationException();
}
@Override
public void setParameters(double[] params) {
System.arraycopy(params, 0, weights_, 0, weights_.length);
update();
}
@Override
public double getValue() {
return value_;
}
@Override
public void getValueGradient(double[] gradient) {
System.arraycopy(gradient_, 0, gradient, 0, gradient_.length);
}
}