/*******************************************************************************
* Copyright (c) 2012 György Orosz, Attila Novák.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the GNU Lesser Public License v3
* which accompanies this distribution, and is available at
* http://www.gnu.org/licenses/
*
* This file is part of PurePos.
*
* PurePos is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* PurePos is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser Public License for more details.
*
* Contributors:
* György Orosz - initial API and implementation
******************************************************************************/
package hu.ppke.itk.nlpg.purepos.decoder;
import hu.ppke.itk.nlpg.purepos.model.internal.CompiledModel;
import hu.ppke.itk.nlpg.purepos.model.internal.NGram;
import hu.ppke.itk.nlpg.purepos.morphology.IMorphologicalAnalyzer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.commons.lang3.tuple.Pair;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.collect.Table.Cell;
/**
* Decoder that implements the Viterbi search method speed up with using beams.
*
* @author György Orosz
*
*/
public class BeamedViterbi extends AbstractDecoder {
public BeamedViterbi(CompiledModel<String, Integer> model,
IMorphologicalAnalyzer morphologicalAnalyzer, double logTheta,
double sufTheta, int maxGuessedTags) {
super(model, morphologicalAnalyzer, logTheta, sufTheta, maxGuessedTags);
}
// protected Logger logger = Logger.getLogger(this.getClass());
@Override
public List<Pair<List<Integer>, Double>> decode(List<String> observations,
int maxResultsNumber) {
List<String> obs = prepareObservations(observations);
NGram<Integer> startNGram = createInitialElement();
List<Pair<List<Integer>, Double>> tagSeqList = beamedSearch(startNGram,
obs, maxResultsNumber);
List<Pair<List<Integer>, Double>> ret = cleanResults(tagSeqList);
return ret;
}
//
// public List<Integer> beamedSearch(final NGram<Integer> start,
// final List<String> obs) {
// return beamedSearch(start, obs, 1).get(0);
//
// }
public List<Pair<List<Integer>, Double>> beamedSearch(
final NGram<Integer> start,
final List<String> observations,
int resultsNumber) {
HashMap<NGram<Integer>, Node> beam = new HashMap<NGram<Integer>, Node>();
beam.put(start, startNode(start));
boolean isFirst = true;
int pos = 0;
for (String obs : observations) {
HashMap<NGram<Integer>, Node> newBeam = new HashMap<NGram<Integer>, Node>();
Table<NGram<Integer>, Integer, Double> nextProbs = HashBasedTable
.create();
Map<NGram<Integer>, Double> obsProbs = new HashMap<NGram<Integer>, Double>();
Set<NGram<Integer>> contexts = beam.keySet();
Map<NGram<Integer>, Map<Integer, Pair<Double, Double>>> nexts = getNextProbs(
contexts, obs, pos, isFirst);
for (Map.Entry<NGram<Integer>, Map<Integer, Pair<Double, Double>>> nextsEntry : nexts
.entrySet()) {
NGram<Integer> context = nextsEntry.getKey();
Map<Integer, Pair<Double, Double>> nextContextProbs = nextsEntry
.getValue();
for (Map.Entry<Integer, Pair<Double, Double>> entry : nextContextProbs
.entrySet()) {
Integer tag = entry.getKey();
nextProbs.put(context, tag, entry.getValue().getLeft());
obsProbs.put(context.add(tag), entry.getValue().getRight());
}
}
for (Cell<NGram<Integer>, Integer, Double> cell : nextProbs
.cellSet()) {
Integer nextTag = cell.getColumnKey();
NGram<Integer> context = cell.getRowKey();
Double transVal = cell.getValue();
NGram<Integer> newState = context.add(nextTag);
Node from = beam.get(context);
double newVal = transVal + beam.get(context).getWeight();
update(newBeam, newState, newVal, from);
}
if (nextProbs.size() > 1)
for (NGram<Integer> tagSeq : newBeam.keySet()) {
Node node = newBeam.get(tagSeq);
Double obsProb = obsProbs.get(tagSeq);
node.setWeight(obsProb + node.getWeight());
}
beam = prune(newBeam);
isFirst = false;
++pos;
}
return findMax(beam, resultsNumber);
}
private List<Pair<List<Integer>, Double>> findMax(
final HashMap<NGram<Integer>, Node> beam, int resultsNumber) {
// Node max = Collections.max(beam.values());
// Node act = max;
// return decompose(max);
SortedSet<Node> sortedKeys = new TreeSet<Node>(beam.values());
List<Pair<List<Integer>, Double>> ret = new ArrayList<Pair<List<Integer>, Double>>();
Node max;
for (int i = 0; i < resultsNumber && !sortedKeys.isEmpty(); ++i) {
max = sortedKeys.last();
sortedKeys.remove(max);
List<Integer> maxTagSeq = decompose(max);
ret.add(Pair.of(maxTagSeq, max.weight));
}
return ret;
}
private HashMap<NGram<Integer>, Node> prune(final HashMap<NGram<Integer>, Node> beam) {
HashMap<NGram<Integer>, Node> ret = new HashMap<NGram<Integer>, Node>();
Node maxNode = Collections.max(beam.values());
Double max = maxNode.getWeight();
for (NGram<Integer> key : beam.keySet()) {
Node actNode = beam.get(key);
Double actVal = actNode.getWeight();
if (!(actVal < max - logTheta)) {
ret.put(key, actNode);
}
}
return ret;
}
private void update(HashMap<NGram<Integer>, Node> beam,
NGram<Integer> newState, Double newWeight, Node fromNode) {
if (!beam.containsKey(newState)) {
// logger.trace("\t\t\tAS: " + newNGram + " from " + context
// + " with " + newValue);
beam.put(newState, new Node(newState, newWeight, fromNode));
} else if (beam.get(newState).getWeight() < newWeight) {
// logger.trace("\t\t\tUS: " + old + " to " + newNGram + " from "
// + context + " with " + newValue);
beam.get(newState).setPrevious(fromNode);
beam.get(newState).setWeight(newWeight);
} else {
// logger.trace("\t\t\tNU: " + old + " to " + newNGram + " from "
// + context + " with " + newValue);
}
}
}