/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.lp;
import java.io.PrintStream;
import java.util.BitSet;
import java.util.List;
import java.util.SortedSet;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableBase;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableRepresentation;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.factors.DiscreteFactor;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.solvers.core.STableFactorBase;
import com.analog.lyric.dimple.solvers.core.SVariableBase;
import com.google.common.collect.SortedSetMultimap;
import com.google.common.collect.TreeMultimap;
import net.jcip.annotations.NotThreadSafe;
/**
* Solver table factor under LP solver.
*
* @since 0.07
* @author Christopher Barber
*/
@NotThreadSafe
public class LPTableFactor extends STableFactorBase
{
/*-------
* State
*/
/**
* The LP solver factor graph that owns this instance.
*/
private final LPSolverGraph _solverGraph;
/**
* The index of the first LP variable associated with the model factor.
* There will be one LP variable for each valid (non-zero probability) joint assignment of the variables
* input to the factor table.
* <p>
* Set to negative value if not yet computed.
*/
private int _lpVarIndex = -1;
private int _nLpVars = 0;
/**
* Represents the invalid joint assignments of the variables input to the factor table
* as an index into a factor table's {@link FactorTableBase#getWeightsSparseUnsafe()} array. An assignment is invalid
* either if the weight in the table is zero (which normally would only happen if it was explicitly
* set to zero after the table was constructed) or if one of the input parameters has a zero input
* probability.
* <p>
* This will be null if all assignments in the weights list are valid.
* <p>
* NOTE: for large non-sparse factor tables, we would want a representation that does not take
* O(size-of-factor-table) space.
*/
private @Nullable BitSet _invalidAssignments = null;
/*--------------
* Construction
*/
LPTableFactor(LPSolverGraph solverGraph, DiscreteFactor factor)
{
super(factor, solverGraph);
_solverGraph = solverGraph;
}
/*---------------------
* ISolverNode methods
*/
@Override
public DiscreteFactor getModelObject()
{
return (DiscreteFactor)super.getModelObject();
}
/**
* Returns the LP solver graph object to which this variable instance belongs.
* Note that unlike the default implementation provided by {@link SVariableBase#getParentGraph()},
* this method returns the graph that was used to construct this instance even
* if the solver on the associated model variable has changed.
*/
@Override
public LPSolverGraph getParentGraph()
{
return _solverGraph;
}
/*----------------------
* ISolverFactor methods
*/
/**
* Does nothing.
*/
@Override
public void doUpdateEdge(int outPortNum)
{
}
/*--------------------------
* STableFactorBase methods
*/
@Override
protected void setTableRepresentation(IFactorTable table)
{
table.setRepresentation(FactorTableRepresentation.SPARSE_WEIGHT_WITH_INDICES);
}
/*-------------------------
* LP STableFactor methods
*/
/**
* It is assumed that {@link LPDiscrete#computeValidAssignments()} has already been invoked on the variables
* connected to this factor.
*/
int computeValidAssignments()
{
final DiscreteFactor factor = getModelObject();
final IFactorTable factorTable = factor.getFactorTable();
final double[] weights = factorTable.getWeightsSparseUnsafe();
final LPDiscrete[] svariables = getSVariables();
int cardinality = 0;
boolean hasNonFixedVariable = true;
for (LPDiscrete svar : svariables)
{
if (!svar.hasFixedValue())
{
hasNonFixedVariable = true;
break;
}
}
if (hasNonFixedVariable)
{
for (int i = weights.length; --i >= 0;)
{
double weight = weights[i];
boolean skipEntry = true;
if (weight != 0)
{
// If any of the variable input weights for these values
// is zero, then we can skip this entry.
skipEntry = false;
final int[] indices = factorTable.sparseIndexToIndices(i);
for (int j = 0, endj = indices.length; j < endj; ++j)
{
LPDiscrete svar = svariables[j];
if (svar.hasZeroWeight(indices[j]))
{
skipEntry = true;
break;
}
}
}
if (skipEntry)
{
BitSet invalidAssignments = _invalidAssignments;
if (invalidAssignments == null)
{
invalidAssignments = _invalidAssignments = new BitSet(i);
}
invalidAssignments.set(i);
}
else
{
++cardinality;
}
}
}
_nLpVars = cardinality;
return cardinality;
}
/*
*/
int computeObjectiveFunction(double[] objectiveFunction, int start)
{
if (_nLpVars > 0)
{
_lpVarIndex = start;
final DiscreteFactor factor = getModelObject();
final IFactorTable factorTable = factor.getFactorTable();
final double[] weights = factorTable.getWeightsSparseUnsafe();
final int nWeights = weights.length;
final BitSet invalidAssignments = _invalidAssignments;
for (int i = 0; i < nWeights; ++i)
{
if (invalidAssignments != null && nWeights <= (i = invalidAssignments.nextClearBit(i)))
{
break;
}
objectiveFunction[start++] = Math.log(weights[i]);
}
}
return start;
}
private LPDiscrete[] getSVariables()
{
// Build array of solver variables for input variables.
final Factor factor = getModelObject();
final int nVars = factor.getSiblingCount();
final LPDiscrete[] svariables = new LPDiscrete[nVars];
for (int i = nVars; --i >= 0;)
{
// Getting this from the solver graph instead of from the model variable allows
// the solver to operate even when it is detached from the model.
svariables[i] = _solverGraph.getSolverVariable(factor.getSibling(i));
}
return svariables;
}
/**
* Computes constraint equations for this factor table and adds to {@code constraints}.
* @return the total number of non-zero terms in added constraints.
* <p>
* Call after {@link #computeObjectiveFunction}.
*/
int computeConstraints(List<IntegerEquation> constraints)
{
if (_nLpVars <= 0)
{
return 0;
}
final DiscreteFactor factor = getModelObject();
final IFactorTable factorTable = factor.getFactorTable();
final int[][] rows = factorTable.getIndicesSparseUnsafe();
final int nRows = rows.length;
final LPDiscrete[] svariables = getSVariables();
// Table of the marginal constraints for this factor where key is the index of the LP variable for
// the marginal variable value, and the associated values are the indexes of the LP
// variables in this factor that have the same variable value.
final SortedSetMultimap<Integer, Integer> marginalConstraints = TreeMultimap.create();
final BitSet invalidAssignments = _invalidAssignments;
for (int i = 0, lpFactor = _lpVarIndex; i < nRows; ++i, ++lpFactor)
{
if (invalidAssignments != null && nRows <= (i = invalidAssignments.nextClearBit(i)))
{
break;
}
int [] indices = rows[i];
for (int j = 0, endj = indices.length; j < endj; ++j)
{
LPDiscrete svar = svariables[j];
if (svar.hasLPVariable())
{
// Only build marginal constraints for variables that have LP variables
// (i.e. don't have fixed values).
int valueIndex = indices[j];
int lpVar = svar.domainIndexToLPVar(valueIndex);
marginalConstraints.put(lpVar, lpFactor);
}
}
}
int nTerms = 0;
for (int lpVar : marginalConstraints.keySet())
{
// This expresses the constraint that the marginal probability of a particular variable value
// is equal to the sum of the non-zero factor table entries for the same variable value.
SortedSet<Integer> lpFactorVars = marginalConstraints.get(lpVar);
int[] lpVars = new int[1 + lpFactorVars.size()];
lpVars[0] = lpVar;
int i = 0;
for (int lpFactorVar : lpFactorVars)
{
lpVars[++i] = lpFactorVar;
}
constraints.add(new LPFactorMarginalConstraint(this, lpVars));
nTerms += lpVars.length;
}
return nTerms;
}
void clearLPState()
{
_lpVarIndex = -1;
_invalidAssignments = null;
}
/**
* Returns the index of the first LP variable for this factor, or else a negative value if
* LP state has not yet been computed or if factor is not included in LP representation
* (e.g. because all its variables are fixed).
*/
public int getLPVarIndex()
{
return _lpVarIndex;
}
public int getNumberOfValidAssignments()
{
return _nLpVars;
}
/**
* Underlying implementation of {@link LPFactorMarginalConstraint#print}.
*/
void printConstraintEquation(PrintStream out, int[] lpVars)
{
final int lpVar = lpVars[0];
final DiscreteFactor factor = getModelObject();
final IFactorTable factorTable = factor.getFactorTable();
final int[][] rows = factorTable.getIndicesSparseUnsafe();
final int nRows = rows.length;
// Build array of solver variables for input variables.
final LPDiscrete[] svariables = getSVariables();
// Find the marginal variable from its lpVar index
for (LPDiscrete svar : svariables)
{
// Linear search could be replaced by a binary search.
if (svar.hasLPVariable(lpVar))
{
// Print out term for marginal variable value
Discrete var = svar.getModelObject();
String varName = var.getName();
int varValueIndex = svar.lpVarToDomainIndex(lpVar);
Object varValue = var.getDomain().getElement(varValueIndex);
out.format("-p(%s=%s)", varName, varValue);
break;
}
}
final BitSet invalidAssignments = _invalidAssignments;
int lpVarsIndex = 1;
for (int rowIndex = 0, lpVarForRow = _lpVarIndex; rowIndex < nRows; ++rowIndex, ++lpVarForRow)
{
if (invalidAssignments != null && nRows <= (rowIndex = invalidAssignments.nextClearBit(rowIndex)))
{
break;
}
if (lpVarForRow == lpVars[lpVarsIndex])
{
int[] indices = rows[rowIndex];
out.print(" + p(");
for (int i = 0, end = indices.length; i < end ; ++i)
{
LPDiscrete svar = svariables[i];
Discrete var = svar.getModelObject();
Object[] elements = var.getDomain().getElements();
if (i > 0)
{
out.print(",");
}
out.format("%s=%s", var.getName(), elements[indices[i]]);
}
out.print(")");
if (++lpVarsIndex == lpVars.length)
{
break;
}
}
}
out.println(" = 0");
}
}