// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.util; import java.util.LinkedList; import java.util.List; import java.util.ListIterator; public class LevenshteinLattice { protected int[][] cost_lattice_; protected short[][] op_lattice_; protected final static short START = 1; protected final static short INSERT = 2; protected final static short DELETE = 4; protected final static short COPY = 8; protected final static short REPLACE = 16; protected String input_; protected String output_; protected int replace_cost_; protected int insert_cost_; protected int delete_cost_; private boolean initialized_; public LevenshteinLattice(String input, String output) { this(input, output, 1, 1, 2); } public LevenshteinLattice(String input, String output, int insert_cost, int delete_cost, int replace_cost) { input_ = input; output_ = output; replace_cost_ = replace_cost; insert_cost_ = insert_cost; delete_cost_ = delete_cost; initialized_ = false; } protected void init() { if (! initialized_) { fillLattice(); } initialized_ = true; } protected int min(int a, int b, int c) { return Math.min(a, Math.min(b, c)); } protected void fillLattice() { int input_length = input_.length(); int output_length = output_.length(); cost_lattice_ = new int[input_length + 1][output_length + 1]; op_lattice_ = new short[input_length + 1][output_length + 1]; op_lattice_[0][0] = START; for (int input_index = 1; input_index <= input_length; input_index++) { cost_lattice_[input_index][0] = delete_cost_ * input_index; op_lattice_[input_index][0] = DELETE; } for (int output_index = 1; output_index <= output_length; output_index++) { cost_lattice_[0][output_index] = insert_cost_ * output_index; op_lattice_[0][output_index] = INSERT; } for (int input_index = 1; input_index <= input_length; input_index++) { char current_input = input_.charAt(input_index - 1); for (int output_index = 1; output_index <= output_length; output_index++) { char current_output = output_.charAt(output_index - 1); short diag_op; int diag_cost; if (current_input == current_output) { diag_op = COPY; diag_cost = getCopyCost(input_index); } else { diag_op = REPLACE; diag_cost = getReplaceCost(current_input, current_output); } int minimal_diag_cost = cost_lattice_[(input_index - 1)][(output_index - 1)] + diag_cost; int minimal_delete_cost = cost_lattice_[(input_index - 1)][output_index] + delete_cost_; int minimal_insert_cost = cost_lattice_[input_index][(output_index - 1)] + insert_cost_; int minimal_cost = min(minimal_delete_cost, minimal_insert_cost, minimal_diag_cost); cost_lattice_[input_index][output_index] = minimal_cost; short minimal_cost_op = 0; if (minimal_cost == minimal_diag_cost) { minimal_cost_op |= diag_op; } if (minimal_cost == minimal_delete_cost) { minimal_cost_op |= DELETE; } if (minimal_cost == minimal_insert_cost) { minimal_cost_op |= INSERT; } op_lattice_[input_index][output_index] = minimal_cost_op; } } } protected int getCopyCost(int input_index) { return 0; } protected int getReplaceCost(char input, char output) { return replace_cost_; } public String searchOperationSequence() { init(); StringBuilder sb = new StringBuilder(); int input_index = input_.length(); int output_index = output_.length(); boolean stop = false; while (!stop) { short op = op_lattice_[input_index][output_index]; if ((op & START) > 0) { stop = true; assert op == START; } else if ((op & COPY) > 0) { sb.append('C'); output_index--; input_index--; } else if ((op & REPLACE) > 0) { sb.append('R'); output_index--; input_index--; } else if ((op & INSERT) > 0) { sb.append('I'); output_index--; } else if ((op & DELETE) > 0) { sb.append('D'); input_index--; } else { throw new RuntimeException("Unexpected operation code: " + Integer.toBinaryString(op)); } } sb.reverse(); return sb.toString(); } public List<List<Character>> searchOperationSequences(boolean remove_redundant) { init(); int input_index = input_.length(); int output_index = output_.length(); List<List<Character>> lists = searchOperationSequences(input_index, output_index); if (remove_redundant) { ListIterator<List<Character>> iter = lists.listIterator(); while (iter.hasNext()) { List<Character> next = iter.next(); if (redundant(next)) { iter.remove(); } } } return lists; } public List<List<Character>> searchOperationSequences() { return searchOperationSequences(false); } public List<List<Character>> searchOperationSequences(int input_index, int output_index) { init(); short op = op_lattice_[input_index][output_index]; List<List<Character>> lists = new LinkedList<List<Character>>(); if ((op & START) > 0) { assert op == START; lists.add(new LinkedList<Character>()); } else { if ((op & COPY) > 0) { lists.addAll(appendToList( searchOperationSequences(input_index - 1, output_index - 1), 'C')); } if ((op & REPLACE) > 0) { lists.addAll(appendToList( searchOperationSequences(input_index - 1, output_index - 1), 'R')); } if ((op & INSERT) > 0) { lists.addAll(appendToList( searchOperationSequences(input_index, output_index - 1), 'I')); } if ((op & DELETE) > 0) { lists.addAll(appendToList( searchOperationSequences(input_index - 1, output_index), 'D')); } if (lists.isEmpty()) { throw new RuntimeException("Unexpected operation code: " + Integer.toBinaryString(op)); } } return lists; } protected List<List<Character>> appendToList(List<List<Character>> lists, char c) { for (List<Character> list : lists) { list.add(c); } return lists; } public int getDistance() { init(); return cost_lattice_[input_.length()][output_.length()]; } public boolean redundant(List<Character> seq) { char last_c = 'S'; for (char c : seq) { if (c == 'R' && last_c == 'D') { // Canonical form is RD return true; } if (c == 'R' && last_c == 'I') { // Canonical form is RI return true; } if (c == 'I' && last_c == 'D') { // Canonical form is R return true; } if (c == 'D' && last_c == 'I') { // Canonical form is R return true; } last_c = c; } return false; } }