/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.lp;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.solvers.core.PriorAndCondition;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteEnergyMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteWeightMessage;
import com.analog.lyric.dimple.solvers.lp.IntegerEquation;
import com.analog.lyric.dimple.solvers.lp.LPDiscrete;
import com.analog.lyric.dimple.solvers.lp.LPFactorMarginalConstraint;
import com.analog.lyric.dimple.solvers.lp.LPSolverGraph;
import com.analog.lyric.dimple.solvers.lp.LPTableFactor;
import com.analog.lyric.dimple.solvers.lp.LPVariableConstraint;
import com.analog.lyric.dimple.solvers.lp.MatlabConstraintTermIterator;
import com.analog.lyric.dimple.solvers.lp.Solver;
import com.analog.lyric.dimple.test.DimpleTestBase;
public class LPSolverTestCase extends DimpleTestBase
{
public final FactorGraph model;
public @Nullable String[] expectedConstraints = null;
public final LPSolverGraph solver;
public LPSolverTestCase(FactorGraph model)
{
this.model = model;
solver = new Solver().createFactorGraph(model);
assertInitialState();
assertNull(solver.getParentGraph());
assertSame(solver, solver.getRootSolverGraph());
assertEquals("dimpleLPSolve", solver.getMatlabSolveWrapper());
solver.setLPSolverName("Matlab");
assertEquals("Matlab", solver.getLPSolverName());
assertEquals("dimpleLPSolve", solver.getMatlabSolveWrapper());
solver.setLPSolverName("foo");
assertEquals("foo", solver.getLPSolverName());
assertEquals(null, solver.getMatlabSolveWrapper());
solver.setLPSolverName(null);
// Test do nothing methods
solver.setNumIterations(42);
assertEquals(1, solver.getNumIterations());
solver.update();
solver.updateEdge(0);
solver.initialize();
// Test unsupported methods
expectThrow(DimpleException.class, solver, "estimateParameters", null, 0, 0, 0.0);
expectThrow(DimpleException.class, solver, "baumWelch", null, 0, 0);
}
/**
* Tests construction of linear programming version of factor
* graph.
*/
public void testLPState()
{
solver.buildLPState();
assertTrue(solver.hasLPState());
final int nLPVars = solver.getNumberOfLPVariables();
assertTrue(nLPVars >= 0);
final double[] objective = requireNonNull(solver.getObjectiveFunction());
assertEquals(nLPVars, objective.length);
int nVarsUsed = 0;
for (Variable var : model.getVariables())
{
LPDiscrete svar = requireNonNull(solver.getSolverVariable(var));
Discrete mvar = svar.getModelObject();
assertSame(var, mvar);
assertSame(solver, svar.getParentGraph());
// Test do-nothing methods
svar.updateEdge(0);
int lpVar = svar.getLPVarIndex();
int nValidAssignments = svar.getNumberOfValidAssignments();
if (var.hasFixedValue())
{
// Currently the converse is not true because model variables
// do not currently check to see if there is only one non-zero input weight.
assertTrue(svar.hasFixedValue());
}
if (svar.hasFixedValue())
{
assertFalse(svar.hasLPVariable());
}
if (svar.hasLPVariable())
{
assertTrue(lpVar >= 0);
assertTrue(nValidAssignments > 1);
++nVarsUsed;
}
else
{
assertEquals(-1, lpVar);
assertTrue(nValidAssignments <= 1);
}
DiscreteMessage prior = new DiscreteWeightMessage(mvar.getDomain(), mvar.getPrior());
for (int i = 0, end = svar.getModelObject().getDomain().size(); i < end; ++i)
{
double w = prior.getWeight(i);
int lpVarForValue = svar.domainIndexToLPVar(i);
int i2 = svar.lpVarToDomainIndex(lpVarForValue);
if (lpVarForValue >= 0)
{
assertEquals(i, i2);
assertEquals(Math.log(w), objective[lpVarForValue], 1e-6);
}
if (!svar.hasLPVariable() || w == 0.0)
{
assertEquals(-1, lpVarForValue);
}
}
}
for (Factor factor : model.getFactors())
{
LPTableFactor sfactor = requireNonNull(solver.getSolverFactor(factor));
assertSame(factor, sfactor.getModelObject());
assertSame(solver, sfactor.getParentGraph());
// Test do nothing methods
sfactor.updateEdge(0);
}
final List<IntegerEquation> constraints = solver.getConstraints();
assertNotNull(constraints);
int nConstraints = solver.getNumberOfConstraints();
int nVarConstraints = solver.getNumberOfVariableConstraints();
int nMarginalConstraints = solver.getNumberOfMarginalConstraints();
assertEquals(nConstraints, constraints.size());
assertEquals(nConstraints, nVarConstraints + nMarginalConstraints);
assertEquals(nVarsUsed, nVarConstraints);
{
MatlabConstraintTermIterator termIter = solver.getMatlabSparseConstraints();
List<Integer> constraintTerms = new ArrayList<Integer>(termIter.size() * 3);
Iterator<IntegerEquation> constraintIter = constraints.iterator();
for (int row = 1; constraintIter.hasNext(); ++ row)
{
IntegerEquation constraint = constraintIter.next();
int nExpectedTerms = -1;
int lpVar = -1;
if (row <= nVarConstraints)
{
LPVariableConstraint varConstraint = constraint.asVariableConstraint();
assertNotNull(varConstraint);
assertNull(constraint.asFactorConstraint());
LPDiscrete svar = varConstraint.getSolverVariable();
assertTrue(svar.hasLPVariable());
assertEquals(1, varConstraint.getRHS());
nExpectedTerms = svar.getNumberOfValidAssignments();
lpVar = svar.getLPVarIndex();
}
else
{
LPFactorMarginalConstraint factorConstraint = constraint.asFactorConstraint();
assertNotNull(factorConstraint);
assertNull(constraint.asVariableConstraint());
LPTableFactor sfactor = factorConstraint.getSolverFactor();
lpVar = sfactor.getLPVarIndex();
assertEquals(0, factorConstraint.getRHS());
nExpectedTerms = sfactor.getNumberOfValidAssignments();
}
int[] lpvars = constraint.getVariables();
assertEquals(constraint.size(), lpvars.length);
assertEquals(0, constraint.getCoefficient(-1));
assertEquals(0, constraint.getCoefficient(lpVar + nExpectedTerms));
assertFalse(constraint.hasCoefficient(-1));
assertTrue(lpVar >= 0);
IntegerEquation.TermIterator constraintTermIter = constraint.getTerms();
for (int i = 0; constraintTermIter.advance(); ++i)
{
assertEquals(lpvars[i], constraintTermIter.getVariable());
assertEquals(constraintTermIter.getCoefficient(), constraint.getCoefficient(lpvars[i]));
assertTrue(constraint.hasCoefficient(lpvars[i]));
constraintTerms.add(row);
constraintTerms.add(constraintTermIter.getVariable());
constraintTerms.add(constraintTermIter.getCoefficient());
}
assertFalse(constraintTermIter.advance());
}
for (int i = 0; termIter.advance(); i += 3)
{
assertEquals((int)constraintTerms.get(i), termIter.getRow());
assertEquals(constraintTerms.get(i+1) + 1, termIter.getVariable());
assertEquals((int)constraintTerms.get(i+2), termIter.getCoefficient());
}
assertFalse(termIter.advance());
}
final String[] expectedConstraints2 = expectedConstraints;
if (expectedConstraints2 != null)
{
Iterator<IntegerEquation> constraintIter = constraints.iterator();
assertEquals(expectedConstraints2.length, solver.getNumberOfConstraints());
for (int i = 0, end = expectedConstraints2.length; i < end; ++i)
{
ByteArrayOutputStream out = new ByteArrayOutputStream();
IntegerEquation constraint = constraintIter.next();
constraint.print(new PrintStream(out));
String actual = out.toString().trim();
String expected = expectedConstraints2[i].trim();
if (!expected.equals(actual))
{
System.out.format("Constraint %d mismatch:\n", i);
System.out.format("Expected: %s\n", expected);
System.out.format(" Actual: %s\n", actual);
}
assertEquals(expected, actual);
}
}
// Test setting solution. A real solution will only use ones and zeros,
// but we will use random values to make sure they are assigned correctly.
double[] solution = new double[nLPVars];
Random rand = new Random();
for (int i = solution.length; --i>=0;)
{
solution[i] = rand.nextDouble();
}
solver.setSolution(solution);
for (Variable var : model.getVariables())
{
LPDiscrete svar = requireNonNull(solver.getSolverVariable(var));
double[] belief = svar.getBelief();
final PriorAndCondition known = svar.getPriorAndCondition();
Value fixedValue = known.value();
if (fixedValue != null)
{
int vali = fixedValue.getIndex();
for (int i = belief.length; --i>=0;)
{
if (i == vali)
{
assertEquals(1.0, belief[i], 1e-6);
}
else
{
assertEquals(0.0, belief[i], 1e-6);
}
}
}
else
{
DiscreteMessage prior = DiscreteEnergyMessage.convertFrom(svar.getDomain(), known);
for (int i = svar.getModelObject().getDomain().size(); --i>=0;)
{
int lpVar = svar.domainIndexToLPVar(i);
if (lpVar < 0)
{
if (prior != null)
assertEquals(prior.getWeight(i), belief[i], 0.0);
else
assertEquals(0, belief[i], 1e-6);
}
else
{
assertEquals(solution[lpVar], belief[i], 1e-6);
}
}
}
known.release();
}
solver.clearLPState();
assertInitialState();
}
private void assertInitialState()
{
assertFalse(solver.hasLPState());
assertNull(solver.getObjectiveFunction());
assertNull(solver.getConstraints());
assertEquals(-1, solver.getNumberOfConstraints());
assertEquals(-1, solver.getNumberOfLPVariables());
assertEquals(-1, solver.getNumberOfVariableConstraints());
assertEquals(-1, solver.getNumberOfMarginalConstraints());
// for (Variable var : model.getVariables())
// {
// assertNull(solver.getSolverVariable(var));
// }
MatlabConstraintTermIterator terms = solver.getMatlabSparseConstraints();
assertEquals(0, terms.size());
assertFalse(terms.advance());
assertEquals(-1, terms.getVariable());
assertEquals(0, terms.getCoefficient());
assertEquals(-1, terms.getRow());
ByteArrayOutputStream out = new ByteArrayOutputStream();
solver.printConstraints(new PrintStream(out));
assertEquals("Constraints not yet computed.", out.toString().trim());
assertEquals("", solver.getLPSolverName());
}
}