// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.util.eval;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class RandomizationTest implements SignificanceTest {
private Random random_;
public RandomizationTest() {
random_ = new Random();
}
public RandomizationTest(long seed) {
random_ = new Random(seed);
}
public double getSum(List<Double> scores) {
double sum = 0.0;
for (double score : scores) {
sum += score;
}
return sum;
}
public double test(Scorer scorer, String gold, String pred1, String pred2) {
List<Double> scores1 = scorer.getScores(gold, pred1);
System.err.println(pred1 + ": " + getSum(scores1));
List<Double> scores2 = scorer.getScores(gold, pred2);
System.err.println(pred2 + ": " + getSum(scores2));
List<Double> diffs = getDifferences(scores1, scores2, true);
if (diffs.isEmpty()) {
return 1.0;
}
System.err.println("|Diffs|: " + diffs.size());
double diff = getAbsoluteDifference(diffs, false);
int total = 1048576;
int error = 0;
for (int index = 0; index < total; index++) {
double random_diff = getAbsoluteDifference(diffs, true);
if (diff - random_diff < 1.e-10) {
error += 1;
}
}
return error / (double) total;
}
public static List<Double> getDifferences(List<Double> scores1,
List<Double> scores2, boolean remove_zeroes) {
List<Double> list = new ArrayList<Double>(scores1.size());
for (int index = 0; index < scores1.size(); index++) {
double diff = scores1.get(index) - scores2.get(index);
if (remove_zeroes && Math.abs(diff) < 1e-99) {
continue;
}
list.add(diff);
}
return list;
}
private double getAbsoluteDifference(List<Double> differences, boolean random) {
double diff = 0.;
for (double current_diff : differences) {
if (random && random_.nextBoolean()) {
diff -= current_diff;
} else {
diff += current_diff;
}
}
return Math.abs(diff);
}
public static void main(String[] args) {
String scorer_string = args[0];
Scorer scorer;
try {
scorer = (Scorer)(Class.forName(scorer_string).newInstance());
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
String actual = args[1];
String prediction1 = args[2];
String prediction2 = args[3];
if (args.length > 4 ) {
String[] key_value = args[4].split("=");
scorer.setOption(key_value[0], key_value[1]);
}
SignificanceTest test = new RandomizationTest();
DecimalFormat df = new DecimalFormat("0.#######################");
System.out.println(df.format(test.test(scorer, actual, prediction1, prediction2)));
}
}