// Copyright 2014 Thomas Müller // This file is part of HMMLA, which is licensed under GPLv3. package hmmla.splitmerge; import hmmla.hmm.Model; import hmmla.hmm.Statistics; import hmmla.hmm.Tree; import hmmla.util.SymbolTable; import java.util.HashMap; import java.util.Map; import java.util.Random; public class Splitter { private Random rng_; private double randomness_; public Splitter(double randomness, Random rng) { this.randomness_ = randomness; this.rng_ = rng; } public void split(Model model) { SymbolTable<String> tag_table = model.getTagTable(); SymbolTable<String> word_table = model.getWordTable(); Map<String, Tree> clustering = model.getClustering(); Statistics stats = model.getStatistics(); int num_tags = tag_table.size(); SymbolTable<String> new_tag_table = new SymbolTable<String>(); Statistics new_statistics = new Statistics((num_tags - 1) * 2 + 1, word_table.size()); Map<String, Tree> new_clustering = new HashMap<String, Tree>(); new_tag_table.toIndex(Model.BorderSymbol, true); model.getTopLevel().get(Model.BorderSymbol).incrementLevel(); for (int i = 1; i < num_tags; i++) { String name = tag_table.toSymbol(i); String lname = String.format("%s0", name); int left = new_tag_table.toIndex(lname, true); assert left == i; assert clustering.containsKey(name); new_clustering.put(lname, clustering.get(name).setLeft(lname)); } for (int i = 1; i < num_tags; i++) { String name = tag_table.toSymbol(i); String rname = String.format("%s1", name); int right = new_tag_table.toIndex(rname, true); assert right == i + num_tags - 1; assert clustering.containsKey(name); new_clustering.put(rname, clustering.get(name).setRight(rname)); } for (int index = 1; index < num_tags; index++) { double random_double; double freq; int left = index; int right = index + num_tags - 1; for (int o = 0; o < word_table.size(); o++) { freq = 0.5 * stats.getEmissions(index, o); random_double = (rng_.nextDouble() - 0.5) * 2.0 * freq * randomness_; assert (freq + random_double >= 0.0) && (freq - random_double >= 0.0); new_statistics.setEmissions(left, o, freq + random_double); new_statistics.setEmissions(right, o, freq - random_double); } for (int i = 1; i < tag_table.size(); i++) { freq = 0.25 * stats.getTransitions(i, index); random_double = (rng_.nextDouble() - 0.5) * 2.0 * freq * randomness_; new_statistics.setTransitions(i, left, freq + random_double); new_statistics.setTransitions(i, right, freq - random_double); random_double = (rng_.nextDouble() - 0.5) * 2.0 * freq * randomness_; new_statistics.setTransitions(i + num_tags - 1, left, freq + random_double); new_statistics.setTransitions(i + num_tags - 1, right, freq - random_double); } freq = 0.5 * stats.getTransitions(Model.BorderIndex, index); random_double = (rng_.nextDouble() - 0.5) * 2.0 * freq * randomness_; new_statistics.setTransitions(Model.BorderIndex, left, freq + random_double); new_statistics.setTransitions(Model.BorderIndex, right, freq - random_double); freq = 0.5 * stats.getTransitions(index, Model.BorderIndex); random_double = (rng_.nextDouble() - 0.5) * 2.0 * freq * randomness_; new_statistics.setTransitions(left, Model.BorderIndex, freq + random_double); new_statistics.setTransitions(right, Model.BorderIndex, freq - random_double); } model.setStatistics(new_statistics); model.setTagTable(new_tag_table); model.setClustering(new_clustering); } }