// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package lemming.lemma.ranker;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lemming.lemma.LemmaCandidate;
import lemming.lemma.LemmaCandidateSet;
import lemming.lemma.ranker.RankerTrainer.RankerTrainerOptions;
import marmot.util.Numerics;
import cc.mallet.optimize.Optimizable.ByGradientValue;
public class RankerObjective implements ByGradientValue {
private RankerModel model_;
private List<RankerInstance> instances_;
private double value_;
private double[] gradient_;
private double[] weights_;
private double penalty_;
private RankerTrainerOptions options_;
private int max_num_duplicates_;
public RankerObjective(RankerTrainerOptions options, RankerModel model,
List<RankerInstance> instances, int max_num_duplicates) {
options_ = options;
model_ = model;
instances_ = instances;
weights_ = model.getWeights();
gradient_ = new double[weights_.length];
penalty_ = options.getQuadraticPenalty();
max_num_duplicates_ = max_num_duplicates;
}
public RankerObjective(RankerTrainerOptions options, RankerModel model,
List<RankerInstance> instances) {
this(options, model, instances, 1);
}
public void update(RankerInstance instance, boolean sgd, double step_width) {
if (!options_.getUseOfflineFeatureExtraction())
model_.addIndexes(instance, instance.getCandidateSet(), true);
int pos_index_ = instance.getPosIndex(model_.getPosTable(), false);
int[] morph_indexes_ = instance.getMorphIndexes(model_.getMorphTable(), false);
model_.setWeights(weights_);
double logSum = Double.NEGATIVE_INFINITY;
LemmaCandidateSet set = instance.getCandidateSet();
for (Map.Entry<String, LemmaCandidate> entry : set) {
LemmaCandidate candidate = entry.getValue();
double score = model_.score(candidate, pos_index_, morph_indexes_);
candidate.setScore(score);
logSum = Numerics.sumLogProb(logSum, score);
}
double target_prob = Double.NEGATIVE_INFINITY;
if (!sgd)
model_.setWeights(gradient_);
for (Map.Entry<String, LemmaCandidate> entry : set) {
LemmaCandidate candidate = entry.getValue();
double score = candidate.getScore();
double prob = Math.exp(score - logSum);
double update = -prob;
String plemma = entry.getKey();
if (plemma.equals(instance.getInstance().getLemma())) {
update += 1.0;
target_prob = prob;
value_ += (score - logSum)
* instance.getInstance().getCount();
}
if (sgd) {
// For max_num_duplicates = 2, we created two copies of every instances with count >=2
// Therefore we only count half the count here.
double effective_count = instance.getInstance().getCount();
if (Numerics.approximatelyGreaterEqual(effective_count, (double) max_num_duplicates_)) {
effective_count /= (double) max_num_duplicates_;
}
// 1 -> 1, 2 -> 1.69, 3 -> 2.09 ...
update *= Math.log(effective_count * Math.E);
} else {
update *= instance.getInstance().getCount();
}
model_.update(instance, plemma, update * step_width);
}
assert target_prob != Double.NEGATIVE_INFINITY;
if (!options_.getUseOfflineFeatureExtraction())
model_.removeIndexes(instance.getCandidateSet());
model_.setWeights(weights_);
}
public void update() {
// System.err.println("update");
value_ = 0.;
Arrays.fill(gradient_, 0.);
for (RankerInstance instance : instances_) {
update(instance, false, 1.0);
}
for (int i = 0; i < weights_.length; i++) {
double w = weights_[i];
value_ -= penalty_ * w * w;
gradient_[i] -= 2. * penalty_ * w;
}
}
@Override
public int getNumParameters() {
// System.err.println("getNumParameters");
return weights_.length;
}
@Override
public double getParameter(int arg0) {
throw new UnsupportedOperationException();
}
@Override
public void getParameters(double[] params) {
// System.err.println("getParameters");
System.arraycopy(weights_, 0, params, 0, weights_.length);
}
@Override
public void setParameter(int arg0, double arg1) {
throw new UnsupportedOperationException();
}
@Override
public void setParameters(double[] params) {
// System.err.println("setParameters");
System.arraycopy(params, 0, weights_, 0, weights_.length);
update();
}
@Override
public double getValue() {
// System.err.println("getValue");
return value_;
}
@Override
public void getValueGradient(double[] gradient) {
// System.err.println("getValueGradient " + gradient_.length + " " +
// gradient.length);
System.arraycopy(gradient_, 0, gradient, 0, gradient_.length);
}
}