// Copyright 2014 Thomas Müller
// This file is part of HMMLA, which is licensed under GPLv3.
package hmmla.splitmerge;
import hmmla.hmm.Model;
import hmmla.hmm.Statistics;
import hmmla.hmm.Tree;
import hmmla.io.Sentence;
import hmmla.util.SymbolTable;
import hmmla.util.Tuple;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class Merger {
private LossEstimator estimator_;
public Merger(LossEstimator estimator) {
estimator_ = estimator;
}
public double merge(Model model, Iterable<Sentence> reader, double mergeFactor) {
SymbolTable<String> inputTable = model.getTagTable();
int N = (inputTable.size() - 1) / 2;
List<Tuple<Integer, Double>> tuples = new ArrayList<Tuple<Integer, Double>>(N);
estimator_.estimateLosses(model, reader, tuples);
Collections.sort(tuples);
double loss = merge(model, tuples,
(int) (N * mergeFactor));
return loss;
}
public double merge(Model model, List<Tuple<Integer, Double>> tuples,
int limit) {
assert limit <= tuples.size();
SymbolTable<String> tag_table = model.getTagTable();
SymbolTable<String> word_table = model.getWordTable();
Statistics statistics = model.getStatistics();
int N = (tag_table.size() - 1) / 2;
Set<Integer> set = new HashSet<Integer>();
double loss = 0;
for (int k = 0; k < limit; k++) {
Tuple<Integer, Double> t = tuples.get(k);
int lindex = t.x;
int rindex = lindex + N;
set.add(lindex);
double f;
for (int o = 0; o < word_table.size(); o++) {
f = statistics.getEmissions(rindex, o);
statistics.addEmissions(lindex, o, f);
statistics.setEmissions(rindex, o, 0);
}
for (int i = 0; i < tag_table.size(); i++) {
f = statistics.getTransitions(i, rindex);
statistics.addTransitions(i, lindex, f);
statistics.setTransitions(i, rindex, 0.0);
f = statistics.getTransitions(rindex, i);
statistics.addTransitions(lindex, i, f);
statistics.setTransitions(rindex, i, 0.0);
}
f = statistics.getTransitions(rindex, lindex);
statistics.addTransitions(lindex, lindex, f);
statistics.setTransitions(rindex, lindex, 0.0);
f = statistics.getTransitions(lindex, rindex);
statistics.addTransitions(lindex, lindex, f);
statistics.setTransitions(lindex, rindex, 0.0);
loss += t.y;
}
Map<String, Tree> clustering = model.getClustering();
SymbolTable<String> new_tagtable = new SymbolTable<String>();
List<Tuple<Integer, Integer>> indexes = new LinkedList<Tuple<Integer, Integer>>();
for (int i = 0; i < N + 1; i++) {
int nindex;
int lindex = i;
int rindex = i + N;
if (set.contains(lindex)) {
// merge
String leftName = tag_table.toSymbol(lindex);
String rightName = tag_table.toSymbol(rindex);
Tree parent = clustering.get(leftName).getParent();
parent.prune();
assert parent.getLeft() == null && parent.getRight() == null;
clustering.remove(leftName);
clustering.remove(rightName);
String name = parent.getName();
clustering.put(name, parent);
nindex = new_tagtable.toIndex(name, true);
indexes.add(new Tuple<Integer, Integer>(lindex, nindex));
} else {
nindex = new_tagtable.toIndex(tag_table.toSymbol(lindex),
true);
indexes.add(new Tuple<Integer, Integer>(lindex, nindex));
if (lindex != Model.BorderIndex) {
nindex = new_tagtable.toIndex(tag_table.toSymbol(rindex),
true);
indexes.add(new Tuple<Integer, Integer>(rindex, nindex));
}
}
}
Statistics new_statistics = new Statistics(indexes.size(), word_table.size());
for (Tuple<Integer, Integer> tuple : indexes) {
int newIndex = tuple.y;
int oldIndex = tuple.x;
for (int o = 0; o < word_table.size(); o++) {
double f = statistics.getEmissions(oldIndex, o);
new_statistics.setEmissions(newIndex, o, f);
statistics.setEmissions(oldIndex, o, 0.0);
}
for (Tuple<Integer, Integer> tuple2 : indexes) {
int newIndex2 = tuple2.y;
int oldIndex2 = tuple2.x;
double f;
f = new_statistics.getTransitions(newIndex, newIndex2);
assert f == 0.0;
f = statistics.getTransitions(oldIndex, oldIndex2);
new_statistics.setTransitions(newIndex, newIndex2, f);
statistics.setTransitions(oldIndex, oldIndex2, 0.0);
}
}
assert statistics.totalEmission() == 0.0;
assert statistics.totalTransmission() == 0.0;
model.setStatistics(new_statistics);
model.setTagTable(new_tagtable);
return loss;
}
}