/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.MixedNormal;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.factorfunctions.Sum;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.domains.TypedDiscreteDomain;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.solvers.particleBP.ParticleBPReal;
import com.analog.lyric.dimple.solvers.particleBP.ParticleBPSolverGraph;
// TODO: move to particleBP test directory
public class RealVariableParticleBPTest extends DimpleTestBase
{
protected static boolean debugPrint = false;
protected static boolean repeatable = true;
// @Test
// public void profileTest1()
// {
// while (true)
// {
// basicTest1();
// }
// }
@SuppressWarnings({ "null", "deprecation" })
@Test
public void basicTest1()
{
if (debugPrint) System.out.println("== basicTest1 ==");
int numIterations = 10;
int numParticlesPerRealVariable = 200;
int numResamplingUpdatesPerParticle = 50;
FactorGraph graph = new FactorGraph();
graph.setSolverFactory(new com.analog.lyric.dimple.solvers.particleBP.Solver());
ParticleBPSolverGraph solver = (ParticleBPSolverGraph)graph.getSolver();
solver.setNumIterations(numIterations);
solver.setNumParticles(numParticlesPerRealVariable);
solver.setResamplingUpdatesPerParticle(numResamplingUpdatesPerParticle);
double aPriorMean = 1;
double aPriorSigma = 0.5;
double aPriorR = 1/(aPriorSigma*aPriorSigma);
double bPriorMean = -1;
double bPriorSigma = 2.;
double bPriorR = 1/(bPriorSigma*bPriorSigma);
Real a = new Real();
Real b = new Real();
a.setInputObject(new Normal(aPriorMean,aPriorR));
b.setInputObject(new Normal(bPriorMean, bPriorR));
a.setName("a");
b.setName("b");
double abMean = 0;
double abSigma = 1;
double abR = 1/(abSigma*abSigma);
graph.addFactor(new Normal(0,1), a, b);
ParticleBPReal sa = (ParticleBPReal)a.getSolver();
ParticleBPReal sb = (ParticleBPReal)b.getSolver();
sa.setProposalStandardDeviation(0.5);
sb.setProposalStandardDeviation(0.5);
if (repeatable) solver.setSeed(1); // Make this repeatable
graph.solve();
double[] aBelief = (double[])a.getBeliefObject();
double[] bBelief = (double[])b.getBeliefObject();
double[] aParticles = sa.getParticleValues();
double[] bParticles = sb.getParticleValues();
if (debugPrint)
{
System.out.print("aBelief = [");
for (int i = 0; i < aBelief.length; i++) System.out.print(aBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bBelief = [");
for (int i = 0; i < bBelief.length; i++) System.out.print(bBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("aParticles = [");
for (int i = 0; i < aParticles.length; i++) System.out.print(aParticles[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bParticles = [");
for (int i = 0; i < bParticles.length; i++) System.out.print(bParticles[i] + " ");
System.out.print("];\n");
}
int aNumPoints = 500;
double aLower = -3;
double aUpper = 3;
double[] aUniformPointSet = new double[aNumPoints];
for (int i = 0; i < aNumPoints; i++)
aUniformPointSet[i] = aLower + i*(aUpper-aLower)/aNumPoints;
double[] aUniformBelief = sa.getBelief(aUniformPointSet);
int bNumPoints = 500;
double bLower = -3;
double bUpper = 3;
double[] bUniformPointSet = new double[bNumPoints];
for (int i = 0; i < bNumPoints; i++)
bUniformPointSet[i] = bLower + i*(bUpper-bLower)/bNumPoints;
double[] bUniformBelief = sb.getBelief(bUniformPointSet);
if (debugPrint)
{
System.out.print("aUniformBelief = [");
for (int i = 0; i < aUniformBelief.length; i++) System.out.print(aUniformBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("aUniformPointSet = [");
for (int i = 0; i < aUniformPointSet.length; i++) System.out.print(aUniformPointSet[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bUniformBelief = [");
for (int i = 0; i < bUniformBelief.length; i++) System.out.print(bUniformBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bUniformPointSet = [");
for (int i = 0; i < bUniformPointSet.length; i++) System.out.print(bUniformPointSet[i] + " ");
System.out.print("];\n");
}
double aSolverMean = 0;
for (int i = 0; i < aUniformPointSet.length; i++) aSolverMean += aUniformPointSet[i] * aUniformBelief[i];
double bSolverMean = 0;
for (int i = 0; i < bUniformPointSet.length; i++) bSolverMean += bUniformPointSet[i] * bUniformBelief[i];
if (debugPrint) System.out.println("aSolverMean: " + aSolverMean);
if (debugPrint) System.out.println("bSolverMean: " + bSolverMean);
double aExpectedMean = (aPriorMean*aPriorR + abMean*abR)/(aPriorR + abR);
double bExpectedMean = (bPriorMean*bPriorR + abMean*abR)/(bPriorR + abR);
if (debugPrint) System.out.println("aExpectedMean: " + aExpectedMean);
if (debugPrint) System.out.println("bExpectedMean: " + bExpectedMean);
assertTrue(nearlyEquals(aSolverMean,0.7999989412684679));
assertTrue(nearlyEquals(bSolverMean,-0.19800348473801446));
}
@SuppressWarnings({ "null", "deprecation" })
@Test
public void basicTest2()
{
// Test a combination of real and discrete variables connected to a single factor
if (debugPrint) System.out.println("== basicTest2 ==");
int numIterations = 10;
int numParticlesPerRealVariable = 200;
int numResamplingUpdatesPerParticle = 50;
FactorGraph graph = new FactorGraph();
graph.setSolverFactory(new com.analog.lyric.dimple.solvers.particleBP.Solver());
ParticleBPSolverGraph solver = (ParticleBPSolverGraph)graph.getSolver();
solver.setNumIterations(numIterations);
solver.setNumParticles(numParticlesPerRealVariable);
solver.setResamplingUpdatesPerParticle(numResamplingUpdatesPerParticle);
double aPriorMean = 0;
double aPriorSigma = 5;
double aPriorR = 1/(aPriorSigma*aPriorSigma);
double bProb1 = 0.6;
double bProb0 = 1 - bProb1;
Real a = new Real();
a.setInputObject(new Normal(aPriorMean,aPriorR));
Discrete b = new Discrete(0,1);
b.setInput(bProb0, bProb1);
a.setName("a");
b.setName("b");
double fMean0 = -1;
double fSigma0 = 0.75;
double fMean1 = 1;
double fSigma1 = 0.75;
double fR0 = 1/(fSigma0*fSigma0);
double fR1 = 1/(fSigma1*fSigma1);
graph.addFactor(new MixedNormal(fMean0, fR0, fMean1, fR1), a, b);
ParticleBPReal sa = (ParticleBPReal)a.getSolver();
//SVariable sb = (SVariable)b.getSolver();
sa.setProposalStandardDeviation(1.0);
if (repeatable) solver.setSeed(1); // Make this repeatable
graph.solve();
int aNumPoints = 100;
double aLower = -3;
double aUpper = 3;
double[] aUniformPointSet = new double[aNumPoints];
for (int i = 0; i < aNumPoints; i++)
aUniformPointSet[i] = aLower + i*(aUpper-aLower)/aNumPoints;
double[] aUniformBelief = sa.getBelief(aUniformPointSet);
double[] aBelief = (double[])a.getBeliefObject();
double[] bBelief = (double[])b.getBeliefObject();
double[] aParticles = sa.getParticleValues();
@SuppressWarnings("unchecked")
TypedDiscreteDomain<Integer> bDomain = (TypedDiscreteDomain<Integer>) b.getDiscreteDomain();
if (debugPrint)
{
System.out.print("aUniformBelief = [");
for (int i = 0; i < aUniformBelief.length; i++) System.out.print(aUniformBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("aUniformPointSet = [");
for (int i = 0; i < aUniformPointSet.length; i++) System.out.print(aUniformPointSet[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("aBelief = [");
for (int i = 0; i < aBelief.length; i++) System.out.print(aBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("aParticles = [");
for (int i = 0; i < aParticles.length; i++) System.out.print(aParticles[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bBelief = [");
for (int i = 0; i < bBelief.length; i++) System.out.print(bBelief[i] + " ");
System.out.print("];\n");
}
double aSolverMean = 0;
for (int i = 0; i < aUniformPointSet.length; i++) aSolverMean += aUniformPointSet[i] * aUniformBelief[i];
double bSolverMean = 0;
for (int i = 0; i < bDomain.size(); i++) bSolverMean += bDomain.getElement(i) * bBelief[i];
if (debugPrint) System.out.println("aSolverMean: " + aSolverMean);
if (debugPrint) System.out.println("bSolverMean: " + bSolverMean);
double aExpectedMean = bProb0*(aPriorMean*aPriorR + fMean0*fR0)/(aPriorR + fR0) + bProb1*(aPriorMean*aPriorR + fMean1*fR1)/(aPriorR + fR1);
if (debugPrint) System.out.println("aExpectedMean: " + aExpectedMean);
if (debugPrint) System.out.println("bExpectedMean: " + bProb1);
assertEquals(aSolverMean, 0.1929829696757485, 1e-12);
assertEquals(bSolverMean, 0.5809732764581331, 1e-12);
}
@SuppressWarnings("null")
@Test
public void basicTest3()
{
// Test particle initialization based on the domain and setInitialParticleRange, which should override the domain
if (debugPrint) System.out.println("== basicTest3 ==");
int numIterations = 1;
int numParticlesPerRealVariable = 100;
FactorGraph graph = new FactorGraph();
graph.setSolverFactory(new com.analog.lyric.dimple.solvers.particleBP.Solver());
ParticleBPSolverGraph solver = (ParticleBPSolverGraph)graph.getSolver();
solver.setNumIterations(numIterations);
solver.setNumParticles(numParticlesPerRealVariable);
Real a = new Real();
Real b = new Real(new RealDomain(-10,10));
Real c = new Real(new RealDomain(-20,20));
Real d = new Real();
a.setName("a");
b.setName("b");
c.setName("c");
d.setName("d");
graph.addFactor(new Normal(0,1), a, b, c, d);
ParticleBPReal sa = (ParticleBPReal)a.getSolver();
ParticleBPReal sb = (ParticleBPReal)b.getSolver();
ParticleBPReal sc = (ParticleBPReal)c.getSolver();
ParticleBPReal sd = (ParticleBPReal)d.getSolver();
sa.setInitialParticleRange(-11, 14);
sb.setInitialParticleRange(0, 7);
graph.initialize();
double[] aParticles = sa.getParticleValues();
double[] bParticles = sb.getParticleValues();
double[] cParticles = sc.getParticleValues();
double[] dParticles = sd.getParticleValues();
if (debugPrint)
{
System.out.print("aParticles = [");
for (int i = 0; i < aParticles.length; i++) System.out.print(aParticles[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bParticles = [");
for (int i = 0; i < bParticles.length; i++) System.out.print(bParticles[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("cParticles = [");
for (int i = 0; i < cParticles.length; i++) System.out.print(cParticles[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("dParticles = [");
for (int i = 0; i < dParticles.length; i++) System.out.print(dParticles[i] + " ");
System.out.print("];\n");
}
assertTrue(nearlyEquals(aParticles[0],-11));
assertTrue(nearlyEquals(aParticles[numParticlesPerRealVariable-1],14));
assertTrue(nearlyEquals(bParticles[0],0));
assertTrue(nearlyEquals(bParticles[numParticlesPerRealVariable-1],7));
assertTrue(nearlyEquals(cParticles[0],-20));
assertTrue(nearlyEquals(cParticles[numParticlesPerRealVariable-1],20));
assertTrue(nearlyEquals(dParticles[0],0));
assertTrue(nearlyEquals(dParticles[numParticlesPerRealVariable-1],0));
}
@SuppressWarnings("deprecation")
@Test
public void basicTest4()
{
// Test tempering
if (debugPrint) System.out.println("== basicTest4 ==");
int numIterations = 50;
int numParticlesPerRealVariable = 20;
int numResamplingUpdatesPerParticle = 10;
FactorGraph graph = new FactorGraph();
graph.setSolverFactory(new com.analog.lyric.dimple.solvers.particleBP.Solver());
ParticleBPSolverGraph solver = requireNonNull((ParticleBPSolverGraph)graph.getSolver());
solver.setNumIterations(numIterations);
double aPriorMean = 1;
double aPriorSigma = 0.1;
double aPriorR = 1/(aPriorSigma*aPriorSigma);
double bPriorMean = 2;
double bPriorSigma = 0.1;
double bPriorR = 1/(bPriorSigma*bPriorSigma);
Real a = new Real();
Real b = new Real();
Real c = new Real();
a.setInputObject(new Normal(aPriorMean,aPriorR));
b.setInputObject(new Normal(bPriorMean,bPriorR));
a.setName("a");
b.setName("b");
c.setName("c");
graph.addFactor(new Sum(1.0),c,b,a);
ParticleBPReal sa = requireNonNull((ParticleBPReal)a.getSolver());
ParticleBPReal sb = requireNonNull((ParticleBPReal)b.getSolver());
ParticleBPReal sc = requireNonNull((ParticleBPReal)c.getSolver());
sa.setProposalStandardDeviation(0.1);
sb.setProposalStandardDeviation(0.1);
sc.setProposalStandardDeviation(0.1);
sa.setInitialParticleRange(0, 2);
sb.setInitialParticleRange(1, 3);
sc.setInitialParticleRange(2, 4);
// Test setting this after the variables have already been created
solver.setNumParticles(numParticlesPerRealVariable);
solver.setResamplingUpdatesPerParticle(numResamplingUpdatesPerParticle);
// Enable tempering
double initialTemperature = 1.0;
double temperingHalfLifeInIterations = 5;
solver.setInitialTemperature(initialTemperature);
solver.setTemperingHalfLifeInIterations(temperingHalfLifeInIterations);
assertTrue(solver.isTemperingEnabled()); // Make sure this automatically enabled tempering
if (repeatable) solver.setSeed(1); // Make this repeatable
graph.solve();
double[] aBelief = requireNonNull((double[])a.getBeliefObject());
double[] bBelief = requireNonNull((double[])b.getBeliefObject());
double[] cBelief = requireNonNull((double[])c.getBeliefObject());
double[] aParticles = sa.getParticleValues();
double[] bParticles = sb.getParticleValues();
double[] cParticles = sc.getParticleValues();
if (debugPrint)
{
System.out.print("aBelief = [");
for (int i = 0; i < aBelief.length; i++) System.out.print(aBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bBelief = [");
for (int i = 0; i < bBelief.length; i++) System.out.print(bBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("cBelief = [");
for (int i = 0; i < cBelief.length; i++) System.out.print(cBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("aParticles = [");
for (int i = 0; i < aParticles.length; i++) System.out.print(aParticles[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bParticles = [");
for (int i = 0; i < bParticles.length; i++) System.out.print(bParticles[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("cParticles = [");
for (int i = 0; i < cParticles.length; i++) System.out.print(cParticles[i] + " ");
System.out.print("];\n");
}
int aNumPoints = 500;
double aLower = -4;
double aUpper = 4;
double[] aUniformPointSet = new double[aNumPoints];
for (int i = 0; i < aNumPoints; i++)
aUniformPointSet[i] = aLower + i*(aUpper-aLower)/aNumPoints;
double[] aUniformBelief = sa.getBelief(aUniformPointSet);
int bNumPoints = 500;
double bLower = -4;
double bUpper = 4;
double[] bUniformPointSet = new double[bNumPoints];
for (int i = 0; i < bNumPoints; i++)
bUniformPointSet[i] = bLower + i*(bUpper-bLower)/bNumPoints;
double[] bUniformBelief = sb.getBelief(bUniformPointSet);
int cNumPoints = 500;
double cLower = -4;
double cUpper = 4;
double[] cUniformPointSet = new double[cNumPoints];
for (int i = 0; i < cNumPoints; i++)
cUniformPointSet[i] = cLower + i*(cUpper-cLower)/cNumPoints;
double[] cUniformBelief = sc.getBelief(cUniformPointSet);
if (debugPrint)
{
System.out.print("aUniformBelief = [");
for (int i = 0; i < aUniformBelief.length; i++) System.out.print(aUniformBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("aUniformPointSet = [");
for (int i = 0; i < aUniformPointSet.length; i++) System.out.print(aUniformPointSet[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bUniformBelief = [");
for (int i = 0; i < bUniformBelief.length; i++) System.out.print(bUniformBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("bUniformPointSet = [");
for (int i = 0; i < bUniformPointSet.length; i++) System.out.print(bUniformPointSet[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("cUniformBelief = [");
for (int i = 0; i < cUniformBelief.length; i++) System.out.print(cUniformBelief[i] + " ");
System.out.print("];\n");
}
if (debugPrint)
{
System.out.print("cUniformPointSet = [");
for (int i = 0; i < cUniformPointSet.length; i++) System.out.print(cUniformPointSet[i] + " ");
System.out.print("];\n");
}
double aSolverMean = 0;
for (int i = 0; i < aUniformPointSet.length; i++) aSolverMean += aUniformPointSet[i] * aUniformBelief[i];
double bSolverMean = 0;
for (int i = 0; i < bUniformPointSet.length; i++) bSolverMean += bUniformPointSet[i] * bUniformBelief[i];
double cSolverMean = 0;
for (int i = 0; i < cUniformPointSet.length; i++) cSolverMean += cUniformPointSet[i] * cUniformBelief[i];
if (debugPrint) System.out.println("aSolverMean: " + aSolverMean);
if (debugPrint) System.out.println("bSolverMean: " + bSolverMean);
if (debugPrint) System.out.println("cSolverMean: " + cSolverMean);
double aSolverVariance = 0;
for (int i = 0; i < aUniformPointSet.length; i++)
{
double pointDifference = aUniformPointSet[i] - aSolverMean;
aSolverVariance += pointDifference * pointDifference * aUniformBelief[i];
}
double aSolverStdMeasured = Math.sqrt(aSolverVariance);
double bSolverVariance = 0;
for (int i = 0; i < bUniformPointSet.length; i++)
{
double pointDifference = bUniformPointSet[i] - bSolverMean;
bSolverVariance += pointDifference * pointDifference * bUniformBelief[i];
}
double bSolverStdMeasured = Math.sqrt(bSolverVariance);
double cSolverVariance = 0;
for (int i = 0; i < cUniformPointSet.length; i++)
{
double pointDifference = cUniformPointSet[i] - cSolverMean;
cSolverVariance += pointDifference * pointDifference * cUniformBelief[i];
}
double cSolverStdMeasured = Math.sqrt(cSolverVariance);
if (debugPrint) System.out.println("aSolverStdMeasured: " + aSolverStdMeasured);
if (debugPrint) System.out.println("bSolverStdMeasured: " + bSolverStdMeasured);
if (debugPrint) System.out.println("cSolverStdMeasured: " + cSolverStdMeasured);
double cExpectedMean = aPriorMean + bPriorMean;
double cExpectedStd = Math.sqrt(aPriorSigma*aPriorSigma + bPriorSigma*bPriorSigma);
if (debugPrint) System.out.println("cExpectedMean: " + cExpectedMean);
if (debugPrint) System.out.println("cExpectedStd: " + cExpectedStd);
assertTrue(Math.abs(cSolverMean - 3.0) < 0.1);
assertTrue(cSolverStdMeasured <= cExpectedStd + 0.01);
}
private static double TOLLERANCE = 1e-12;
private boolean nearlyEquals(double a, double b)
{
double diff = a - b;
if (diff > TOLLERANCE) return false;
if (diff < -TOLLERANCE) return false;
return true;
}
}