package experimental.igel;
import java.util.ArrayList;
import java.util.List;
import org.javatuples.Pair;
public class FactorGraph {
private List<SegmentVariable> variables;
private List<UnaryFactor> unaryFactors;
private List<BinaryFactor> binaryFactors;
private SemiMarkovFactor globalFactor;
private int numVariables;
public FactorGraph(Word word) {
this.numVariables = 0;
variables = new ArrayList<SegmentVariable>();
unaryFactors = new ArrayList<UnaryFactor>();
binaryFactors = new ArrayList<BinaryFactor>();
// create the factor graph
for (Pair<Integer,Integer> key : word.getPos2String().keySet()){
int startPos = key.getValue0();
int endPos = key.getValue1();
String segment = word.getPos2String().get(key);
this.numVariables += 1;
variables.add(new SegmentVariable(segment,startPos,endPos));
}
// create unary factors
for (SegmentVariable sv : variables) {
UnaryFactor uf = new UnaryFactor();
uf.getNeighbors().add(sv);
uf.getMessageIds().add(sv.getMessages().size());
sv.getMessages().add(new Message(2));
unaryFactors.add(uf);
}
// create binary factors
for (SegmentVariable sv1 : variables) {
for (SegmentVariable sv2 : variables) {
if (sv1.getEndPos() == sv2.getStartPos()) {
System.out.println(sv1.getSegment() + "\t" + sv2.getSegment());
BinaryFactor bf = new BinaryFactor();
// add messages
bf.getNeighbors().add(sv1);
bf.getMessageIds().add(sv1.getMessages().size());
sv1.getMessages().add(new Message(2));
// add message
bf.getNeighbors().add(sv2);
bf.getMessageIds().add(sv2.getMessages().size());
sv2.getMessages().add(new Message(2));
binaryFactors.add(bf);
}
}
}
}
public void inferenceBP(int maxIterations, double convergence) {
for (int iterNum = 0; iterNum < maxIterations; ++iterNum) {
// unary factors
for (UnaryFactor uf : this.unaryFactors){
uf.passMessages();
}
// binary factors
for (BinaryFactor bf : this.binaryFactors) {
bf.passMessages();
}
// global factor
// TODO
// variables
for (SegmentVariable sv : this.variables) {
sv.passMessages();
}
}
}
public void inferenceBruteForce() {
for(int i = 0; i < Math.pow(2,this.numVariables); i++) {
double configurationScore = 1.0;
String format="%0"+this.numVariables+"d";
String newString = String.format(format,Integer.valueOf(Integer.toBinaryString(i)));
System.out.println(newString);
/*
List<Integer> configuration = new ArrayList<Integer>();
for (int n = 0; n < this.numVariables; ++n) {
configuration.add(Character.getNumericValue(newString.charAt(n)));
}
// sum over unary factors
for (UnaryFactor uf : this.unaryFactors) {
int value = configuration.get(uf.getI());
configurationScore *= uf.potential[value];
}
// sum over binary factors
for (BinaryFactor bf : this.binaryFactors) {
int value1 = configuration.get(bf.getI());
int value2 = configuration.get(bf.getJ());
configurationScore *= bf.potential[value1][value2];
}
//System.out.println(configuration);
//System.out.println(configurationScore);
//System.out.println();
//add configuration score
for (int n = 0; n < this.numVariables; ++n) {
int value = configuration.get(n);
marginals[n][value] += configurationScore;
}
*/
}
}
}