// Copyright 2015 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package experimental.analyzer.simple;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import marmot.util.Mutable;
import marmot.util.Numerics;
import cc.mallet.optimize.Optimizable.ByGradientValue;
import experimental.analyzer.AnalyzerTag;
import experimental.analyzer.simple.SimpleAnalyzer.Mode;
import experimental.analyzer.simple.SimpleAnalyzerTrainer.PairConstraint;
public class SimpleAnalyzerObjective implements ByGradientValue {
private SimpleAnalyzerModel model_;
private Collection<SimpleAnalyzerInstance> instances_;
private double value_;
private double[] gradient_;
private double[] weights_;
private double penalty_;
private Mode mode_;
private double[] scores;
private double[] updates;
private Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts_;
private PairConstraint pair_constraint_;
public SimpleAnalyzerObjective(
double penalty,
SimpleAnalyzerModel model,
Collection<SimpleAnalyzerInstance> instances,
Mode mode,
Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts,
PairConstraint pair_constraint) {
model_ = model;
instances_ = instances;
weights_ = model.getWeights();
gradient_ = new double[weights_.length];
penalty_ = penalty;
mode_ = mode;
relative_counts_ = relative_counts;
int num_tags = model_.getNumTags();
scores = new double[num_tags];
updates = new double[num_tags];
pair_constraint_ = pair_constraint;
}
public void update() {
// System.err.println("update");
value_ = 0.;
Arrays.fill(gradient_, 0.);
for (SimpleAnalyzerInstance instance : instances_) {
update(instance, 1.0, false);
}
for (int i = 0; i < weights_.length; i++) {
double w = weights_[i];
value_ -= penalty_ * w * w;
gradient_[i] -= 2. * penalty_ * w;
}
}
public void update(SimpleAnalyzerInstance instance, double step_width,
boolean sgd) {
Arrays.fill(scores, 0.0);
Arrays.fill(updates, 0.0);
int num_tags = model_.getNumTags();
model_.setWeights(weights_);
model_.score(instance, scores);
switch (mode_) {
case binary:
value_ += binaryUpdate(scores, updates, num_tags, instance);
break;
case classifier:
value_ += classifierUpdate(scores, updates, num_tags, instance);
break;
default:
throw new RuntimeException("Unsupported mode: " + mode_);
}
if (!sgd) {
model_.setWeights(gradient_);
}
if (!Numerics.approximatelyEqual(step_width, 1.0)) {
for (int i = 0; i < num_tags; i++) {
updates[i] *= step_width;
}
}
model_.update(instance, updates);
model_.setWeights(weights_);
}
private double classifierUpdate(double[] scores, double[] updates,
int num_tags, SimpleAnalyzerInstance instance) {
double value = 0;
double sum = Double.NEGATIVE_INFINITY;
int num_tag_indexes = instance.getTagIndexes().size();
for (int tag_index = 0; tag_index < num_tags; tag_index++) {
sum = Numerics.sumLogProb(scores[tag_index], sum);
}
value -= num_tag_indexes * sum;
for (int tag_index = 0; tag_index < num_tags; tag_index++) {
updates[tag_index] = -num_tag_indexes
* Math.exp(scores[tag_index] - sum);
}
if (pair_constraint_ != PairConstraint.none) {
for (int tag_index : instance.getTagIndexes()) {
AnalyzerTag tag = model_.getTagTable().toSymbol(tag_index);
Map<AnalyzerTag, Mutable<Double>> map = relative_counts_
.get(tag);
if (map != null) {
for (Map.Entry<AnalyzerTag, Mutable<Double>> entry : map
.entrySet()) {
int new_tag_index = model_.getTagTable().toIndex(
entry.getKey());
double count = entry.getValue().get();
if (pair_constraint_ == PairConstraint.weighted) {
value += count * scores[new_tag_index];
updates[new_tag_index] += count;
} else {
if (new_tag_index != tag_index) {
updates[new_tag_index] = 0;
}
}
}
}
if (map == null || pair_constraint_ == PairConstraint.simple) {
value += scores[tag_index];
updates[tag_index] += 1.0;
}
}
} else {
for (int tag_index : instance.getTagIndexes()) {
value += scores[tag_index];
updates[tag_index] += 1.0;
}
}
return value;
}
private double binaryUpdate(double[] scores, double[] updates,
int num_tags, SimpleAnalyzerInstance instance) {
double value = 0;
for (int tag_index = 0; tag_index < num_tags; tag_index++) {
double sum = Numerics.sumLogProb(scores[tag_index], 0);
value -= sum;
updates[tag_index] = -Math.exp(scores[tag_index] - sum);
}
for (int tag_index : instance.getTagIndexes()) {
value += scores[tag_index];
updates[tag_index] += 1.0;
}
return value;
}
@Override
public int getNumParameters() {
// System.err.println("getNumParameters");
return weights_.length;
}
@Override
public double getParameter(int arg0) {
throw new UnsupportedOperationException();
}
@Override
public void getParameters(double[] params) {
// System.err.println("getParameters");
System.arraycopy(weights_, 0, params, 0, weights_.length);
}
@Override
public void setParameter(int arg0, double arg1) {
throw new UnsupportedOperationException();
}
@Override
public void setParameters(double[] params) {
// System.err.println("setParameters");
System.arraycopy(params, 0, weights_, 0, weights_.length);
update();
}
@Override
public double getValue() {
// System.err.println("getValue");
return value_;
}
@Override
public void getValueGradient(double[] gradient) {
// System.err.println("getValueGradient " + gradient_.length + " " +
// gradient.length);
System.arraycopy(gradient_, 0, gradient, 0, gradient_.length);
}
}