/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.matlabproxy;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static org.junit.Assert.*;
import org.junit.Test;
import com.analog.lyric.dimple.environment.DimpleEnvironment;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.matlabproxy.ModelFactory;
import com.analog.lyric.dimple.matlabproxy.PCustomFactors;
import com.analog.lyric.dimple.matlabproxy.PDimpleEventLogger;
import com.analog.lyric.dimple.matlabproxy.PDiscreteDomain;
import com.analog.lyric.dimple.matlabproxy.PDiscreteVariableVector;
import com.analog.lyric.dimple.matlabproxy.PFactorGraphVector;
import com.analog.lyric.dimple.matlabproxy.PFactorTable;
import com.analog.lyric.dimple.matlabproxy.PFiniteFieldDomain;
import com.analog.lyric.dimple.matlabproxy.PFiniteFieldVariableVector;
import com.analog.lyric.dimple.matlabproxy.PHelpers;
import com.analog.lyric.dimple.matlabproxy.PLogger;
import com.analog.lyric.dimple.matlabproxy.PMultiplexerCPD;
import com.analog.lyric.dimple.matlabproxy.PRealDomain;
import com.analog.lyric.dimple.matlabproxy.PRealJointDomain;
import com.analog.lyric.dimple.matlabproxy.PRealJointVariableVector;
import com.analog.lyric.dimple.matlabproxy.PRealVariableVector;
import com.analog.lyric.dimple.matlabproxy.PScheduler;
import com.analog.lyric.dimple.matlabproxy.PTableFactorFunction;
import com.analog.lyric.dimple.matlabproxy.PVariableVector;
import com.analog.lyric.dimple.matlabproxy.repeated.PDoubleArrayDataSink;
import com.analog.lyric.dimple.matlabproxy.repeated.PDoubleArrayDataSource;
import com.analog.lyric.dimple.matlabproxy.repeated.PFactorFunctionDataSource;
import com.analog.lyric.dimple.matlabproxy.repeated.PMultivariateDataSink;
import com.analog.lyric.dimple.matlabproxy.repeated.PMultivariateDataSource;
import com.analog.lyric.dimple.matlabproxy.repeated.PVariableStreamBase;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.FiniteFieldDomain;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.domains.RealJointDomain;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.options.BPOptions;
import com.analog.lyric.dimple.schedulers.CustomScheduler;
import com.analog.lyric.dimple.schedulers.GibbsRandomScanScheduler;
import com.analog.lyric.dimple.schedulers.schedule.FixedSchedule;
import com.analog.lyric.dimple.schedulers.schedule.ISchedule;
import com.analog.lyric.dimple.schedulers.scheduleEntry.IScheduleEntry;
import com.analog.lyric.dimple.schedulers.scheduleEntry.NodeScheduleEntry;
import com.analog.lyric.dimple.solvers.core.multithreading.ThreadPool;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
import com.analog.lyric.dimple.solvers.gibbs.GibbsCustomFactors;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.interfaces.IFactorGraphFactory;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
* Unit tests for {@link ModelFactory}
* @since 0.08
* @author Christopher Barber
*/
public class TestModelFactory extends DimpleTestBase
{
private final ModelFactory mf = new ModelFactory();
// These are mostly lame tests of the code in ModelFactory itself and does not test
// the correctness of the returned objects for the most part.
@Test
public void createCustomFactors()
{
PCustomFactors pcf = mf.createCustomFactors("GibbsOptions.customFactors");
assertSame(GibbsCustomFactors.class, pcf.getDelegate().getClass());
expectThrow(DimpleException.class, "Cannot find option 'Bogus.customFactors'",
mf, "createCustomFactors", "Bogus.customFactors");
expectThrow(DimpleException.class, "Option 'BPOptions.damping' is not a CustomFactors option",
mf, "createCustomFactors", BPOptions.damping.qualifiedName());
}
@Test
public void createDomains()
{
PDiscreteDomain pdiscrete = mf.createDiscreteDomain(new Object[] { 0.0, 1.0, 2.0 });
assertSame(DiscreteDomain.create(0.0,1.0,2.0), pdiscrete.getDelegate());
PRealDomain punit = mf.createRealDomain(0.0, 1.0);
assertEquals(0.0, punit.getLowerBound(), 0.0);
assertEquals(1.0, punit.getUpperBound(), 1.0);
assertSame(RealDomain.create(0.0, 1.0), punit.getDelegate());
PRealDomain preal = mf.createRealDomain(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
assertSame(RealDomain.unbounded(), preal.getDelegate());
PRealJointDomain prj = mf.createRealJointDomain(new Object[] { punit, preal });
assertSame(RealJointDomain.create(RealDomain.create(0.0,1.0), RealDomain.unbounded()), prj.getDelegate());
PFiniteFieldDomain pff = mf.createFiniteFieldDomain(0x2f);
assertSame(0x2f, pff.getPrimitivePolynomial());
}
@Test
public void createEventLogger()
{
PDimpleEventLogger plogger = mf.createEventLogger();
assertEquals(0, plogger.verbosity());
assertEquals("stderr", plogger.filename());
plogger.close();
}
@Test
public void createFactorTable()
{
PDiscreteDomain pdd2 = mf.createDiscreteDomain(new Object[] {1, 2});
PDiscreteDomain pdd3 = mf.createDiscreteDomain(new Object[] { 1,2,3});
PDiscreteDomain[] domains = new PDiscreteDomain[] { pdd2, pdd3 };
PFactorTable pft = mf.createFactorTable(domains);
assertArrayEquals(domains, pft.getDomains());
IFactorTable ft = pft.getDelegate();
assertEquals(2, ft.getDimensions());
assertEquals(pdd2.getDelegate(), ft.getDomainIndexer().get(0));
assertEquals(pdd3.getDelegate(), ft.getDomainIndexer().get(1));
assertEquals(0, ft.countNonZeroWeights());
int[][] indices = new int[][] {
new int[] { 0, 0 },
new int[] { 1, 1 },
new int[] { 0, 2 },
new int[] { 1, 2 }
};
double[] values = new double[] {1,2,3,4};
pft = mf.createFactorTable(indices, values, domains);
for (int i = values.length; --i>=0;)
{
assertEquals(values[i], pft.get(indices[i]), 0.0);
}
PTableFactorFunction ptff = mf.createTableFactorFunction("ff", indices, values, domains);
assertEquals("ff", ptff.getDelegate().getName());
pft = ptff.getFactorTable();
for (int i = values.length; --i>=0;)
{
assertEquals(values[i], pft.get(indices[i]), 0.0);
}
double[][][] rawValues = new double[][][] {
new double[][] {
new double[] { 1, 5 },
new double[] { 3, 0 },
},
new double [][] {
new double[] { 2, 6 },
new double[] { 4, 8 },
}
};
pft = mf.createFactorTable(rawValues, new Object[] { pdd2, pdd2, pdd2});
assertArrayEquals(new double[] { 1,2,3,4,5,6,8 }, pft.getWeights(), 0.0);
}
@Test
public void createGraph()
{
PFactorGraphVector pfg = mf.createGraph(new Object[0]);
assertEquals(1, pfg.getDelegate().length);
FactorGraph fg = pfg.getGraph();
assertSame(fg, pfg.getDelegate()[0]);
assertEquals(0, fg.getBoundaryVariableCount());
PRealDomain prd = mf.createRealDomain(0.0, 1.0);
PDiscreteDomain pdd = mf.createDiscreteDomain(new Object[]{0.0,1.0});
PVariableVector vv1 = mf.createRealVariableVector(prd, 3);
PVariableVector vv2 = mf.createDiscreteVariableVector(pdd, 2);
PFactorGraphVector pfg2 = mf.createGraph(new Object[] { vv1, vv2 });
FactorGraph fg2 = pfg2.getGraph();
assertEquals(5, fg2.getBoundaryVariableCount());
assertSame(prd.getDelegate(), fg2.getBoundaryVariable(2).getDomain());
assertSame(pdd.getDelegate(), fg2.getBoundaryVariable(3).getDomain());
}
@Test
public void createParameterizedMessages()
{
// DiscreteMessage
DiscreteMessage discrete = mf.createDiscreteMessage("energy", 2, null);
assertFalse(discrete.storesWeights());
assertEquals(2, discrete.size());
assertArrayEquals(new double[2], discrete.getEnergies(), 0.0);
discrete = mf.createDiscreteMessage("weight", 3, new double[] { 1,2,3 });
assertTrue(discrete.storesWeights());
assertArrayEquals(new double[] {1,2,3}, discrete.getWeights(), 0.0);
// NormalParameters
NormalParameters normal = mf.createNormalParameters(1.0, 2.0);
assertEquals(1.0, normal.getMean(), 0.0);
assertEquals(2.0, normal.getPrecision(), 0.0);
// MultivariateNormalParameters
final double[] means = new double[] {1.0, 2.0};
final double[][] covariance = new double[][] {
new double[] {1.0, .5},
new double[] {.5, 2.0}
};
MultivariateNormalParameters multi = mf.createMultivariateNormalParameters(means, covariance);
assertArrayEquals(means, multi.getMeans(), 0.0);
assertArrayEquals(covariance[0], multi.getCovariance()[0], 0.0);
assertArrayEquals(covariance[1], multi.getCovariance()[1], 0.0);
}
@Test
public void createScheduler()
{
PScheduler scheduler = mf.createScheduler("GibbsRandomScanScheduler");
assertSame(GibbsRandomScanScheduler.class, scheduler.getDelegate().getClass());
expectThrow(RuntimeException.class, mf, "createScheduler", "NoSuchScheduler");
PFactorGraphVector pfg = mf.createGraph(new Object[0]);
PRealDomain prd = mf.createRealDomain(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
PVariableVector pvars = mf.createRealVariableVector(prd, 5);
Variable[] vars = pvars.getVariableArray();
pfg.getGraph().addVariables(vars);
scheduler = mf.createCustomScheduler(pfg, GibbsOptions.scheduler, new Object[] {
PHelpers.wrapObject(vars[0]),
PHelpers.wrapObject(vars[3])
});
assertTrue(scheduler.getDelegate() instanceof CustomScheduler);
ISchedule schedule = scheduler.getDelegate().createSchedule(pfg.getGraph());
assertTrue(schedule instanceof FixedSchedule);
FixedSchedule fixed = (FixedSchedule)schedule;
assertEquals(2, fixed.size());
IScheduleEntry entry = fixed.get(0);
assertEquals(IScheduleEntry.Type.NODE, entry.type());
assertEquals(vars[0], ((NodeScheduleEntry)entry).getNode());
assertEquals(vars[3], ((NodeScheduleEntry)fixed.get(1)).getNode());
scheduler = mf.createCustomScheduler(pfg, "GibbsOptions.scheduler", new Object[] {
PHelpers.wrapObject(vars[1])
});
schedule = scheduler.getDelegate().createSchedule(pfg.getGraph());
fixed = (FixedSchedule)schedule;
assertEquals(1, fixed.size());
assertEquals(vars[1], ((NodeScheduleEntry)fixed.get(0)).getNode());
scheduler = mf.createCustomScheduler(pfg, null, new Object[] {
PHelpers.wrapObject(vars[2])
});
schedule = scheduler.getDelegate().createSchedule(pfg.getGraph());
fixed = (FixedSchedule)schedule;
assertEquals(1, fixed.size());
assertEquals(vars[2], ((NodeScheduleEntry)fixed.get(0)).getNode());
scheduler = mf.createCustomScheduler(pfg, "", new Object[] {
PHelpers.wrapObject(vars[2])
});
schedule = scheduler.getDelegate().createSchedule(pfg.getGraph());
fixed = (FixedSchedule)schedule;
assertEquals(1, fixed.size());
assertEquals(vars[2], ((NodeScheduleEntry)fixed.get(0)).getNode());
}
@SuppressWarnings("deprecation")
@Test
public void createVariableVectors()
{
final RealDomain rd = RealDomain.create(0.0,1.0);
final PRealDomain prd = mf.createRealDomain(0.0, 1.0);
assertSame(rd, prd.getDelegate());
PRealVariableVector prv = mf.createRealVariableVector(prd, 2);
assertEquals(2, prv.size());
assertSame(rd, prv.getVariable(0).getDomain());
assertSame(rd, prv.getVariable(1).getDomain());
prv = mf.createRealVariableVector("ignored", prd, 3);
assertEquals(3, prv.size());
assertSame(rd, prv.getVariable(0).getDomain());
DiscreteDomain dd = DiscreteDomain.bit();
PDiscreteDomain pdd = mf.createDiscreteDomain(new Object[] {0,1});
assertSame(dd, pdd.getDelegate());
PDiscreteVariableVector pdv = mf.createDiscreteVariableVector(pdd, 3);
assertEquals(3, pdv.size());
assertSame(dd, pdv.getVariable(0).getDomain());
assertSame(dd, pdv.getVariable(2).getDomain());
pdv = mf.createDiscreteVariableVector("gag", pdd, 2);
assertEquals(2, pdv.size());
assertSame(dd, pdv.getVariable(0).getDomain());
pdv = mf.createVariableVector("xxx", pdd, 4);
assertEquals(4, pdv.size());
assertSame(dd, pdv.getVariable(3).getDomain());
RealJointDomain rjd = RealJointDomain.create(rd, 2);
assertSame(rjd, RealJointDomain.create(rd,rd));
PRealJointDomain prdj = mf.createRealJointDomain(new Object[] { prd, prd } );
assertSame(rjd, prdj.getDelegate());
PRealJointVariableVector prjv = mf.createRealJointVariableVector(prdj, 2);
assertEquals(2, prjv.size());
assertSame(rjd, prjv.getVariable(0).getDomain());
assertSame(rjd, prjv.getVariable(1).getDomain());
prjv = mf.createRealJointVariableVector("bogus", prdj, 3);
assertEquals(3, prjv.size());
assertSame(rjd, prjv.getVariable(2).getDomain());
FiniteFieldDomain ffd = FiniteFieldDomain.create(0x2f);
PFiniteFieldDomain pffd = mf.createFiniteFieldDomain(0x2f);
assertSame(ffd, pffd.getDelegate());
PFiniteFieldVariableVector pffv = mf.createFiniteFieldVariableVector(pffd, 2);
assertEquals(2, pffv.size());
assertSame(ffd, pffv.getVariable(0).getDomain());
}
@Test
public void getLogger()
{
PLogger logger = mf.getLogger();
assertSame(logger, mf.getLogger());
}
@Test
public void getMultiplexerCPD()
{
PDiscreteDomain pdd2 = mf.createDiscreteDomain(new Object[] {1,2});
PMultiplexerCPD pmcpd = mf.getMultiplexerCPD(new Object[]{1,2}, 3);
assertEquals(1, pmcpd.size());
assertSame(pdd2.getDelegate(), pmcpd.getY().getDomain().getDelegate());
assertEquals(3, pmcpd.getZs().length);
pmcpd = mf.getMultiplexerCPD(new Object[][] { new Object[] {1,2}, new Object [] {3,4}});
assertEquals(1, pmcpd.size());
assertEquals(2, pmcpd.getZs().length);
}
@Test
public void setSolver()
{
mf.setSolver(null);
assertNull(DimpleEnvironment.active().defaultSolver());
IFactorGraphFactory<?> solverFactory = new GibbsSolver();
mf.setSolver(solverFactory);
assertSame(solverFactory, DimpleEnvironment.active().defaultSolver());
}
@Test
public void streamOperations()
{
PDiscreteDomain pdd = mf.createDiscreteDomain(new Object[] { 0.0,1.0 });
PVariableStreamBase pds = mf.createDiscreteStream(pdd, 2);
assertEquals(2, pds.getModelerObjects().length);
assertSame(pdd.getDelegate(), pds.getModelerObjects()[0].getDomain());
PRealDomain prd = mf.createRealDomain(0.0,1.0);
PVariableStreamBase prs = mf.createRealStream(prd, 2);
assertEquals(2, prs.getModelerObjects().length);
assertSame(prd.getDelegate(), prs.getModelerObjects()[0].getDomain());
PRealJointDomain prjd = mf.createRealJointDomain(new Object[] { prd, prd });
PVariableStreamBase prjs = mf.createRealJointStream(prjd, 2);
assertEquals(2, prjs.getModelerObjects().length);
assertSame(prjd.getDelegate(), prjs.getModelerObjects()[0].getDomain());
PFactorFunctionDataSource pffds = mf.getFactorFunctionDataSource(3);
assertEquals(3, pffds.getModelObjects().length);
PDoubleArrayDataSource pdads = mf.getDoubleArrayDataSource(2);
assertEquals(2, pdads.getModelObjects().length);
PDoubleArrayDataSink pdadsink = mf.getDoubleArrayDataSink(3);
assertEquals(3, pdadsink.getModelObjects().length);
PMultivariateDataSource pmds = mf.getMultivariateDataSource(4);
assertEquals(4, pmds.getModelObjects().length);
PMultivariateDataSink pmdsink = mf.getMultivariateDataSink(2);
assertEquals(2, pmdsink.getModelObjects().length);
}
@Test
public void threadOperations()
{
// Thread operations
ThreadPool.setNumThreadsToDefault();
int nthreads = mf.getNumThreads();
mf.setNumThreads(42);
assertEquals(42, mf.getNumThreads());
assertEquals(42, ThreadPool.getNumThreads());
mf.setNumThreadsToDefault();
assertEquals(nthreads, mf.getNumThreads());
}
}