/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.FactorFunctions;
import static com.analog.lyric.dimple.test.FactorFunctions.FactorFunctionTester.*;
import static com.analog.lyric.math.Utilities.*;
import static org.junit.Assert.*;
import org.junit.Test;
import com.analog.lyric.dimple.exceptions.InvalidDistributionException;
import com.analog.lyric.dimple.factorfunctions.Categorical;
import com.analog.lyric.dimple.factorfunctions.CategoricalBase;
import com.analog.lyric.dimple.factorfunctions.CategoricalEnergyParameters;
import com.analog.lyric.dimple.factorfunctions.CategoricalUnnormalizedParameters;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public class TestCategorical extends DimpleTestBase
{
@SuppressWarnings({ "unused", "deprecation" })
@Test
public void test()
{
Categorical c0 = new Categorical();
assertInvariants(c0);
assertArrayEquals(new int[] {1}, c0.getDirectedToIndices(2));
assertArrayEquals(new int[] {1,2}, c0.getDirectedToIndices(3));
assertFalse(c0.hasConstantParameters());
assertEvalEnergy(c0, 0.0, Value.createRealJoint(1.0, 2.0), Value.createReal(0.0));
assertEvalEnergy(c0, weightToEnergy(2.0), Value.createRealJoint(1.0, 2.0), Value.createReal(1));
assertEvalEnergy(c0, 0.0, Value.createRealJoint(2.0, 1.0), Value.create(DiscreteDomain.bool(), true));
Categorical c2 = new Categorical(new double[] { .6, 1.4 });
assertInvariants(c2);
assertArrayEquals(new double[] { .3, .7 }, c2.getParameters(), 0.0);
assertTrue(c2.hasConstantParameters());
assertArrayEquals(new int[] { 0 , 1 }, c2.getDirectedToIndices(2));
assertEvalEnergyBit(c2, weightToEnergy(.3), 0);
assertEvalEnergyBit(c2, weightToEnergy(.7), 1);
assertEvalEnergyBit(c2, weightToEnergy(.3 * .7), 0, 1);
try
{
new Categorical(new double[] { .3, -.4 });
fail("expected InvalidDistributionException");
}
catch (InvalidDistributionException ex)
{
}
CategoricalUnnormalizedParameters cup = new CategoricalUnnormalizedParameters(2.0);
assertInvariants(cup);
assertEquals(2, cup.getDimension());
assertArrayEquals(new int[] {2}, cup.getDirectedToIndices(3));
assertFalse(cup.hasConstantParameters());
assertEvalEnergyReal(cup, weightToEnergy(.3), .6, 1.4, 0);
assertEvalEnergyReal(cup, Double.POSITIVE_INFINITY, .6, -1.4, 0);
assertEvalEnergyReal(cup, weightToEnergy(.3 * .7), .6, 1.4, 1, 0);
CategoricalUnnormalizedParameters cup2 = new CategoricalUnnormalizedParameters(2, new double[] {.2, .3});
assertInvariants(cup2);
assertTrue(cup2.hasConstantParameters());
assertArrayEquals(new double[] { .4, .6 }, cup2.getParameters(), 0.0);
assertEvalEnergyBool(cup2, weightToEnergy(.4), false);
CategoricalEnergyParameters cep = new CategoricalEnergyParameters(2.0);
assertInvariants(cep);
assertEquals(2, cep.getDimension());
assertArrayEquals(new int[] {2}, cep.getDirectedToIndices(3));
assertFalse(cep.hasConstantParameters());
assertEvalEnergyReal(cep, weightToEnergy(.3), weightToEnergy(.6), weightToEnergy(1.4), 0);
assertEvalEnergyReal(cep, weightToEnergy(.3 * .7), weightToEnergy(.6), weightToEnergy(1.4), 1, 0);
CategoricalEnergyParameters cep2 =
new CategoricalEnergyParameters(2, new double[] {weightToEnergy(.2), weightToEnergy(.3)});
assertInvariants(cep2);
assertTrue(cep2.hasConstantParameters());
assertArrayEquals(new double[] { weightToEnergy(.4), weightToEnergy(.6) }, cep2.getParameters(), 1e-10);
assertEvalEnergyBool(cep2, weightToEnergy(.4), false);
assertEvalEnergyBool(cep2, weightToEnergy(.4 * .6), false, true);
}
private void assertInvariants(CategoricalBase c)
{
assertTrue(c.objectEquals(c));
assertFalse(c.objectEquals("bogus"));
assertTrue(c.isDirected());
assertEquals(c.getParameters().length, c.getDimension());
if (c.hasConstantParameters())
{
assertArrayEquals(c.getParameters(), c.getParameter("alpha"), 0.0);
assertArrayEquals(c.getParameters(), c.getParameter("alphas"), 0.0);
assertNull(c.getParameter("beta"));
}
else
{
assertNull(c.getParameter("alpha"));
assertNull(c.getParameter("alphas"));
}
CategoricalBase clone = (CategoricalBase)c.clone();
assertTrue(c.objectEquals(clone));
}
}