/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.lp;
import static java.util.Objects.*;
import java.io.PrintStream;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.DiscreteFactor;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.schedulers.SchedulerOptionKey;
import com.analog.lyric.dimple.solvers.core.NoSolverEdge;
import com.analog.lyric.dimple.solvers.core.NoSolverVariableBlock;
import com.analog.lyric.dimple.solvers.core.SFactorGraphBase;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
import com.analog.lyric.dimple.solvers.interfaces.ISolverNode;
import com.analog.lyric.dimple.solvers.lp.IntegerEquation.TermIterator;
import com.analog.lyric.util.misc.Matlab;
import net.jcip.annotations.NotThreadSafe;
import net.sf.javailp.Linear;
import net.sf.javailp.Operator;
import net.sf.javailp.Problem;
import net.sf.javailp.Result;
import net.sf.javailp.SolverFactory;
/**
* Solver-specific factor graph for LP solver.
* <p>
* <em>Previously was com.analog.lyric.dimple.solvers.lp.SFactorGraph</em>
* <p>
* @since 0.07
*/
@NotThreadSafe
public class LPSolverGraph extends SFactorGraphBase<LPTableFactor, LPDiscrete, NoSolverEdge, NoSolverVariableBlock>
{
/*-------
* State
*/
/**
* Maps model variables to their corresponding solver variable. Iterators will return
* variables in the order in which they were first added to the map.
*/
private final LinkedHashMap<Variable, LPDiscrete> _varMap;
/**
* Maps model factors to their corresponding solver factor. Iterators will return
* variables in the order in which they were first added to the map.
*/
private final LinkedHashMap<Factor, LPTableFactor> _factorMap;
/**
* Contains the parameters of the linear objective function for the LP solve.
* It's length is the number of LP variables.
* <p>
* Null if not yet computed.
*/
private @Nullable double[] _objectiveFunction = null;
/**
* List of linear constraints describing this graph. The first {@link #getNumberOfVariableConstraints()}
* will be {@link LPVariableConstraint}s and the remainder will be {@link LPFactorMarginalConstraint}s.
* <p>
* Null if not yet computed.
*/
private @Nullable List<IntegerEquation> _constraints = null;
/**
* Number of non-zero terms in all of the {@link #_constraints}.
*/
private int _nConstraintTerms = 0;
/**
* The number of variable constraints, equal to the number of rows in the constraint
* list dedicated to variable constraints, which will be at the front of the
* {@link _constraints} list.
* <p>
* Negative if not yet computed.
*/
private int _nVariableConstraints = -1;
/**
* Name of external LP solver to be used to do the actual solving.
*/
private String _lpSolverName = "";
private String _lpMatlabSolver = ""; // TODO: merge lpSolverName and lpSolver.
/*--------------
* Construction
*/
LPSolverGraph(FactorGraph model, @Nullable ISolverFactorGraph parent)
{
super(model, parent);
_varMap = new LinkedHashMap<Variable, LPDiscrete>(model.getVariableCount());
_factorMap = new LinkedHashMap<Factor, LPTableFactor>(model.getFactorCount());
}
/*---------------------
* ISolverNode methods
*/
@Override
public void initialize()
{
super.initialize();
_lpSolverName = getOptionOrDefault(LPOptions.LPSolver);
_lpMatlabSolver = getOptionOrDefault(LPOptions.MatlabLPSolver);
}
/**
* Does nothing for this solver.
*/
@Override
public void update()
{
}
/**
* Does nothing for this solver.
*/
@Override
public void updateEdge(int outPortNum)
{
}
/*----------------------------
* ISolverFactorGraph methods
*/
@SuppressWarnings("deprecation") // TODO remove when STableFactor removed
@Override
public LPTableFactor createFactor(Factor factor)
{
LPTableFactor sfactor = _factorMap.get(factor);
if (sfactor == null)
{
if (!(factor instanceof DiscreteFactor))
{
throw new DimpleException("Factor '%s' is not a DiscreteFactor. LP solver only supports discrete factors.",
factor.getName());
}
sfactor = new STableFactor(this, (DiscreteFactor)factor);
_factorMap.put(factor, sfactor);
}
return sfactor;
}
@SuppressWarnings("deprecation") // Remove when SVariable removed
@Override
public LPDiscrete createVariable(Variable var)
{
LPDiscrete svar = _varMap.get(var);
if (svar == null)
{
if (!(var instanceof Discrete))
{
throw new DimpleException("Variable '%s' is not discrete. LP solver only supports discrete variables.",
var.getName());
}
// Reuse svar already associated with var if applicable.
svar = var.getSolverIfTypeAndGraph(LPDiscrete.class, this);
if (svar == null)
{
svar = new SVariable(this, (Discrete)var);
}
else
{
svar.clearLPState();
}
_varMap.put(var, svar);
}
return svar;
}
@Override
public ISolverFactorGraph createSubgraph(FactorGraph subgraph)
{
return new LPSolverGraph(subgraph, this);
}
/**
* {@inheritDoc}
* <p>
* This implementation returns "dimpleLPSolve" if the value of {@link #getLPSolverName()} is
* "matlab" and otherwise returns null.
*/
@Override
public @Nullable String getMatlabSolveWrapper()
{
return useMatlabSolver() ? "dimpleLPSolve" : null;
}
/**
* {@inheritDoc}
* @return {@code null}
*/
@Override
public @Nullable SchedulerOptionKey getSchedulerKey()
{
return null;
}
@SuppressWarnings("unchecked")
@Override
public Collection<LPTableFactor> getSolverFactorsRecursive()
{
return (Collection<LPTableFactor>) super.getSolverFactorsRecursive();
}
@SuppressWarnings("unchecked")
@Override
public Collection<LPDiscrete> getSolverVariablesRecursive()
{
return (Collection<LPDiscrete>) super.getSolverVariablesRecursive();
}
@Override
public boolean hasEdgeState()
{
return false;
}
private boolean useMatlabSolver()
{
return _lpSolverName.isEmpty() || _lpSolverName.equalsIgnoreCase("matlab");
}
@Override
public void iterate(int numIters)
{
if (useMatlabSolver())
{
throw new DimpleException("Java solve() not supported for LP solver using 'MATLAB' as underlying solver");
}
net.sf.javailp.Solver solver = null;
try
{
@SuppressWarnings("unchecked")
Class<SolverFactory> factoryClass =
(Class<SolverFactory>)Class.forName(String.format("net.sf.javailp.SolverFactory%s", _lpSolverName));
solver = factoryClass.newInstance().get();
}
catch (Exception ex)
{
throw new DimpleException("Cannot load underlying LP solver '%s': %s'", _lpSolverName, ex.toString());
}
buildLPState(); // computes object function and constraints
Problem problem = new Problem();
double[] objectiveCoefficients = requireNonNull(getObjectiveFunction());
Linear objective = new Linear();
for (int i = 0, end = objectiveCoefficients.length; i < end; ++i)
{
objective.add(objectiveCoefficients[i], i);
problem.setVarBounds(0.0, i, 1.0);
}
problem.setObjective(objective);
for (IntegerEquation constraint : requireNonNull(getConstraints()))
{
Linear linear = new Linear();
TermIterator iter = constraint.getTerms();
while (iter.advance())
{
linear.add(iter.getCoefficient(), iter.getVariable());
}
problem.add(linear, Operator.EQ, constraint.getRHS());
}
Result result = solver.solve(problem);
double[] solution = new double[getNumberOfLPVariables()];
for (int i = 0, end = solution.length; i < end; ++i)
{
solution[i] = result.get(i).doubleValue();
}
setSolution(solution);
}
/**
* Does nothing. Input ignored.
*/
@Override
public void setNumIterations(int numIterations)
{
}
/**
* Always returns one.
*/
@Override
public int getNumIterations()
{
return 1;
}
@Override
public void estimateParameters(IFactorTable[] tables, int numRestarts, int numSteps, double stepScaleFactor)
{
throw unsupported("estimateParameters");
}
@Override
public void baumWelch(IFactorTable[] tables, int numRestarts, int numSteps)
{
throw unsupported("baumWelch");
}
@Override
public void moveMessages(ISolverNode other)
{
throw unsupported("moveMessages");
}
@Override
protected String getSolverName()
{
return "LP";
}
/*-------------------------
* LP SFactorGraph methods
*/
/**
* Returns the linear objective function for the underlying LP solver or null if
* not yet computed.
* @see #buildLPState()
*/
@Matlab
public @Nullable double[] getObjectiveFunction()
{
return _objectiveFunction;
}
/**
* Get constraint linear equations. The first {@link #getNumberOfVariableConstraints()}
* constraints will be of type {@link LPVariableConstraint} and the remainder will be of
* type {@link LPFactorMarginalConstraint}.
*/
public @Nullable List<IntegerEquation> getConstraints()
{
return _constraints;
}
/**
* Returns an object that can iterate over the non-zero terms of the linear
* constraint equations for constructing a sparse MATLAB matrix.
*/
@Matlab
public MatlabConstraintTermIterator getMatlabSparseConstraints()
{
return new MatlabConstraintTermIterator(_constraints, _nConstraintTerms);
}
@Matlab
public double[][] getMatlabConstraintArrays()
{
MatlabConstraintTermIterator termIter = getMatlabSparseConstraints();
int numel = termIter.size();
double[][] result= new double[numel][3];
int ct=0;
while (termIter.advance())
{
result[ct][0]=termIter.getRow();
result[ct][1]=termIter.getVariable();
result[ct][2]=termIter.getCoefficient();
ct++;
}
return result;
}
@Matlab
public String getLPSolverName()
{
return _lpSolverName;
}
@Matlab
public String getMatlabLPSolver()
{
return _lpMatlabSolver;
}
@Matlab
public void setMatlabLPSolver(@Nullable String name)
{
_lpMatlabSolver = name != null ? name : "";
setOption(LPOptions.MatlabLPSolver, _lpMatlabSolver);
}
@Matlab
public void setLPSolverName(@Nullable String name)
{
_lpSolverName = name != null ? name : "";
setOption(LPOptions.LPSolver, _lpSolverName);
}
/**
* The number of constraints equations returned by {@link #getConstraints}
* or -1 if not yet computed.
* @see #hasLPState()
* @see #getNumberOfVariableConstraints()
*/
@Matlab
public int getNumberOfConstraints()
{
final List<IntegerEquation> constraints = _constraints;
return constraints != null ? constraints.size() : -1;
}
/**
* Returns the number of unique linear variables in the constraints to be
* solved using LP.
* <p>
* Returns -1 if not yet computed.
*/
public int getNumberOfLPVariables()
{
final double[] objectiveFunction = _objectiveFunction;
return objectiveFunction != null ? objectiveFunction.length : -1;
}
@Matlab
public int getNumberOfMarginalConstraints()
{
return _constraints != null ? getNumberOfConstraints() - getNumberOfVariableConstraints() : -1;
}
/**
* The number of constraint equations describing a single variable or
* -1 if not yet computed.
* @see #getConstraints()
* @see #hasLPState()
*/
@Matlab
public int getNumberOfVariableConstraints()
{
return _nVariableConstraints;
}
/**
* Prints constraint equations using model variable names and values for debugging purposes.
*/
public void printConstraints(PrintStream out)
{
final List<IntegerEquation> constraints = _constraints;
if (constraints == null)
{
out.println("Constraints not yet computed.");
return;
}
for (IntegerEquation constraint : constraints)
{
constraint.print(out);
}
}
@Matlab
public void printConstraints()
{
printConstraints(System.out);
}
@Matlab
public void setSolution(double[] solution)
{
for (LPDiscrete svar : _varMap.values())
{
svar.setBeliefsFromLPSolution(solution);
}
}
/**
* Builds the LP description of the problem for the underlying LP solver
* to work on.
* @see #hasLPState()
*/
@Matlab
public void buildLPState()
{
final FactorGraph model = getModelObject();
int nLPVars = 0;
// Create solver variables, if not already created
for (Variable var : model.getVariables())
{
LPDiscrete svar = createVariable(var);
nLPVars += svar.computeValidAssignments();
}
// Create solver factor tables, if not already created.
for (Factor factor : model.getFactors())
{
LPTableFactor sfactor = createFactor(factor);
nLPVars += sfactor.computeValidAssignments();
}
double[] objectiveFunction = new double[nLPVars];
int lpVarIndex = 0;
List<IntegerEquation> constraints = new LinkedList<IntegerEquation>();
int nTerms = 0;
for (LPDiscrete svar : _varMap.values())
{
lpVarIndex = svar.computeObjectiveFunction(objectiveFunction, lpVarIndex);
nTerms += svar.computeConstraints(constraints);
}
_nVariableConstraints = constraints.size();
for (LPTableFactor sfactor : _factorMap.values())
{
lpVarIndex = sfactor.computeObjectiveFunction(objectiveFunction, lpVarIndex);
nTerms += sfactor.computeConstraints(constraints);
}
_objectiveFunction = objectiveFunction;
_constraints = constraints;
_nConstraintTerms = nTerms;
}
public void clearLPState()
{
_objectiveFunction = null;
_nVariableConstraints = -1;
_constraints = null;
_nConstraintTerms = 0;
for (LPDiscrete svar : _varMap.values())
{
svar.clearLPState();
}
for (LPTableFactor sfactor : _factorMap.values())
{
sfactor.clearLPState();
}
_varMap.clear();
_factorMap.clear();
}
/**
* Returns solver factor belonging to this solver graph that is
* associated with input model factor or else null.
*/
@Override
public LPTableFactor getSolverFactor(Factor factor)
{
return _factorMap.get(factor);
}
/**
* Returns solver variable belonging to this solver graph that is
* associated with input model variable or else null.
*/
@Override
public LPDiscrete getSolverVariable(Variable var)
{
return _varMap.get(var);
}
/**
* Returns true if state needed for external LP solver to operate has been computed. If true
* the following methods will return valid values:
* <ul>
* <li>{@link #getConstraints}
* <li>{@link #getObjectiveFunction}
* <li>{@link #getMatlabSparseConstraints()}
* <li>{@link #getNumberOfConstraints}
* <li>{@link #getNumberOfMarginalConstraints()}
* <li>{@link #getNumberOfVariableConstraints}
* </ul>
*
* @see #buildLPState()
* @see #clearLPState()
*/
public boolean hasLPState()
{
return _objectiveFunction != null;
}
/*-----------------
* Private methods
*/
private DimpleException unsupported(String methodName)
{
return DimpleException.unsupportedBySolver("LP", methodName);
}
/*
*
*/
@Override
protected void doUpdateEdge(int edge)
{
}
}