package resa.migrate.plan;
import java.util.Arrays;
/**
* Created by ding on 14-6-6.
*/
public class KuhnMunkres {
private int maxN, n, lenX, lenY;
private double[][] weights;
private boolean[] visitX, visitY;
private double[] lx, ly;
private double[] slack;
private int[] match;
public KuhnMunkres(int maxN) {
this.maxN = maxN;
visitX = new boolean[maxN];
visitY = new boolean[maxN];
lx = new double[maxN];
ly = new double[maxN];
slack = new double[maxN];
match = new int[maxN];
}
public int[][] getMaxBipartie(double weight[][], double[] result) {
if (!preProcess(weight)) {
throw new IllegalArgumentException("Data overflow, max num is " + maxN);
}
//initialize memo data for class
//initialize label X and Y
Arrays.fill(ly, 0);
Arrays.fill(lx, 0);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (lx[i] < weights[i][j]) {
lx[i] = weights[i][j];
}
}
}
//find a match for each X point
for (int u = 0; u < n; u++) {
Arrays.fill(slack, 0x7fffffff);
while (true) {
Arrays.fill(visitX, false);
Arrays.fill(visitY, false);
if (findPath(u)) //if find it, go on to the next point
break;
//otherwise update labels so that more edge will be added in
double inc = 0x7fffffff;
for (int v = 0; v < n; v++) {
if (!visitY[v] && slack[v] < inc) {
inc = slack[v];
}
}
for (int i = 0; i < n; i++) {
if (visitX[i]) {
lx[i] -= inc;
}
if (visitY[i]) {
ly[i] += inc;
}
}
}
}
result[0] = 0.0;
for (int i = 0; i < n; i++) {
if (match[i] >= 0) {
result[0] += weights[match[i]][i];
}
}
return matchResult();
}
public int[][] matchResult() {
int len = Math.min(lenX, lenY);
int[][] res = new int[len][2];
int count = 0;
for (int i = 0; i < lenY; i++) {
if (match[i] >= 0 && match[i] < lenX) {
res[count][0] = match[i];
res[count++][1] = i;
}
}
return res;
}
private boolean preProcess(double[][] weight) {
if (weight == null) {
return false;
}
lenX = weight.length;
lenY = weight[0].length;
if (lenX > maxN || lenY > maxN) {
return false;
}
Arrays.fill(match, -1);
n = Math.max(lenX, lenY);
weights = new double[n][n];
for (int i = 0; i < lenX; i++) {
for (int j = 0; j < lenY; j++) {
weights[i][j] = weight[i][j];
}
}
return true;
}
private boolean findPath(int u) {
visitX[u] = true;
for (int v = 0; v < n; v++) {
if (!visitY[v]) {
double temp = lx[u] + ly[v] - weights[u][v];
if (temp == 0.0) {
visitY[v] = true;
if (match[v] == -1 || findPath(match[v])) {
match[v] = u;
return true;
}
} else {
slack[v] = Math.min(slack[v], temp);
}
}
}
return false;
}
}