// Copyright 2013 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package marmot.util; import java.util.List; public class GeneralLevenshteinLattice<T> { protected int[][] cost_lattice_; protected short[][] op_lattice_; protected final static short START = 1; protected final static short INSERT = 2; protected final static short DELETE = 4; protected final static short COPY = 8; protected final static short REPLACE = 16; protected List<T> input_; protected List<T> output_; protected int replace_cost_; protected int insert_cost_; protected int delete_cost_; private boolean initialized_; public GeneralLevenshteinLattice(List<T> input, List<T> output) { this(input, output, 1, 1, 2); } public GeneralLevenshteinLattice(List<T> input, List<T> output, int insert_cost, int delete_cost, int replace_cost) { input_ = input; output_ = output; replace_cost_ = replace_cost; insert_cost_ = insert_cost; delete_cost_ = delete_cost; initialized_ = false; } protected void init() { if (! initialized_) { fillLattice(); } initialized_ = true; } protected int min(int a, int b, int c) { return Math.min(a, Math.min(b, c)); } protected void fillLattice() { int input_length = input_.size(); int output_length = output_.size(); cost_lattice_ = new int[input_length + 1][output_length + 1]; op_lattice_ = new short[input_length + 1][output_length + 1]; op_lattice_[0][0] = START; for (int input_index = 1; input_index <= input_length; input_index++) { cost_lattice_[input_index][0] = delete_cost_ * input_index; op_lattice_[input_index][0] = DELETE; } for (int output_index = 1; output_index <= output_length; output_index++) { cost_lattice_[0][output_index] = insert_cost_ * output_index; op_lattice_[0][output_index] = INSERT; } for (int input_index = 1; input_index <= input_length; input_index++) { T current_input = input_.get(input_index - 1); for (int output_index = 1; output_index <= output_length; output_index++) { T current_output = output_.get(output_index - 1); short diag_op; int diag_cost; if (current_input.equals(current_output)) { diag_op = COPY; diag_cost = getCopyCost(input_index); } else { diag_op = REPLACE; diag_cost = getReplaceCost(current_input, current_output); } int minimal_diag_cost = cost_lattice_[(input_index - 1)][(output_index - 1)] + diag_cost; int minimal_delete_cost = cost_lattice_[(input_index - 1)][output_index] + delete_cost_; int minimal_insert_cost = cost_lattice_[input_index][(output_index - 1)] + insert_cost_; int minimal_cost = min(minimal_delete_cost, minimal_insert_cost, minimal_diag_cost); cost_lattice_[input_index][output_index] = minimal_cost; short minimal_cost_op = 0; if (minimal_cost == minimal_diag_cost) { minimal_cost_op |= diag_op; } if (minimal_cost == minimal_delete_cost) { minimal_cost_op |= DELETE; } if (minimal_cost == minimal_insert_cost) { minimal_cost_op |= INSERT; } op_lattice_[input_index][output_index] = minimal_cost_op; } } } protected int getCopyCost(int input_index) { return 0; } protected int getReplaceCost(T input, T output) { return replace_cost_; } public int getDistance() { init(); return cost_lattice_[input_.size()][output_.size()]; } }