/******************************************************************************* * Copyright 2013 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct.pseudolikelihood; import static java.util.Objects.*; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.Set; import com.analog.lyric.dimple.model.core.INode; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.util.misc.IMapList; /* * The VariableInfo object stores an empirical distribution of all variables that it is * directly connected to through neighboring factors (excluding itself) * * Additionally, it provides a method to calculate the joint probability of a particular state * of variables using p(joint) = p(x|neighbors)*p(neighbors) */ public class VariableInfo extends NodeInfo { private HashSet<LinkedList<Integer>> _uniqueSamplesPerValue = new HashSet<LinkedList<Integer>>(); private Variable [] _neighbors; private Discrete _var; private HashMap<LinkedList<Integer>,double[]> _neighbors2distributions = new HashMap<LinkedList<Integer>, double[]>(); private HashMap<Factor, int []> _factor2mapping = new java.util.HashMap<Factor, int[]>(); //This is a factory method for creating a variable. This is necessary since //the parent class's constructor requires we already know the mapping from all variables //to variables of interest. public static VariableInfo createVariableInfo(Variable var, HashMap<Variable,Integer> var2index) { //Find the variables neighboring variables. Variable [] neighbors = getNeighbors(var); //Get the indices of interest for this variable. int [] indices = getIndices(var,var2index); return new VariableInfo(var,indices,neighbors,var2index); } private VariableInfo(Variable var,int [] indices, Variable [] neighbors, HashMap<Variable,Integer> var2index) { super(indices); _neighbors = neighbors; _var = (Discrete)var; //For every factor we will build a mapping from the factor's variable index //to an index into the full list of neighbors. Factor [] fs = var.getFactorsFlat(); for (Factor f : fs) { final int nVars = f.getSiblingCount(); int [] mapping = new int[nVars]; for (int i = 0; i < nVars; i++) { Variable tmp = f.getSibling(i); if (var == tmp) //This is a special case mapping[i] = _neighbors.length; else { boolean found = false; //Silly linear search. Fix this eventually. for (int j = 0; j < neighbors.length; j++) { if (neighbors[j] == tmp) { mapping[i] = j; found = true; break; } } if (! found) throw new RuntimeException("ack, this should never happen"); } } _factor2mapping.put(f, mapping); } } //Cleanup when reset is called so this can be reused. @Override public void reset() { _uniqueSamplesPerValue.clear(); invalidateDistributions(); super.reset(); } public Variable getVariable() { return _var; } //Distributions are cached but, as a result, we have to invalidate the cache at the right //time. public void invalidateDistributions() { _neighbors2distributions.clear(); } //Given a factor, a domain value, and domain values for the neighbors, calculate the //factor table index. public int getFactorTableIndex(Factor f, int domainValue, LinkedList<Integer> domainVals) { Integer [] domainValues = new Integer[domainVals.size()]; domainValues = domainVals.toArray(domainValues); return getFactorTableIndex(f, domainValue, domainValues); } //Retrieve the joint probability of a variable and its neighbors. public double getProb(int varIndex,LinkedList<Integer> neighbors) { //First get the empirical probability of a neighbor. double pneighbors = getDistribution().get(neighbors); Integer [] domainValues = new Integer[neighbors.size()]; domainValues = neighbors.toArray(domainValues); //Cache the distribution of p(x|neighbors) since we have to calculate them all to //correctly normalize things. if (!_neighbors2distributions.containsKey(neighbors)) { //initialize double [] distribution = new double[_var.getDiscreteDomain().size()]; double normalizer = 0; //calculate the probability for each setting of this var. for (int i = 0; i < distribution.length; i++) { double total = 1; //For every factor for (Factor f : _factor2mapping.keySet()) { //retrieve the factor table index from this vars domain //and the neighbor domains int index = getFactorTableIndex(f,i,domainValues); //retrieve the weight double weight = f.getFactorTable().getWeightsSparseUnsafe()[index]; //TODO: should probably do this in the log domain //multiply it in total *= weight; } //save the value distribution[i] = total; //add to the normalizing constant normalizer += total; } //normalize for (int i = 0; i < distribution.length; i++) distribution[i] /= normalizer; //save _neighbors2distributions.put(neighbors,distribution); } return _neighbors2distributions.get(neighbors)[varIndex]*pneighbors; } // In addition to bulding up the empirical distribution, save all of the sample data // relevant to this variable. This ends up duplicating the data across variables // so this trade-off for speed will cost memory. There is probably a better way. // There is some compression due to storing these as a set. @Override public void addSample(int [] allDataIndices) { super.addSample(allDataIndices); LinkedList<Integer> otherIndices = indicesToRelevantOnes(allDataIndices); _uniqueSamplesPerValue.add(otherIndices); } //Returns the set of unique samples. public Set<LinkedList<Integer>> getUniqueSamples() { return _uniqueSamplesPerValue; } //Converts the domain value for this variable plus the neighbor domain values //to a factor table index. private int getFactorTableIndex(Factor f, int domainValue, Integer [] domainVals) { int [] mapping = _factor2mapping.get(f); int [] indices = new int[f.getSiblingCount()]; for (int j = 0; j < mapping.length; j++) { if (mapping[j] >= domainVals.length) //deal with the special case. indices[j] = domainValue; else indices[j] = domainVals[mapping[j]]; } int index = f.getFactorTable().sparseIndexFromIndices(indices); return index; } //uses a breadth first search to find all neighboring variables. private static Variable [] getNeighbors(Variable var) { IMapList<INode> ml = requireNonNull(var.getRootGraph()).depthFirstSearchFlat(var, 2); HashSet<Variable> neighbors = new HashSet<Variable>(); for (INode n : ml) { if (n.isVariable() && n != var) neighbors.add(n.asVariable()); } Variable [] retVal = new Variable[neighbors.size()]; return neighbors.toArray(retVal); } //uses a breadth first search to find all neighboring variables and then builds //the map. private static int [] getIndices(Variable var,HashMap<Variable,Integer> var2index) { Variable [] neighbors = getNeighbors(var); int [] indices = new int[neighbors.length]; int i = 0; for (Variable n : neighbors) { indices[i] = var2index.get(n); i++; } return indices; } }