// Copyright 2014 Thomas Müller // This file is part of HMMLA, which is licensed under GPLv3. package hmmla.splitmerge; import hmmla.hmm.HmmModel; import hmmla.hmm.Model; import hmmla.hmm.Statistics; import hmmla.io.Sentence; import hmmla.util.AbstractSPMDCallable; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; public class ConcurrentEmTrainer extends EmTrainer { private ArrayList<Worker> workers_ = null; private class Worker extends AbstractSPMDCallable<Sentence, Double> { private SimpleEmTrainer trainer_; private Model model_; private boolean update_; public Worker() { trainer_ = new SimpleEmTrainer(); } public void reset(Iterator<Sentence> iter, Model model, HmmModel normalizedStatistics, boolean update) { super.reset(iter, 0.0); update_ = update; if (update) { // Create a shallow copy of model an replace // statistics with a copy of statistics. model_ = new Model(model); Statistics statistics = new Statistics(model_.getTagTable() .size(), model_.getWordTable().size()); model_.setStatistics(statistics); } else { model_ = model; } trainer_.reset(model_, normalizedStatistics); } @Override protected Double apply(Sentence sentence, Double out) { out += trainer_.estep(sentence, update_); return out; } public Statistics getStatistics() { return model_.getStatistics(); } } public ConcurrentEmTrainer(int threadNumber) { workers_ = new ArrayList<Worker>(threadNumber); for (int i = 0; i < threadNumber; i++) { this.workers_.add(new Worker()); } } @Override public double estep(Model model, HmmModel normalizedStatistics, Iterable<Sentence> reader, boolean update) { final Iterator<Sentence> pairIterator = reader.iterator(); Iterator<Sentence> iterator = new Iterator<Sentence>() { @Override public boolean hasNext() { return pairIterator.hasNext(); } @Override public Sentence next() { return pairIterator.next(); } @Override public void remove() { throw new UnsupportedOperationException(); } }; for (Worker w : workers_) { w.reset(iterator, model, normalizedStatistics, update); } ExecutorService executorService = Executors.newFixedThreadPool(workers_ .size()); List<Future<Double>> results = null; try { results = executorService.invokeAll(workers_); } catch (InterruptedException e) { throw new RuntimeException(e); } double ll = 0.0; for (Future<Double> f : results) { try { ll += f.get(); } catch (InterruptedException e) { throw new RuntimeException(e); } catch (ExecutionException e) { throw new RuntimeException(e); } } if (update) { Statistics statistics = model.getStatistics(); statistics.setZero(); for (Worker w : workers_) { statistics.add(w.getStatistics()); } } executorService.shutdown(); return ll; } }