package hu.ppke.itk.nlpg.purepos.decoder; import hu.ppke.itk.nlpg.purepos.model.internal.CompiledModel; import hu.ppke.itk.nlpg.purepos.model.internal.History; import hu.ppke.itk.nlpg.purepos.model.internal.NGram; import hu.ppke.itk.nlpg.purepos.morphology.IMorphologicalAnalyzer; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.commons.lang3.tuple.Pair; import com.google.common.collect.MinMaxPriorityQueue; public class BeamSearch extends AbstractDecoder { protected Integer beamSize = 10; protected boolean fixedBeam = true; public BeamSearch(CompiledModel<String, Integer> model, IMorphologicalAnalyzer morphologicalAnalyzer, double logTheta, double sufTheta, int maxGuessedTags) { super(model, morphologicalAnalyzer, logTheta, sufTheta, maxGuessedTags); fixedBeam = false; } public BeamSearch(CompiledModel<String, Integer> model, IMorphologicalAnalyzer morphologicalAnalyzer, int beamSize, double sufTheta, int maxGuessedTags) { super(model, morphologicalAnalyzer, 0, sufTheta, maxGuessedTags); this.beamSize = beamSize; fixedBeam = true; } @Override public List<Pair<List<Integer>, Double>> decode(List<String> observations, int maxResultsNumber) { observations = prepareObservations(observations); MinMaxPriorityQueue<History> beam = beamSearch(observations); return getKTop(beam, maxResultsNumber); } private List<Pair<List<Integer>, Double>> getKTop( MinMaxPriorityQueue<History> beam, int maxResultsNumber) { List<Pair<List<Integer>, Double>> ret = new ArrayList<Pair<List<Integer>, Double>>(); int n = Math.min(maxResultsNumber, beam.size()); for (int i = 0; i < n; ++i) { History lastElement = beam.removeLast(); List<Integer> tagSeq = lastElement.getTagSeq().toList(); List<Integer> cleaned = clean(tagSeq); ret.add(Pair.of(cleaned, lastElement.getLogProb())); } return ret; } private List<Integer> clean(List<Integer> tagSeq) { return tagSeq.subList(model.getTaggingOrder(), tagSeq.size() - 1); } private MinMaxPriorityQueue<History> beamSearch(List<String> observations) { MinMaxPriorityQueue<History> beam = initBeam(); int position = 0; for (String word : observations) { Set<NGram<Integer>> contexts = collectContexts(beam); Map<NGram<Integer>, Map<Integer, Pair<Double, Double>>> probs = getNextProbs( contexts, word, position, position == 0); // try { beam = updateBeam(beam, probs); prune(beam); // } catch (Exception e) { // e.printStackTrace(); // } position++; } return beam; } private Set<NGram<Integer>> collectContexts( MinMaxPriorityQueue<History> beam) { Set<NGram<Integer>> ret = new HashSet<NGram<Integer>>(); for (History h : beam) { ret.add(h.getTagSeq()); } return ret; } private MinMaxPriorityQueue<History> updateBeam( MinMaxPriorityQueue<History> beam, Map<NGram<Integer>, Map<Integer, Pair<Double, Double>>> probs) { MinMaxPriorityQueue<History> newBeam = MinMaxPriorityQueue.create(); for (History h : beam) { NGram<Integer> context = h.getTagSeq(); Double oldProb = h.getLogProb(); Map<Integer, Pair<Double, Double>> transitions = probs.get(context); for (Map.Entry<Integer, Pair<Double, Double>> nexts : transitions .entrySet()) { Integer nextTag = nexts.getKey(); Pair<Double, Double> probVals = nexts.getValue(); NGram<Integer> newSeq = context.add(nextTag); Double newProb = oldProb + probVals.getLeft() + probVals.getRight(); newBeam.add(new History(newSeq, newProb)); } } return newBeam; } private void prune(MinMaxPriorityQueue<History> beam) { if (fixedBeam) { while (beam.size() > this.beamSize) { beam.removeFirst(); } } else { try { History max = beam.peekLast(); while (!(beam.peekFirst().getLogProb() > max.getLogProb() - logTheta)) { beam.removeFirst(); } } catch (Exception e) { e.printStackTrace(); } } } private MinMaxPriorityQueue<History> initBeam() { MinMaxPriorityQueue<History> beam = MinMaxPriorityQueue.create(); NGram<Integer> initNGram = createInitialElement(); beam.add(new History(initNGram, 0.0)); return beam; } protected NGram<Integer> createInitialElement() { int n = model.getTaggingOrder() - 1; ArrayList<Integer> startTags = new ArrayList<Integer>(); for (int j = 0; j <= n; ++j) { startTags.add(model.getBOSIndex()); } NGram<Integer> startNGram = new NGram<Integer>(startTags, model.getTaggingOrder()); return startNGram; } }