package chipmunk.segmenter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import marmot.util.DynamicWeights;
import marmot.util.Numerics;
public class SegmenterTrainer {
private SegmenterOptions options_;
public SegmenterTrainer(SegmenterOptions options) {
options_ = options;
}
public Segmenter train(Collection<Word> words) {
SegmenterModel model = new SegmenterModel();
model.init(options_, words);
if (options_.getBoolean(SegmenterOptions.CRF_MODE)) {
if (options_.getBoolean(SegmenterOptions.VERBOSE))
System.err.println("Training CRF");
run_crf(model, words);
} else {
if (options_.getBoolean(SegmenterOptions.VERBOSE))
System.err.println("Training Perceptron");
run_perceptron(model, words);
}
model.setFinal();
Segmenter segmenter = new StatSegmenter(model);
return segmenter;
}
private void run_crf(SegmenterModel model, Collection<Word> words) {
SemiCrfObjective objective = new SemiCrfObjective(model, words, options_.getDouble(SegmenterOptions.PENALTY));
objective.init();
Optimizer optimizer = new LimitedMemoryBFGS(objective);
Logger.getLogger(optimizer.getClass().getName()).setLevel(Level.OFF);
try {
optimizer.optimize(1);
for (int i = 0; i < 200 && !optimizer.isConverged(); i++) {
optimizer.optimize(1);
}
} catch (IllegalArgumentException e) {
} catch (OptimizationException e) {
}
}
private void run_perceptron(SegmenterModel model, Collection<Word> words) {
DynamicWeights weights = new DynamicWeights(null);
DynamicWeights sum_weights = null;
if (options_.getBoolean(SegmenterOptions.AVERAGING)) {
sum_weights = new DynamicWeights(null);
}
model.setWeights(weights);
SegmentationDecoder decoder = new SegmentationDecoder(model);
int number;
List<Word> word_array = new ArrayList<>(words);
for (int iter = 0; iter < options_.getInt(SegmenterOptions.NUM_ITERATIONS); iter++) {
number = 0;
Collections.shuffle(word_array, options_.getRandom());
for (Word word : word_array) {
SegmentationInstance instance = model.getInstance(word);
SegmentationResult result = decoder.decode(instance);
double score = result.getScore();
double exact_score = model.getScore(instance, result);
assert Numerics.approximatelyEqual(score, exact_score) : String
.format("%d %d", score, exact_score);
if (!result.isCorrect(instance)) {
SegmentationResult closest_result = Scorer.closest(result,
instance.getResults(), instance.getLength());
model.update(instance, result, -1.);
model.update(instance, closest_result, +1.);
if (sum_weights != null) { /* averaging */
double amount = word_array.size() - number;
assert amount > 0;
model.setWeights(sum_weights);
model.update(instance, result, -amount);
model.update(instance, closest_result, +amount);
model.setWeights(weights);
}
}
number++;
}
if (sum_weights != null) { /* averaging */
double weights_scaling = 1. / ((iter + 1.) * word_array.size());
double sum_weights_scaling = (iter + 2.) / (iter + 1.);
for (int i = 0; i < weights.getLength(); i++) {
weights.set(i, sum_weights.get(i) * weights_scaling);
sum_weights
.set(i, sum_weights.get(i) * sum_weights_scaling);
}
}
}
}
}