// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.core;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import marmot.core.lattice.SequenceViterbiLattice;
import marmot.core.lattice.SumLattice;
import marmot.core.lattice.ViterbiLattice;
import marmot.core.lattice.ZeroOrderSumLattice;
import marmot.core.lattice.ZeroOrderViterbiLattice;
public class PerceptronTrainer implements Trainer {
private int steps_;
private boolean shuffle_;
private boolean verbose_;
private boolean averaging_;
private long seed_;
@Override
public void train(Tagger tagger, Collection<Sequence> in_sequences,
Evaluator evaluator) {
Random rng = null;
if (shuffle_) {
if (seed_ == 0) {
rng = new Random();
} else {
rng = new Random(seed_);
}
}
List<Sequence> sequences = new ArrayList<Sequence>(in_sequences);
int fraction = Math.max(sequences.size() / 4, 1);
WeightVector weights = tagger.getWeightVector();
assert weights != null;
double[] sum_weights = null;
if (averaging_) {
sum_weights = new double[weights.getWeights().length];
}
Model model = tagger.getModel();
for (int step = 0; step < steps_; step++) {
if (verbose_)
System.err.println("step: " + step);
if (shuffle_)
Collections.shuffle(sequences, rng);
int current_sentence = 0;
long train_time = System.currentTimeMillis();
for (Sequence sequence : sequences) {
SumLattice sum_lattice = tagger.getSumLattice(true, sequence);
List<List<State>> candidates = sum_lattice.getCandidates();
ViterbiLattice lattice;
if (sum_lattice instanceof ZeroOrderSumLattice) {
lattice = new ZeroOrderViterbiLattice(candidates, 1, false);
} else {
lattice = new SequenceViterbiLattice(candidates,
model.getBoundaryState(tagger.getNumLevels() - 1),
1, false);
}
List<Integer> best_sequence = lattice.getViterbiSequence()
.getStates();
List<Integer> gold_sequence = sum_lattice.getGoldCandidates();
if (!gold_sequence.equals(best_sequence)) {
update(weights, candidates, gold_sequence, +1);
update(weights, candidates, best_sequence, -1);
if (averaging_) {
double[] current_weights = weights.getWeights();
int amount = sequences.size() - current_sentence;
assert amount > 0;
weights.setWeights(sum_weights);
update(weights, candidates, gold_sequence, +amount);
update(weights, candidates, best_sequence, -amount);
weights.setWeights(current_weights);
}
}
current_sentence++;
if (current_sentence % fraction == 0) {
if (verbose_)
System.err
.format("Processed %d sentences at %g sentence/s \n",
current_sentence,
current_sentence
/ ((System.currentTimeMillis() - train_time) / 1000.));
}
}
if (averaging_) {
double[] current_weights = weights.getWeights();
for (int i = 0; i < current_weights.length; i++) {
double scaling = (step + 1) * sequences.size();
assert scaling > 0;
current_weights[i] = sum_weights[i] / scaling;
scaling = (step + 2) / (double) (step + 1);
assert scaling > 0;
assert scaling < 2 + 1e-5;
sum_weights[i] *= scaling;
}
}
if (evaluator != null && verbose_) {
weights.setExtendFeatureSet(false);
evaluator.eval(tagger);
weights.setExtendFeatureSet(true);
}
}
weights.setExtendFeatureSet(false);
}
private void update(WeightVector weights, List<List<State>> candidates,
List<Integer> sequence, double amount) {
int last_candidate_index = 0;
for (int index = 0; index < sequence.size(); index++) {
int candidate_index = sequence.get(index);
State state = candidates.get(index).get(candidate_index);
weights.updateWeights(state, amount, false);
State transition = state.getTransition(last_candidate_index);
weights.updateWeights(transition, amount, true);
last_candidate_index = candidate_index;
}
}
@Override
public void setOptions(Options options) {
steps_ = options.getNumIterations();
shuffle_ = options.getShuffle();
verbose_ = options.getVerbose();
averaging_ = options.getAveraging();
seed_ = options.getSeed();
}
}