/*
* Copyright 2014 Alpha Cephei Inc.
* All Rights Reserved. Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*
*/
package edu.cmu.sphinx.alignment;
import static java.lang.Math.abs;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.util.Arrays.fill;
import static java.util.Collections.emptyList;
import java.util.*;
import edu.cmu.sphinx.util.Range;
import edu.cmu.sphinx.util.Utilities;
/**
*
* @author Alexander Solovets
*/
public class LongTextAligner {
private final class Alignment {
public final class Node {
private final int databaseIndex;
private final int queryIndex;
private Node(int row, int column) {
this.databaseIndex = column;
this.queryIndex = row;
}
public int getDatabaseIndex() {
return shifts.get(databaseIndex - 1);
}
public int getQueryIndex() {
return indices.get(queryIndex - 1);
}
public String getQueryWord() {
if (queryIndex > 0)
return query.get(getQueryIndex());
return null;
}
public String getDatabaseWord() {
if (databaseIndex > 0)
return reftup.get(getDatabaseIndex());
return null;
}
public int getValue() {
if (isBoundary())
return max(queryIndex, databaseIndex);
return hasMatch() ? 0 : 1;
}
public boolean hasMatch() {
return getQueryWord().equals(getDatabaseWord());
}
public boolean isBoundary() {
return queryIndex == 0 || databaseIndex == 0;
}
public boolean isTarget() {
return queryIndex == indices.size() &&
databaseIndex == shifts.size();
}
public List<Node> adjacent() {
List<Node> result = new ArrayList<Node>(3);
if (queryIndex < indices.size() &&
databaseIndex < shifts.size()) {
result.add(new Node(queryIndex + 1, databaseIndex + 1));
}
if (databaseIndex < shifts.size()) {
result.add(new Node(queryIndex, databaseIndex + 1));
}
if (queryIndex < indices.size()) {
result.add(new Node(queryIndex + 1, databaseIndex));
}
return result;
}
@Override
public boolean equals(Object object) {
if (!(object instanceof Node))
return false;
Node other = (Node) object;
return queryIndex == other.queryIndex &&
databaseIndex == other.databaseIndex;
}
@Override
public int hashCode() {
return 31 * (31 * queryIndex + databaseIndex);
}
@Override
public String toString() {
return String.format("[%d %d]", queryIndex, databaseIndex);
}
}
private final List<Integer> shifts;
private final List<String> query;
private final List<Integer> indices;
private final List<Node> alignment;
public Alignment(List<String> query, Range range) {
this.query = query;
indices = new ArrayList<Integer>();
Set<Integer> shiftSet = new TreeSet<Integer>();
for (int i = 0; i < query.size(); i++) {
if (tupleIndex.containsKey(query.get(i))) {
indices.add(i);
for (Integer shift : tupleIndex.get(query.get(i))) {
if (range.contains(shift))
shiftSet.add(shift);
}
}
}
shifts = new ArrayList<Integer>(shiftSet);
final Map<Node, Integer> cost = new HashMap<Node, Integer>();
PriorityQueue<Node> openSet = new PriorityQueue<Node>(1, new Comparator<Node>() {
public int compare(Node o1, Node o2) {
return cost.get(o1).compareTo(cost.get(o2));
}
});
Collection<Node> closedSet = new HashSet<Node>();
Map<Node, Node> parents = new HashMap<Node, Node>();
Node startNode = new Node(0, 0);
cost.put(startNode, 0);
openSet.add(startNode);
while (!openSet.isEmpty()) {
Node q = openSet.poll();
if (closedSet.contains(q))
continue;
if (q.isTarget()) {
List<Node> backtrace = new ArrayList<Node>();
while (parents.containsKey(q)) {
if (!q.isBoundary() && q.hasMatch())
backtrace.add(q);
q = parents.get(q);
}
alignment = new ArrayList<Node>(backtrace);
Collections.reverse(alignment);
return;
}
closedSet.add(q);
for (Node nb : q.adjacent()) {
if (closedSet.contains(nb))
continue;
// FIXME: move to appropriate location
int l = abs(indices.size() - shifts.size() - q.queryIndex +
q.databaseIndex) -
abs(indices.size() - shifts.size() -
nb.queryIndex +
nb.databaseIndex);
Integer oldScore = cost.get(nb);
Integer qScore = cost.get(q);
if (oldScore == null)
oldScore = Integer.MAX_VALUE;
if (qScore == null)
qScore = Integer.MAX_VALUE;
int newScore = qScore + nb.getValue() - l;
if (newScore < oldScore) {
cost.put(nb, newScore);
openSet.add(nb);
parents.put(nb, q);
}
}
}
alignment = emptyList();
}
public List<Node> getIndices() {
return alignment;
}
}
private final int tupleSize;
private final List<String> reftup;
private final HashMap<String, ArrayList<Integer>> tupleIndex;
private List<String> refWords;
/**
* Constructs new text aligner that servers requests for alignment of
* sequence of words with the provided database sequence. Sequences are
* aligned by tuples comprising one or more subsequent words.
*
* @param words list of words forming the database
* @param tupleSize size of a tuple, must be greater or equal to 1
*/
public LongTextAligner(List<String> words, int tupleSize) {
assert words != null;
assert tupleSize > 0;
this.tupleSize = tupleSize;
this.refWords = words;
int offset = 0;
reftup = getTuples(words);
tupleIndex = new HashMap<String, ArrayList<Integer>>();
for (String tuple : reftup) {
ArrayList<Integer> indexes = tupleIndex.get(tuple);
if (indexes == null) {
indexes = new ArrayList<Integer>();
tupleIndex.put(tuple, indexes);
}
indexes.add(offset++);
}
}
/**
* Aligns query sequence with the previously built database.
* @param query list of words to look for
*
* @return indices of alignment
*/
public int[] align(List<String> query) {
return align(query, new Range(0, refWords.size()));
}
/**
* Aligns query sequence with the previously built database.
* @param words list words to look for
* @param range range of database to look for alignment
*
* @return indices of alignment
*/
public int[] align(List<String> words, Range range) {
if (range.upperEndpoint() - range.lowerEndpoint() < tupleSize || words.size() < tupleSize) {
return alignTextSimple(refWords.subList(range.lowerEndpoint(), range.upperEndpoint()), words, range.lowerEndpoint());
}
int[] result = new int[words.size()];
fill(result, -1);
int lastIndex = 0;
for (Alignment.Node node : new Alignment(getTuples(words), range)
.getIndices()) {
// for (int j = 0; j < tupleSize; ++j)
lastIndex = max(lastIndex, node.getQueryIndex());
for (; lastIndex < node.getQueryIndex() + tupleSize; ++lastIndex)
result[lastIndex] = node.getDatabaseIndex() + lastIndex -
node.getQueryIndex();
}
return result;
}
/**
* Makes list of tuples of the given size out of list of words.
*
* @param words words
* @return list of tuples of size {@link #tupleSize}
*/
private List<String> getTuples(List<String> words) {
List<String> result = new ArrayList<String>();
LinkedList<String> tuple = new LinkedList<String>();
Iterator<String> it = words.iterator();
for (int i = 0; i < tupleSize - 1; i++) {
tuple.add(it.next());
}
while (it.hasNext()) {
tuple.addLast(it.next());
result.add(Utilities.join(tuple));
tuple.removeFirst();
}
return result;
}
static int[] alignTextSimple(List<String> database, List<String> query,
int offset) {
int n = database.size() + 1;
int m = query.size() + 1;
int[][] f = new int[n][m];
f[0][0] = 0;
for (int i = 1; i < n; ++i) {
f[i][0] = i;
}
for (int j = 1; j < m; ++j) {
f[0][j] = j;
}
for (int i = 1; i < n; ++i) {
for (int j = 1; j < m; ++j) {
int match = f[i - 1][j - 1];
String refWord = database.get(i - 1);
String queryWord = query.get(j - 1);
if (!refWord.equals(queryWord)) {
++match;
}
int insert = f[i][j - 1] + 1;
int delete = f[i - 1][j] + 1;
f[i][j] = min(match, min(insert, delete));
}
}
--n;
--m;
int[] alignment = new int[m];
Arrays.fill(alignment, -1);
while (m > 0) {
if (n == 0) {
--m;
} else {
String refWord = database.get(n - 1);
String queryWord = query.get(m - 1);
if (f[n - 1][m - 1] <= f[n - 1][m - 1]
&& f[n - 1][m - 1] <= f[n][m - 1]
&& refWord.equals(queryWord)) {
alignment[--m] = --n + offset;
} else {
if (f[n - 1][m] < f[n][m - 1]) {
--n;
} else {
--m;
}
}
}
}
return alignment;
}
}