/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.lp;
import static com.analog.lyric.math.Utilities.*;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.solvers.core.PriorAndCondition;
import com.analog.lyric.dimple.solvers.core.SDiscreteVariableBase;
import com.analog.lyric.dimple.solvers.core.SVariableBase;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
import net.jcip.annotations.NotThreadSafe;
/**
* Solver variable for Discrete variables under LP solver.
*
* @since 0.07
*/
@NotThreadSafe
public class LPDiscrete extends SDiscreteVariableBase
{
/*-------
* State
*/
/**
* The LP solver factor graph that owns this instance.
*/
private final LPSolverGraph _solverGraph;
/**
* The index of the first marginal LP variable associated with the model variable.
* There will be one LP variable for each valid assignment to the variable, i.e. for
* each value of the variable with non-zero probability.
* <p>
* Set to negative value if not yet computed or if variable does not have any associated LP
* variables (because it has a fixed value).
*/
private int _lpVarIndex = -1;
/**
* Represents the invalid assignments to the variable: which discrete values the variable
* is not allowed to take based on the input probabilities. If all values are allowed or if not yet computed, this
* will simply be null.
*/
private @Nullable BitSet _invalidAssignments;
/**
* The number of valid assignments to the variable, or negative if not yet computed.
*/
private int _nValidAssignments = -1;
private @Nullable double[] _beliefs = null;
/*--------------
* Construction
*/
LPDiscrete(LPSolverGraph solverGraph, Discrete var)
{
super(var, solverGraph);
_solverGraph = solverGraph;
}
/*---------------------
* ISolverNode methods
*/
/**
* Returns the LP solver graph object to which this variable instance belongs.
* Note that unlike the default implementation provided by {@link SVariableBase#getParentGraph()},
* this method returns the graph that was used to construct this instance even
* if the solver on the associated model variable has changed.
*/
@Override
public LPSolverGraph getParentGraph()
{
return _solverGraph;
}
/**
* Does nothing.
*/
@Override
protected void doUpdateEdge(int outPortNum)
{
}
/*-------------------------
* ISolverVariable methods
*/
@Override
public double[] getBelief()
{
double[] beliefs = _beliefs;
if (beliefs == null)
{
final int size = getModelObject().getDomain().size();
beliefs = new double[size];
PriorAndCondition known = getPriorAndCondition();
Value fixedValue = known.value();
if (fixedValue != null)
{
beliefs[fixedValue.getIndex()] = 1.0;
}
else
{
DiscreteMessage prior = toEnergyMessage(known);
if (prior != null)
{
prior.getWeights(beliefs);
normalize(beliefs);
}
else
{
Arrays.fill(beliefs,1.0/size);
}
}
known.release();
}
return beliefs;
}
/*--------------------
* LPDiscrete methods
*/
/**
* Returns the index of the domain element for this variable for the given lp variable.
*/
public int lpVarToDomainIndex(int lpVar)
{
int domainIndex = -1;
if (hasLPVariable(lpVar))
{
domainIndex = lpVar - _lpVarIndex;
final BitSet invalidAssignments = _invalidAssignments;
if (invalidAssignments != null)
{
domainIndex = -1;
while (lpVar-- >= _lpVarIndex)
{
// Find next clear (valid assignment) bit
domainIndex = invalidAssignments.nextClearBit(domainIndex + 1);
}
}
}
return domainIndex;
}
/**
* Convert index into variable domain into index of corresponding LP variable.
* Returns negative value if {@link #computeObjectiveFunction} not yet called
* of if there is no LP variable for the given {@code domainIndex} (because
* the value has been pruned from the equations because of zero weight).
*/
public int domainIndexToLPVar(int domainIndex)
{
int lpVar = _lpVarIndex;
if (lpVar >= 0)
{
final BitSet invalidAssignments = _invalidAssignments;
if (invalidAssignments == null)
{
lpVar += domainIndex;
}
else
{
if (invalidAssignments.get(domainIndex))
{
return -1;
}
for (int i = 0; (i = invalidAssignments.nextClearBit(i)) < domainIndex; ++i)
{
++lpVar;
}
}
}
return lpVar;
}
/**
* Discovers the non-zero input weights, computes their sum for normalization, and
* returns the number of LP variables to generate for this variable.
* <p>
* @returns the number of LP variables to generate, which is zero if the variable
* has a fixed value (i.e. only one non-zero input weight), and otherwise should
* equal the number of non-zero weights.
*/
int computeValidAssignments()
{
final int domlength = getDomain().size();
PriorAndCondition known = getPriorAndCondition();
final Value fixedValue = known.value();
if (fixedValue != null)
{
_nValidAssignments = 1;
BitSet invalidAssignments = new BitSet(domlength);
invalidAssignments.set(fixedValue.getIndex());
invalidAssignments.flip(0, domlength);
_invalidAssignments = invalidAssignments;
known.release();
return 0;
}
DiscreteMessage prior = toEnergyMessage(known);
known = known.release();
int cardinality = 0;
if (prior != null)
{
for (int i = domlength; --i >=0 ;)
{
if (prior.hasZeroWeight(i))
{
BitSet invalidAssignments = _invalidAssignments;
if (invalidAssignments == null)
{
invalidAssignments = _invalidAssignments = new BitSet(i);
}
invalidAssignments.set(i, true);
}
else
{
++cardinality;
}
}
_nValidAssignments = cardinality;
return cardinality > 1 ? cardinality : 0;
}
else
{
_invalidAssignments = new BitSet(domlength);
_nValidAssignments = domlength;
return domlength;
}
}
/**
* Compute the objective function parameters for this variable. This is simply
* the log probabilities of each possible variable value.
*
* @param objectiveFunction is the array containing the objective function.
* @param start is the index of the first available slot in {@code objectiveFunction}.
* @return the index of the next available slot in {@code objectiveFunction}. The
* difference between this value and {@code start} should be equal to the number of
* valid variable assignments unless there is only one valid assignment.
*/
int computeObjectiveFunction(double[] objectiveFunction, int start)
{
if (_nValidAssignments > 1)
{
_lpVarIndex = start;
PriorAndCondition known = getPriorAndCondition();
final Value value = known.value();
if (value != null)
{
objectiveFunction[start++] = 0;
}
else
{
DiscreteMessage prior = toEnergyMessage(known);
if (prior != null)
{
for (double energy : prior.getEnergies())
{
if (energy < Double.POSITIVE_INFINITY)
{
objectiveFunction[start++] = -energy;
}
}
}
else
{
final int size = getDomain().size();
Arrays.fill(objectiveFunction, start, start + size, 0.0);
start += size;
}
}
known = known.release();
}
else
{
_lpVarIndex = -1;
}
return start;
}
/**
* Computes constraint equation for this variable and adds to {@code constraints}.
* @return the number of non-zero terms in the constraint.
* <p>
* Call after {@link #computeObjectiveFunction}.
*/
int computeConstraints(List<IntegerEquation> constraints)
{
int nTerms = 0;
if (_lpVarIndex >= 0)
{
LPVariableConstraint constraint = new LPVariableConstraint(this);
constraints.add(constraint);
nTerms = constraint.size();
}
return nTerms;
}
void printConstraintEquation(PrintStream out)
{
final Discrete mvar = getModelObject();
final String varName = mvar.getName();
DiscreteDomain domain = mvar.getDomain();
final BitSet invalidAssignments = _invalidAssignments;
boolean first = true;
for (int i = 0, end = domain.size(); i < end; ++i)
{
if (invalidAssignments != null)
{
if ((i = invalidAssignments.nextClearBit(i)) >= end)
{
break;
}
}
if (first)
{
first = false;
}
else
{
out.print(" + ");
}
out.format("p(%s=%s)", varName, domain.getElement(i));
}
out.println(" = 1");
}
void setBeliefsFromLPSolution(double[] solution)
{
final int beliefSize = getModelObject().getDomain().size();
final double[] beliefs = new double[beliefSize];
final int start = _lpVarIndex;
if (start >= 0)
{
final BitSet invalidAssignments = _invalidAssignments;
for (int i = 0, j = start; i < beliefSize; ++i, ++j)
{
if (invalidAssignments != null)
{
if ((i = invalidAssignments.nextClearBit(i)) >= beliefSize)
{
break;
}
}
beliefs[i] = solution[j];
}
_beliefs = beliefs;
}
else
{
// Fixed value
Value value = getKnownValue();
if (value != null)
{
beliefs[value.getIndex()] = 1.0;
_beliefs = beliefs;
}
}
}
void clearLPState()
{
_lpVarIndex = -1;
_invalidAssignments = null;
_nValidAssignments = -1;
}
/**
* Returns the index of the first LP variable associated with the values of this variable.
* Returns negative value if not yet computed or if there are no associated LP variables
* (because there is only one valid value).
*/
public int getLPVarIndex()
{
return _lpVarIndex;
}
public int getNumberOfValidAssignments()
{
return _nValidAssignments;
}
public boolean hasFixedValue()
{
return getKnownValue() != null;
}
/**
* True if {@link #computeObjectiveFunction} has been called and
* variable has at least one associated LP variable,
* i.e. {@link #getLPVarIndex} is non-negative.
*/
public boolean hasLPVariable()
{
return _lpVarIndex >= 0;
}
/**
* True if {@link #computeObjectiveFunction} has been called and
* {@code lpVar} index refers to one of the LP variables used by
* this variable.
*/
boolean hasLPVariable(int lpVar)
{
return _lpVarIndex >= 0 && lpVar >= _lpVarIndex && lpVar < _lpVarIndex + _nValidAssignments;
}
/**
* True if variable cannot have given index given its priors/fixed value.
* @since 0.08
*/
boolean hasZeroWeight(int index)
{
PriorAndCondition known = getPriorAndCondition();
Value fixedValue = known.value();
boolean result = false;
if (fixedValue != null)
{
result = index != fixedValue.getIndex();
}
else
{
DiscreteMessage prior = toEnergyMessage(known);
if (prior != null)
{
result = prior.hasZeroWeight(index);
}
}
known.release();
return result;
}
}