package experimental.ising;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import marmot.util.Numerics;
import org.javatuples.Pair;
public class BruteForceUnit {
public static void main(String[] args) {
int correct = 0;
int total = 100;
for (int i = 0; i < total; ++i) {
if (test(3)) {
++correct;
}
}
System.out.println( ((double) correct / total));
}
public static boolean test(int numVariables) {
Random rand = new Random();
List<String> tagNames = new LinkedList<String>();
for (int i = 0; i < numVariables; ++i) {
tagNames.add("");
}
// adjust tree as seen fit
List<Pair<Integer,Integer>> pairs = new LinkedList<Pair<Integer,Integer>>();
pairs.add(new Pair<>(0,1));
pairs.add(new Pair<>(1,2));
pairs.add(new Pair<>(0,2));
//pairs.add(new Pair<>(1,3));
//pairs.add(new Pair<>(2,4));
//pairs.add(new Pair<>(4,5));
//pairs.add(new Pair<>(0,5));
// golden
List<Integer> golden = new ArrayList<Integer>();
/*
for (int i = 0; i < numVariables; ++i) {
if (rand.nextBoolean()) {
golden.add(1);
} else {
golden.add(0);
}
}*/
for (int i = 0; i < numVariables; ++i) {
golden.add(0);
}
IsingFactorGraph fg = new IsingFactorGraph(numVariables, pairs, golden, tagNames);
int numParameters = 2 * fg.unaryFactors.size() + 4 * fg.binaryFactors.size();
double[] parameters = new double[numParameters];
//for (int i = 0; i < parameters.length; ++i) {
// parameters[i] = 1.0;
//}
//parameters[0] = 1.0;
parameters[2 * fg.unaryFactors.size()] = 1.0;
parameters[2 * fg.unaryFactors.size() + 2] = 1.0;
//System.out.println(Arrays.toString(parameters));
// random unary potentials
int counter = 0;
for (UnaryFactor uf : fg.unaryFactors) {
//parameters[counter] = rand.nextGaussian();
uf.setPotential(0, Math.exp(parameters[counter]));
++counter;
//parameters[counter] = rand.nextGaussian();
uf.setPotential(1, Math.exp(parameters[counter]));
++counter;
}
// random binary potentials
for (BinaryFactor bf : fg.binaryFactors) {
//parameters[counter] = rand.nextGaussian();
bf.setPotential(0, 0, Math.exp(parameters[counter]));
++counter;
//parameters[counter] = rand.nextGaussian();
bf.setPotential(0, 1, Math.exp(parameters[counter]));
++counter;
//parameters[counter] = rand.nextGaussian();
bf.setPotential(1, 0, Math.exp(parameters[counter]));
++counter;
//parameters[counter] = rand.nextGaussian();
bf.setPotential(1, 1, Math.exp(parameters[counter]));
++counter;
}
//System.out.println("...brute-force inference...");
double[][] marginalsBruteForce = fg.inferenceBruteForce();
//System.out.println("...belief propagation...");
fg.inference(10, 1.0);
for (int n = 0; n < numVariables; ++n) {
double[] marginal = fg.variables.get(n).getBelief().measure;
//System.out.println(fg.approximateZ());
//System.out.println(fg.logLikelihood());
//System.out.println("REAL GRAD:\t" + Arrays.toString(fg.unfeaturizedGradient()));
//System.out.println("FINITE DIFF:\t" + Arrays.toString(fg.finiteDifference(parameters, 0.0001)));
/*
System.out.println("BRUTE FORCE 0:" + 0 + "\t" + Arrays.toString(marginalsBruteForce[0]));
System.out.println("MARGINALS 0:" + 0 + "\t" + Arrays.toString(marginal));
System.out.println("BRUTE FORCE 1:" + 1 + "\t" + Arrays.toString(marginalsBruteForce[1]));
System.out.println("MARGINALS 1:" + 1 + "\t" + Arrays.toString(marginal));
*/
if (!Numerics.approximatelyEqual(marginalsBruteForce[n],marginal,0.1)) {
//System.out.println("False");
return false;
} else {
//System.out.println("True");
}
//System.exit(0);
}
return true;
}
}