/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.sumproduct.pseudolikelihood;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Random;
import java.util.Set;
import org.eclipse.jdt.annotation.NonNull;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.factors.FactorList;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableList;
import com.analog.lyric.dimple.solvers.core.ParameterEstimator;
/*
* The pseudolikelihood class uses the Pseudolikelihood algorithm
* to estimate parameters of a factor graph.
*/
public class PseudoLikelihood extends ParameterEstimator
{
private double _scaleFactor;
private HashMap<Factor,FactorInfo> _factor2factorInfo = new HashMap<Factor, FactorInfo>();
private HashMap<Variable, VariableInfo> _var2varInfo = new HashMap<Variable, VariableInfo>();
private @Nullable int [][] _data;
private HashMap<Variable,Integer> _var2index = new HashMap<Variable, Integer>();
private Variable [] _vars;
//The constructor saves the factor graph, the tables of interest, and the variables
//It also builds the NodeInfo object mappings.
public PseudoLikelihood(FactorGraph fg,
IFactorTable[] tables,
Variable [] vars)
{
super(fg, tables, new Random());
_vars = vars;
//Get all factors for this grah
FactorList fl = fg.getFactors();
//create a mapping from the input variables to the index into the input variable array
for (int i = 0; i < vars.length; i++)
_var2index.put(vars[i], i);
//Factor infos will be used to store joint empirical distributions over the factor
for (Factor f : fl)
_factor2factorInfo.put(f,new FactorInfo(f,_var2index));
//Retrieve all the variables that are connected to factors in the graph.
HashSet<Variable> varsConnectedToFactors = new HashSet<Variable>();
for (Factor f : fl)
for (int vi = 0, endvi = f.getSiblingCount(); vi < endvi; ++vi)
varsConnectedToFactors.add(f.getSibling(vi));
//for each variable, create a variable info
//This will be used to store a joint empirical distribution over all of the
//variable's neighbors.
//Additionally it will be used to calculate the probability of a setting of a variable given
//the emperical distribution and the current factor weights.
for (Variable v : varsConnectedToFactors)
_var2varInfo.put(v,VariableInfo.createVariableInfo(v, _var2index));
}
//Users can set data directly
public void setData(Object [][] data)
{
setData(convertObjects2Indices(data));
}
//Users can set data using indices.
//This routine builds up the empirical distributions.
public void setData(int [][] data)
{
_data = data;
//First reset the nodeinfos.
for (FactorInfo fi : _factor2factorInfo.values())
fi.reset();
for (VariableInfo vi : _var2varInfo.values())
vi.reset();
//Then go through the data and add samples to all the factorinfos and
//variableinfos.
for (int i = 0; i < data.length; i++)
{
for (FactorInfo fi : _factor2factorInfo.values())
fi.addSample(data[i]);
for (VariableInfo vi : _var2varInfo.values())
vi.addSample(data[i]);
}
}
//users can set the scale factor.
public void setScaleFactor(double scale)
{
_scaleFactor = scale;
}
//The learn function sets the data, num steps, scale factor and runs the gradient descent
public void learn(Object [][] data, int numSteps,double scaleFactor)
{
setForceKeep(true);
setData(data);
setScaleFactor(scaleFactor);
super.run(0,numSteps);
}
//The learn function sets the data, num steps, scale factor and runs the gradient descent
public void learn(int [][] data, int numSteps, double scaleFactor)
{
setForceKeep(true);
setData(data);
setScaleFactor(scaleFactor);
super.run(0,numSteps);
}
//Users cannot call run directly.
@Override
public void run(int numRestarts,int numSteps)
{
throw new DimpleException("Not supported");
}
//To do gradient descent, we have to be able to calculate the gradient.
public double [][] calculateGradient()
{
//gradient = FactorDegree * imperical distribution of factor
// - sum over all variables
// sum over all unique neighbor sample settings
// Pd(neighbors)*p(var | neighbors)
if (_data == null)
throw new DimpleException("Must set data first");
//Get the list of tables of interest.
IFactorTable [] tables = getTables();
HashMap<IFactorTable,ArrayList<Factor>> table2factors = getTable2Factors();
//initialize the gradient
double [][] gradients = new double[tables.length][];
//Invalidate the distributions because parameters may have changed.
for (VariableInfo vi : _var2varInfo.values())
vi.invalidateDistributions();
//for each unique factor table
for (int i = 0; i < tables.length; i++)
{
//cache some stuff.
double [] weights = tables[i].getWeightsSparseUnsafe();
int [][] indices = tables[i].getIndicesSparseUnsafe();
int degree = indices[0].length;
//TODO: avoid this new?
gradients[i] = new double[weights.length];
ArrayList<Factor> factors = table2factors.get(tables[i]);
//If this table actually is related to this graph
if (factors != null)
{
//for each factor
for (int k = 0; k < factors.size(); k++)
{
Factor f = factors.get(k);
FactorInfo fi = _factor2factorInfo.get(f);
//for each weight
for (int j = 0; j < weights.length; j++)
{
//add degree * pd(indices)
double impericalFactorD = fi.getDistribution().get(indices[j]);
gradients[i][j] += degree*impericalFactorD;
}
//for each variable
for (int vindex = 0, size = f.getSiblingCount(); vindex < size; ++vindex)
{
Variable v = f.getSibling(vindex);
VariableInfo vi = _var2varInfo.get(v);
//for each element of the variables domain
for (int d = 0; d < v.asDiscreteVariable().getDiscreteDomain().size(); d++)
{
Set<LinkedList<Integer>> samples = vi.getUniqueSamples();
//for each unique sample
for (LinkedList<Integer> sample : samples)
{
//calculate pneighbors
double prob = vi.getProb(d,sample);
//find weight index from variable domain and unique sample
int index = vi.getFactorTableIndex(f, d, sample);
//subtract prob
gradients[i][index] -= prob;
}
}
}
}
}
}
return gradients;
}
//One step of gradient descent simply calculates the gradient
//and applies it.
@Override
public void runStep(@NonNull FactorGraph fg)
{
double [][] gradient = calculateGradient();
applyGradient(gradient);
}
//Given a gradient, change the parameters.
private void applyGradient(double [][] gradient)
{
IFactorTable [] tables = getTables();
// for each table
for (int i = 0; i < tables.length; i++)
{
double [] ws = tables[i].getWeightsSparseUnsafe();
double normalizer = 0;
//for each weight
for (int j = 0; j < ws.length; j++)
{
//update the parameter
double tmp = Math.log(ws[j]);
tmp = tmp + gradient[i][j]*_scaleFactor;
ws[j] = Math.exp(tmp);
//build a normalizing constant
normalizer += ws[j];
}
//normalize
for (int j = 0; j < ws.length; j++)
ws[j] /= normalizer;
//save the changed weights.
tables[i].replaceWeightsSparse(ws);
}
}
//Calculate the numerical gradient. Useful for debugging.
public double calculateNumericalGradient(IFactorTable table, int weight, double delta)
{
if (_data == null)
throw new DimpleException("Must set data first");
//numerical gradient = change of pseudo likelihood / change of parameter
double y1 = calculatePseudoLikelihood();
double oldval = table.getWeightsSparseUnsafe()[weight];
double newval = oldval * Math.exp(delta);
table.setWeightForSparseIndex(newval, weight);
double y2 = calculatePseudoLikelihood();
table.setWeightForSparseIndex(oldval, weight);
return (y2-y1)/delta;
}
//Used for calculating the numerical gradient
public double calculatePseudoLikelihood()
{
final int[][] data = _data;
if (data == null)
{
return 0;
}
//1/M sum(m) sum(i) sum(j in neighbors(i)) tehta(xj,xi)
//- 1/M sum(m) sum(i) log Z(neighbors(i))
//retrieve the factor graph and variables
FactorGraph fg = getFactorGraph();
VariableList vl = fg.getVariablesFlat();
double total = 0;
//1/M sum(m) sum(i) sum(j in neighbors(i)) tehta(xj,xi)
//for each data point
final int size = data.length;
for (int m = 0; m < size; m++)
{
//for each variable
for (Variable v : vl)
{
//for each factor associated with the variable
for (Factor f : v.getFactorsFlat())
{
//Build the indices associated with these variables
final int nVars = f.getSiblingCount();
int [] indices = new int[nVars];
for (int i = 0; i < nVars; i++)
indices[i] = data[m][_var2index.get(f.getSibling(i))];
//add the term.
total -= f.getFactorTable().getEnergyForIndices(indices);
}
}
}
total /= size;
//- 1/M sum(m) sum(i) log Z(neighbors(i))
double total2 = 0;
//for each data point
for (int m = 0; m < size; m++)
{
//for each variable
for (Variable v : vl)
{
double sum = 0;
//for each domain value.
for (int d = 0; d < ((Discrete)v).getDomain().size(); d++)
{
double product = 1;
//for every factor connected to the variable.
for (Factor f : v.getFactorsFlat())
{
//build up the list of indices associated with that factor.
final int nVars = f.getSiblingCount();
int [] indices = new int[nVars];
for (int i = 0; i < indices.length; i++)
{
Variable fv = f.getSibling(i);
if (fv == v)
indices[i] = d;
else
indices[i] = data[m][_var2index.get(fv)];
}
//multiply in that term.
product *= f.getFactorTable().getWeightForIndices(indices);
}
//add terms together.
sum += product;
}
//take the log of the partition function.
total2 += Math.log(sum);
}
}
//return the totals
total -= total2 / size;
return total;
}
//Used for dealing with data that is provided as domain objects rather than indices.
final private int [][] convertObjects2Indices(Object [][] data)
{
int [][] retval = new int[data.length][data[0].length];
for (int i = 0; i < retval.length; i++)
{
for (int j = 0; j < retval[i].length; j++)
{
retval[i][j] = _vars[j].asDiscreteVariable().getDiscreteDomain().getIndex(data[i][j]);
}
}
return retval;
}
}