/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.FactorFunctions;
import static org.junit.Assert.*;
import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Collections;
import java.util.Map;
import java.util.TreeMap;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.Bernoulli;
import com.analog.lyric.dimple.factorfunctions.Beta;
import com.analog.lyric.dimple.factorfunctions.Categorical;
import com.analog.lyric.dimple.factorfunctions.Dirichlet;
import com.analog.lyric.dimple.factorfunctions.Gamma;
import com.analog.lyric.dimple.factorfunctions.InverseGamma;
import com.analog.lyric.dimple.factorfunctions.LogNormal;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.factorfunctions.Poisson;
import com.analog.lyric.dimple.factorfunctions.Rayleigh;
import com.analog.lyric.dimple.factorfunctions.VonMises;
import com.analog.lyric.dimple.factorfunctions.core.IParametricFactorFunction;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
* Test implementations of IParametricFactorFunction interface.
* <p>
* @since 0.07
* @author Christopher Barber
*/
public class TestParametricFactorFunction extends DimpleTestBase
{
@Test
public void test()
{
Map<String,Object> emptyMap = Collections.emptyMap();
// Bernoulli
Bernoulli bernoulli = new Bernoulli(.4);
assertEquals(.4, bernoulli.getParameter(), 0.0);
assertEquals(.4, bernoulli.getParameter("p"));
assertInvariants(bernoulli);
bernoulli = new Bernoulli(emptyMap);
assertEquals(.5, bernoulli.getParameter(), 0.0);
assertInvariants(bernoulli);
assertInvariants(new Bernoulli());
// Beta
Beta beta = new Beta(3,4);
assertEquals(2, beta.getAlphaMinusOne(), 0.0);
assertEquals(3, beta.getBetaMinusOne(), 0.0);
assertEquals(3.0, beta.getParameter("alpha"));
assertEquals(4.0, beta.getParameter("beta"));
assertInvariants(beta);
beta = new Beta(emptyMap);
assertEquals(1.0, beta.getParameter("alpha"));
assertEquals(1.0, beta.getParameter("beta"));
assertInvariants(new Beta());
// Categorical
Categorical categorical = new Categorical(new double[] { .6, .8 ,.6 }); // will be normalized
assertArrayEquals(new double[] {.3,.4,.3}, (double[])categorical.getParameter("alpha"), 0.0);
assertArrayEquals(new double[] {.3,.4,.3}, (double[])categorical.getParameter("alphas"), 0.0);
assertInvariants(categorical);
assertInvariants(new Categorical());
// Dirichlet
Dirichlet dirichlet= new Dirichlet(new double[] { .6, .8 ,.6 }); // not normalized
assertArrayEquals(new double[] {.6,.8,.6}, (double[])dirichlet.getParameter("alpha"), 0.0);
assertArrayEquals(new double[] {.6,.8,.6}, (double[])dirichlet.getParameter("alphas"), 0.0);
assertInvariants(dirichlet);
assertInvariants(new Dirichlet());
// Gamma
Gamma gamma = new Gamma(2,3);
assertEquals(2.0, gamma.getParameter("alpha"));
assertEquals(1.0, gamma.getAlphaMinusOne(), 0.0);
assertEquals(3.0, gamma.getParameter("beta"));
assertEquals(3.0, gamma.getBeta(), 0.0);
assertInvariants(gamma);
gamma = new Gamma(emptyMap);
assertEquals(1.0, gamma.getParameter("alpha"));
assertEquals(1.0, gamma.getParameter("beta"));
assertInvariants(new Gamma());
// InverseGamma
InverseGamma inverseGamma = new InverseGamma(2,3);
assertEquals(2.0, inverseGamma.getParameter("alpha"));
assertEquals(3.0, inverseGamma.getParameter("beta"));
assertInvariants(inverseGamma);
inverseGamma = new InverseGamma(emptyMap);
assertEquals(1.0, inverseGamma.getParameter("alpha"));
assertEquals(1.0, inverseGamma.getParameter("beta"));
assertInvariants(new Gamma());
// LogNormal
LogNormal logNormal = new LogNormal(2.0, .5);
assertEquals(2.0, logNormal.getMean(), 0.0);
assertEquals(.5, logNormal.getPrecision(), 0.0);
assertInvariants(logNormal);
logNormal = new LogNormal(newMap("mu", 1.5, "sigma", 2.0));
assertEquals(1.5, logNormal.getParameter("mean"));
assertEquals(.25, logNormal.getParameter("precision"));
assertEquals(2.0, logNormal.getParameter("sigma"));
assertInvariants(logNormal);
logNormal = new LogNormal(newMap("variance", 9.0));
assertEquals(9.0, logNormal.getParameter("variance"));
assertEquals(3.0, logNormal.getParameter("std"));
assertInvariants(logNormal);
logNormal = new LogNormal(emptyMap);
assertEquals(0.0, logNormal.getMean(), 0.0);
assertEquals(1.0, logNormal.getPrecision(), 0.0);
assertInvariants(new LogNormal());
// Normal
Normal normal = new Normal(2.0, .5);
assertEquals(2.0, normal.getMean(), 0.0);
assertEquals(.5, normal.getPrecision(), 0.0);
assertInvariants(normal);
normal = new Normal(newMap("mu", 1.5, "sigma", 2.0));
assertEquals(1.5, normal.getParameter("mean"));
assertEquals(.25, normal.getParameter("precision"));
assertEquals(2.0, normal.getParameter("sigma"));
assertInvariants(normal);
normal = new Normal(newMap("variance", 9.0));
assertEquals(3.0, normal.getStandardDeviation(), 0.0);
assertEquals(9.0, normal.getParameter("variance"));
assertEquals(3.0, normal.getParameter("std"));
assertInvariants(normal);
normal = new Normal(emptyMap);
assertEquals(0.0, normal.getMean(), 0.0);
assertEquals(1.0, normal.getPrecision(), 0.0);
assertInvariants(new Normal());
// Poisson
Poisson poisson = new Poisson(.3);
assertEquals(.3, poisson.getLambda(), 0.0);
assertEquals(.3, poisson.getParameter("lambda"));
assertInvariants(poisson);
poisson = new Poisson(emptyMap);
assertEquals(1.0, poisson.getParameter("lambda"));
assertInvariants(new Poisson());
// Rayleigh
Rayleigh rayleigh = new Rayleigh(2.3);
assertInvariants(rayleigh);
rayleigh = new Rayleigh(emptyMap);
assertEquals(1.0, rayleigh.getParameter("sigma"));
assertInvariants(new Rayleigh());
// VonMises
VonMises vonMises = new VonMises(2.0, .5);
assertEquals(2.0, vonMises.getParameter("mean"));
assertEquals(.5, vonMises.getParameter("precision"));
assertInvariants(vonMises);
vonMises = new VonMises(newMap("mu", 1.5, "sigma", 2.0));
assertEquals(1.5, vonMises.getParameter("mean"));
assertEquals(.25, vonMises.getParameter("precision"));
assertEquals(2.0, vonMises.getParameter("sigma"));
assertInvariants(vonMises);
vonMises = new VonMises(newMap("variance", 9.0));
assertEquals(9.0, vonMises.getParameter("variance"));
assertEquals(3.0, vonMises.getParameter("std"));
assertInvariants(vonMises);
vonMises = new VonMises(emptyMap);
assertEquals(0.0, vonMises.getParameter("mean"));
assertEquals(1.0, vonMises.getParameter("precision"));
assertInvariants(new VonMises());
}
private void assertInvariants(IParametricFactorFunction function)
{
assertTrue(function.isParametric());
Map<String,Object> parameters = new TreeMap<>();
int nCopied = function.copyParametersInto(parameters);
assertEquals(nCopied, parameters.size());
if (function.hasConstantParameters())
{
assertTrue(nCopied > 0);
}
else
{
assertEquals(0, nCopied);
}
for (String name : parameters.keySet())
{
Object val1 = parameters.get(name);
Object val2 = function.getParameter(name);
assertNotNull(val1);
assertNotNull(val2);
assertReallyEquals(val1,val2);
}
assertNull(function.getParameter("bogusParameterName"));
try
{
Constructor<? extends IParametricFactorFunction> constructor =
function.getClass().getConstructor(Map.class);
IParametricFactorFunction function2 = constructor.newInstance(parameters);
assertTrue(function2.hasConstantParameters());
Map<String,Object> parameters2 = new TreeMap<>();
function2.copyParametersInto(parameters2);
assertFalse(parameters2.isEmpty());
assertTrue(parameters2.size() >= parameters.size());
if (function.hasConstantParameters())
{
assertEquals(parameters.keySet(), parameters2.keySet());
for (String name : parameters.keySet())
{
assertReallyEquals(parameters.get(name), parameters2.get(name));
}
}
else
{
for (String name : parameters2.keySet())
{
assertNull(parameters.get(name));
}
}
}
catch (InvocationTargetException ex)
{
// Not all parametric functions have default values, so constructing using an empty
// constructor will not always work.
}
catch (Exception ex)
{
fail(ex.toString());
}
}
private void assertReallyEquals(Object val1, Object val2)
{
if (val1.getClass().isArray())
{
final int length1 = Array.getLength(val1);
final int length2 = Array.getLength(val2);
assertEquals(length1, length2);
for (int i = length1; --i>=0;)
{
assertEquals(Array.get(val1,i), Array.get(val2,i));
}
}
else
{
assertEquals(val1,val2);
}
}
private Map<String,Object> newMap(Object ... args)
{
Map<String,Object> map = new TreeMap<>();
for (int i = 0; i < args.length; i += 2)
{
map.put((String)args[i], args[i+1]);
}
return map;
}
}