package experimental.ising; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedList; import java.util.List; import org.javatuples.Pair; public class IsingFactorGraph { private String word; private int numVariables; protected List<Variable> variables; protected List<UnaryFactor> unaryFactors; protected List<BinaryFactor> binaryFactors; protected List<Integer> golden; protected int numParameters; protected UnaryFeatureExtractor ufe; private int inferenceIterations; public IsingFactorGraph(String word, UnaryFeatureExtractor ufe, int inferenceIterations, int numVariables, List<Pair<Integer,Integer>> pairwise, List<Integer> golden, List<String> tagNames) { this.numVariables = numVariables; this.variables = new ArrayList<Variable>(); this.unaryFactors = new ArrayList<UnaryFactor>(); this.binaryFactors = new ArrayList<BinaryFactor>(); this.word = word; this.ufe = ufe; this.inferenceIterations = inferenceIterations; this.golden = golden; // ADD VARIABLES AND UNARY FACTORS for (int i = 0; i < this.numVariables; ++i) { Variable v = new Variable(2,i,tagNames.get(i)); UnaryFactor uf = new UnaryFactor(word, tagNames.get(i), 2, i, ufe); // add neighbors to variable v.getNeighbors().add(uf); v.getMessageIds().add(0); v.getMessages().add(new Message(2)); // add neighbors to unary uf.getNeighbors().add(v); uf.getMessageIds().add(v.getMessages().size()); uf.getMessages().add(new Message(2)); // add to graph this.variables.add(v); this.unaryFactors.add(uf); } for (Pair<Integer,Integer> p : pairwise) { int i = p.getValue0(); int j = p.getValue1(); BinaryFactor bf = new BinaryFactor(2,2,i,j); // add neighbors to variable Variable v1 = this.variables.get(i); Variable v2 = this.variables.get(j); v1.getNeighbors().add(bf); v1.getMessageIds().add(0); bf.getMessages().add(new Message(2)); v2.getNeighbors().add(bf); v2.getMessageIds().add(1); bf.getMessages().add(new Message(2)); // add neighbors to factor bf.getNeighbors().add(v1); bf.getMessageIds().add(v1.getMessages().size()); v1.getMessages().add(new Message(2)); bf.getNeighbors().add(v2); bf.getMessageIds().add(v2.getMessages().size()); v2.getMessages().add(new Message(2)); // add to graph this.binaryFactors.add(bf); } this.numParameters = 2 * this.unaryFactors.size() + 4 * this.binaryFactors.size(); } public IsingFactorGraph(int numVariables2, List<Pair<Integer, Integer>> pairs, List<Integer> golden2, List<String> tagNames) { throw new UnsupportedOperationException(); } /** * Brute force inference for the Ising factor graph */ public double[][] inferenceBruteForce() { double[][] marginals = new double[this.numVariables][2]; double Z = 0.0; for(int i = 0; i < Math.pow(2,this.numVariables); i++) { double configurationScore = 1.0; String format="%0"+this.numVariables+"d"; String newString = String.format(format,Integer.valueOf(Integer.toBinaryString(i))); List<Integer> configuration = new ArrayList<Integer>(); for (int n = 0; n < this.numVariables; ++n) { configuration.add(Character.getNumericValue(newString.charAt(n))); } // sum over unary factors for (UnaryFactor uf : this.unaryFactors) { int value = configuration.get(uf.getI()); configurationScore *= uf.potential[value]; } // sum over binary factors for (BinaryFactor bf : this.binaryFactors) { int value1 = configuration.get(bf.getI()); int value2 = configuration.get(bf.getJ()); configurationScore *= bf.potential[value1][value2]; } Z += configurationScore; //add configuration score for (int n = 0; n < this.numVariables; ++n) { int value = configuration.get(n); marginals[n][value] += configurationScore; } } for (int n = 0; n < this.numVariables; ++n) { double Z_local = marginals[n][0] + marginals[n][1]; marginals[n][0] /= Z_local; marginals[n][1] /= Z_local; } return marginals; } public double betheFreeEnergy() { double betheFreeEnergy = 0.0; /* // binary factor beliefs for (BinaryFactor bf : this.binaryFactors) { bf.computeFactorBelief(); for (int i = 0; i < bf.getSize1(); ++i) { for (int j = 0; j < bf.getSize2(); ++j) { betheFreeEnergy -= bf.factorBelief[i][j] * Math.log(bf.factorBelief[i][j]); betheFreeEnergy += bf.factorBelief[i][j] * Math.log(bf.potential[i][j]); } } }*/ // unary factor belief = variable belief for (Variable v: this.variables) { v.computeBelief(); UnaryFactor uf = this.unaryFactors.get(v.getI()); for (int i = 0; i < v.getSize(); ++i) { // -2 to get rid of unary factor // generally -1 int n = v.getNeighbors().size() - 2; if (n != 0) { betheFreeEnergy += n * v.getBelief().measure[i] * Math.log(v.getBelief().measure[i]); } betheFreeEnergy += v.getBelief().measure[i] * Math.log(uf.potential[i]); } } return betheFreeEnergy; } /** * Returns an approximate partition function. * This is simply the exp of the Bethe Free Energy * @return */ public double approximateZ() { return Math.exp(this.betheFreeEnergy()); } /** * Performs inference by belief propagation * @param maxIterNum * @param convergence */ public void inference(int maxIterNum, double convergence) { for (int iterNum = 0; iterNum < maxIterNum; ++iterNum) { // update unary factors for (UnaryFactor ur : this.unaryFactors) { ur.passMessage(); } /* // update binary factors for (BinaryFactor bf : this.binaryFactors) { bf.passMessage(); } // update variables for (Variable v : this.variables) { v.passMessage(); }*/ } for (Variable v : this.variables) { v.computeBelief(); } } /** * Returns the most probable configuration under 0/1 loss */ public List<String> viterbiDecode() { return null; } /** * Returns the most probable configuration under Hamming loss * @return */ public List<String> posteriorDecode() { List<String> tags = new LinkedList<String>(); this.inference(this.inferenceIterations, 0.01); for (Variable v: this.variables) { Belief b = v.getBelief(); if (b.measure[1] > b.measure[0]) { tags.add(v.getTagName()); } } return tags; } /** * */ public double logLikelihood() { // partition function double logZ_B = this.betheFreeEnergy(); double configurationScore = 1.0; // sum over unary factors for (UnaryFactor uf : this.unaryFactors) { int value = this.golden.get(uf.getI()); configurationScore *= uf.potential[value]; } /* // sum over binary factors for (BinaryFactor bf : this.binaryFactors) { int value1 = this.golden.get(bf.getI()); int value2 = this.golden.get(bf.getJ()); configurationScore *= bf.potential[value1][value2]; } */ return Math.log(configurationScore) - logZ_B; } /** * Finite Difference Results * @return */ public double[] finiteDifference(double[] parameters, double epsilon) { double[] gradient = new double[parameters.length]; for (int i = 0; i < parameters.length; ++i) { parameters[i] += epsilon; this.updatePotentials(parameters); this.inference(10, 1.0); double val1 = this.logLikelihood(); parameters[i] -= 2 * epsilon; this.updatePotentials(parameters); this.inference(10, 1.0); double val2 = this.logLikelihood(); gradient[i] = (val1 - val2) / (2 * epsilon); parameters[i] += epsilon; } return gradient; } public void updatePotentials2(double[] parameters) { int counter = 0; for (UnaryFactor uf : this.unaryFactors) { uf.setPotential(0, Math.exp(parameters[counter])); ++counter; uf.setPotential(1, Math.exp(parameters[counter])); ++counter; //uf.renormalize(); } // random binary potentials for (BinaryFactor bf : this.binaryFactors) { bf.setPotential(0, 0, Math.exp(parameters[counter])); ++counter; bf.setPotential(0, 1, Math.exp(parameters[counter])); ++counter; bf.setPotential(1, 0, Math.exp(parameters[counter])); ++counter; bf.setPotential(1, 1, Math.exp(parameters[counter])); ++counter; } } public void updatePotentials(double[] parameters) { for (UnaryFactor uf : this.unaryFactors) { uf.updatePotential(parameters); } } /** * * @return */ public void featurizedGradient(double[] gradient, int numData) { this.inference(this.inferenceIterations, 0.01); for (UnaryFactor uf : this.unaryFactors) { if (this.golden.get(uf.getI()) == 1) { for (Integer feat : uf.getFeaturesPositive()) { gradient[feat] += 1.0 ; } } for (Integer feat : uf.getFeaturesPositive()) { gradient[feat] -= this.variables.get(uf.getI()).getBelief().measure[1]; } if (this.golden.get(uf.getI()) == 0) { for (Integer feat : uf.getFeaturesNegative()) { gradient[feat] += 1.0; } } for (Integer feat : uf.getFeaturesNegative()) { gradient[feat] -= this.variables.get(uf.getI()).getBelief().measure[0]; } } } /** * * @return */ public double[] unfeaturizedGradient() { this.inference(10, 0.01); double[] gradient = new double[this.numParameters]; int counter = 0; for (UnaryFactor uf : this.unaryFactors) { if (this.golden.get(uf.getI()) == 0) { gradient[counter] += 1.0; } gradient[counter] -= this.variables.get(uf.getI()).getBelief().measure[0]; ++counter; if (this.golden.get(uf.getI()) == 1) { gradient[counter] += 1.0; } gradient[counter] -= this.variables.get(uf.getI()).getBelief().measure[1]; ++counter; } // random binary potentials for (BinaryFactor bf : this.binaryFactors) { if (this.golden.get(bf.getI()) == 0 && this.golden.get(bf.getJ()) == 0) { gradient[counter] += 1.0; } gradient[counter] -= this.variables.get(bf.getI()).getBelief().measure[0] * this.variables.get(bf.getJ()).getBelief().measure[0]; ++counter; if (this.golden.get(bf.getI()) == 0 && this.golden.get(bf.getJ()) == 1) { gradient[counter] += 1.0; } gradient[counter] -= this.variables.get(bf.getI()).getBelief().measure[0] * this.variables.get(bf.getJ()).getBelief().measure[1]; ++counter; if (this.golden.get(bf.getI()) == 1 && this.golden.get(bf.getJ()) == 0) { gradient[counter] += 1.0; } gradient[counter] -= this.variables.get(bf.getI()).getBelief().measure[1] * this.variables.get(bf.getJ()).getBelief().measure[0]; ++counter; if (this.golden.get(bf.getI()) == 1 && this.golden.get(bf.getJ()) == 1) { gradient[counter] += 1.0; } gradient[counter] -= this.variables.get(bf.getI()).getBelief().measure[1] * this.variables.get(bf.getJ()).getBelief().measure[1]; ++counter; } return gradient; } public String getWord() { return word; } public void setWord(String word) { this.word = word; } public List<Variable> getVariables() { return this.variables; } public List<UnaryFactor> getUnaryFactor() { return this.unaryFactors; } }