/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.core;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.lang.String.format;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.List;
import java.util.Set;
import org.eclipse.jdt.annotation.Nullable;
import org.junit.Test;
import com.analog.lyric.dimple.environment.DimpleEnvironment;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.And;
import com.analog.lyric.dimple.factorfunctions.Xor;
import com.analog.lyric.dimple.factorfunctions.core.CustomFactorFunctionWrapper;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Bit;
import com.analog.lyric.dimple.solvers.core.CustomFactors;
import com.analog.lyric.dimple.solvers.core.CustomFactorsOptionKey;
import com.analog.lyric.dimple.solvers.core.ISolverFactorCreator;
import com.analog.lyric.dimple.solvers.core.SolverFactorCreationException;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.analog.lyric.dimple.test.solvers.core.customFactors.MyCustomFactor;
import com.analog.lyric.dimple.test.solvers.core.customFactors.MyCustomXor;
/**
* Unit test for {@link CustomFactors}
* @since 0.08
* @author Christopher Barber
*/
public class TestCustomFactors extends DimpleTestBase
{
public static abstract class MyCustomFactorsBase extends CustomFactors<MyCustomFactor, ISolverFactorGraph>
{
private static final long serialVersionUID = 1L;
MyCustomFactorsBase(Class<MyCustomFactor> sfactorClass, Class<ISolverFactorGraph> sgraphClass)
{
super(sfactorClass, sgraphClass);
}
MyCustomFactorsBase(MyCustomFactorsBase other)
{
super(other);
}
}
public static class MyCustomFactors extends MyCustomFactorsBase
{
private static final long serialVersionUID = 1L;
public MyCustomFactors()
{
super(MyCustomFactor.class, ISolverFactorGraph.class);
}
MyCustomFactors(MyCustomFactors other)
{
super(other);
}
@Override
public CustomFactors<MyCustomFactor, ISolverFactorGraph> clone()
{
return new MyCustomFactors(this);
}
@Override
public void addBuiltins()
{
add(Xor.class, MyCustomXor.class);
}
@Override
public MyCustomFactor createDefault(Factor factor, ISolverFactorGraph sgraph)
{
return new MyCustomFactor(factor, sgraph, "default");
}
@Override
protected void freeze()
{
super.freeze();
}
@Override
public String qualifiedFactorFunctionName(String factorFunction)
{
return super.qualifiedFactorFunctionName(factorFunction);
}
}
public static class MyCustomXor3 extends MyCustomXor
{
public MyCustomXor3(Factor factor, ISolverFactorGraph parent)
{
super(factor, parent, "xor3");
}
public MyCustomXor3(Factor factor, ISolverFactorGraph parent, String tag)
{
super(factor, parent, tag);
}
}
public static class MyCustomXor4 extends MyCustomXor
{
public static @Nullable Throwable throwMe = null;
public MyCustomXor4(Factor factor, ISolverFactorGraph parent) throws Throwable
{
this(factor, parent, "xor4");
}
public MyCustomXor4(Factor factor, ISolverFactorGraph parent, String tag) throws Throwable
{
super(factor, parent, tag);
Throwable ex = throwMe;
if (ex != null)
throw ex;
}
}
public static class MyCustomSolverGraph extends SumProductSolverGraph
{
public MyCustomSolverGraph(FactorGraph factorGraph)
{
super(factorGraph, null);
}
@Override
public ISolverFactor createFactor(Factor factor)
{
return option.createFactor(factor, this);
}
}
@Test
public void test()
{
FactorGraph fg = new FactorGraph();
Bit b1 = new Bit();
Bit b2 = new Bit();
Factor xorFactor = fg.addFactor(new Xor(), b1, b2);
ISolverFactorGraph sfg = requireNonNull(fg.getSolver());
MyCustomFactors customFactors = new MyCustomFactors();
assertTrue(customFactors.isMutable());
assertTrue(customFactors.keySet().isEmpty());
assertTrue(customFactors.get("bogus").isEmpty());
assertTrue(customFactors.get("Xor").isEmpty());
assertEquals("MyCustomFactors()\n", customFactors.toString());
// Test getFactorClass helper method
assertSame(MyCustomXor.class, customFactors.getFactorClass("MyCustomXor"));
assertSame(MyCustomXor.class, customFactors.getFactorClass(MyCustomXor.class.getName()));
expectThrow(IllegalArgumentException.class, ".*ClassNotFoundException.*" ,customFactors, "getFactorClass", "CustomNormal");
// Test qualifiedFactorFunctionName helper method
assertEquals("com.analog.lyric.dimple.factorfunctions.Xor", customFactors.qualifiedFactorFunctionName("Xor"));
assertEquals("alias", customFactors.qualifiedFactorFunctionName("alias"));
try
{
customFactors.qualifiedFactorFunctionName("foo.bar");
fail("expected exception");
}
catch (IllegalArgumentException ex)
{
}
// Test addBuiltins
customFactors.addBuiltins();
Set<String> keys = customFactors.keySet();
assertEquals(1, keys.size());
assertTrue(keys.contains(Xor.class.getName()));
assertEquals(1, customFactors.get(Xor.class.getName()).size());
assertEquals(format("MyCustomFactors(\n\t%s = %s)\n", Xor.class.getName(), MyCustomXor.class.getName()),
customFactors.toString());
// Test defaultCreator helper methods
expectThrow(DimpleException.class, ".*NoSuchMethod.*", customFactors, "defaultCreator", getClass());
ISolverFactorCreator<MyCustomFactor, ISolverFactorGraph> creator = customFactors.defaultCreator("MyCustomXor");
assertEquals(MyCustomXor.class.getName(), creator.toString());
ISolverFactor sfactor = creator.create(xorFactor, sfg);
assertSame(xorFactor, sfactor.getModelObject());
assertSame(MyCustomXor.class, sfactor.getClass());
creator = customFactors.defaultCreator(MyCustomXor4.class);
assertSame(MyCustomXor4.class, creator.create(xorFactor, sfg).getClass());
MyCustomXor4.throwMe = new RuntimeException("xxx");
try
{
creator.create(xorFactor, sfg);
fail("expected exception");
}
catch (RuntimeException ex)
{
assertSame(MyCustomXor4.throwMe, ex);
}
MyCustomXor4.throwMe = new Exception("yyy");
try
{
creator.create(xorFactor, sfg);
fail("expected exception");
}
catch (RuntimeException ex)
{
assertSame(MyCustomXor4.throwMe, ex.getCause());
}
MyCustomXor4.throwMe = null;
// Test add methods
ISolverFactorCreator<MyCustomFactor,ISolverFactorGraph> xor1 =
new ISolverFactorCreator<MyCustomFactor,ISolverFactorGraph>() {
@Override
public MyCustomFactor create(Factor factor, ISolverFactorGraph sgraph)
{
return new MyCustomXor(factor, sgraph, "xor1");
}
};
customFactors.add("Xor", xor1);
ISolverFactorCreator<MyCustomFactor,ISolverFactorGraph> xor2 =
new ISolverFactorCreator<MyCustomFactor,ISolverFactorGraph>() {
@Override
public MyCustomFactor create(Factor factor, ISolverFactorGraph sgraph)
{
return new MyCustomXor(factor, sgraph, "xor2");
}
};
customFactors.add(Xor.class, xor2);
customFactors.add("Xor", MyCustomXor3.class);
customFactors.add(Xor.class, MyCustomXor4.class);
customFactors.add("xor_alias", "MyCustomXor");
List<ISolverFactorCreator<MyCustomFactor,ISolverFactorGraph>> creators =
customFactors.get(Xor.class.getName());
assertEquals(5, creators.size());
assertSame(xor1, creators.get(1));
assertEquals("xor1", creators.get(1).create(xorFactor, sfg).tag);
assertSame(xor2, creators.get(2));
assertSame(MyCustomXor3.class, creators.get(3).create(xorFactor, sfg).getClass());
assertSame(MyCustomXor4.class, creators.get(4).create(xorFactor, sfg).getClass());
creators = customFactors.get("xor_alias");
assertEquals(1, creators.size());
// Not documented, but you can change the list directly!
customFactors.get(Xor.class.getName()).clear();
assertTrue(customFactors.get(Xor.class.getName()).isEmpty());
customFactors.get("xor_alias").clear();
assertTrue(customFactors.get("xor_alias").isEmpty());
assertTrue(customFactors.keySet().isEmpty());
// Test addFirst methods
customFactors.addFirst("Xor", xor1);
customFactors.addFirst(Xor.class, xor2);
customFactors.addFirst("Xor", MyCustomXor3.class);
customFactors.addFirst(Xor.class, MyCustomXor4.class);
customFactors.addFirst("xor_alias", xor1);
customFactors.addFirst("xor_alias", "MyCustomXor");
creators = customFactors.get(Xor.class.getName());
assertEquals(4, creators.size());
assertSame(xor1, creators.get(3));
assertEquals("xor1", creators.get(3).create(xorFactor, sfg).tag);
assertSame(xor2, creators.get(2));
assertSame(MyCustomXor3.class, creators.get(1).create(xorFactor, sfg).getClass());
assertSame(MyCustomXor4.class, creators.get(0).create(xorFactor, sfg).getClass());
creators = customFactors.get("xor_alias");
assertEquals(2, creators.size());
// Test freeze
customFactors.freeze();
assertFalse(customFactors.isMutable());
expectThrow(UnsupportedOperationException.class, customFactors, "addBuiltins");
expectThrow(UnsupportedOperationException.class, customFactors, "add", "foo", "MyCustomXor");
expectThrow(UnsupportedOperationException.class, customFactors, "add", "foo", xor1);
expectThrow(UnsupportedOperationException.class, customFactors, "add", "foo", MyCustomXor.class);
expectThrow(UnsupportedOperationException.class, customFactors, "add", Xor.class, MyCustomXor.class);
expectThrow(UnsupportedOperationException.class, customFactors, "add", Xor.class, xor1);
expectThrow(UnsupportedOperationException.class, customFactors, "addFirst", "foo", "MyCustomXor");
expectThrow(UnsupportedOperationException.class, customFactors, "addFirst", "foo", xor1);
expectThrow(UnsupportedOperationException.class, customFactors, "addFirst", "foo", MyCustomXor.class);
expectThrow(UnsupportedOperationException.class, customFactors, "addFirst", Xor.class, MyCustomXor.class);
expectThrow(UnsupportedOperationException.class, customFactors, "addFirst", Xor.class, xor1);
}
public static CustomFactorsOptionKey<MyCustomFactor, ISolverFactorGraph, MyCustomFactorsBase> bogusOption =
new CustomFactorsOptionKey<>(TestCustomFactors.class, "bogusOption", MyCustomFactorsBase.class);
public static final CustomFactorsOptionKey<MyCustomFactor, ISolverFactorGraph, MyCustomFactors> option =
new CustomFactorsOptionKey<>(TestCustomFactors.class, "option", MyCustomFactors.class);
@Test
public void testOptionKey()
{
assertSame(MyCustomFactors.class, option.type());
MyCustomFactors defaultCustomFactors = option.defaultValue();
assertFalse(defaultCustomFactors.isMutable());
Set<String> keys = defaultCustomFactors.keySet();
assertEquals(1, keys.size());
assertTrue(keys.contains(Xor.class.getName()));
assertEquals(1, defaultCustomFactors.get(Xor.class.getName()).size());
DimpleEnvironment env = DimpleEnvironment.active();
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(null);
ISolverFactorGraph sfg = new MyCustomSolverGraph(fg);
Bit b1 = new Bit();
Bit b2 = new Bit();
Factor xorFactor = fg.addFactor(new Xor(), b1, b2);
Factor xorAliasFactor = fg.addFactor(new CustomFactorFunctionWrapper("xor-alias"), b1, b2);
Factor andFactor = fg.addFactor(new And(), b1, b2);
// getOrCreate
assertNull(option.get(env));
MyCustomFactors envCustomFactors = option.getOrCreate(env);
assertNotNull(envCustomFactors);
assertTrue(envCustomFactors.keySet().isEmpty());
assertSame(envCustomFactors, option.getOrCreate(env));
// Test a couple of unusual exception cases
expectThrow(DimpleException.class, bogusOption, "getOrCreate", env);
expectThrow(DimpleException.class, bogusOption, "defaultValue");
//
// Test createFactor
//
// default solver factor creation
MyCustomFactor sfactor = option.createFactor(andFactor, sfg);
assertSame(MyCustomFactor.class, sfactor.getClass());
assertSame(andFactor, sfactor.getModelObject());
assertEquals("default", sfactor.tag);
// override default by registering against FactorFunction base class
option.getOrCreate(env).addBuiltins();
option.getOrCreate(env).add(FactorFunction.class, new ISolverFactorCreator<MyCustomFactor,ISolverFactorGraph>() {
@Override
public MyCustomFactor create(Factor factor, ISolverFactorGraph sgraph)
{
return new MyCustomFactor(factor, sgraph, "custom-default");
}
});
sfactor = option.createFactor(andFactor, sfg);
assertEquals("custom-default", sfactor.tag);
// custom creation
sfactor = option.createFactor(xorFactor, sfg);
assertSame(MyCustomXor.class, sfactor.getClass());
assertEquals("xor", sfactor.tag);
// custom creation with alias
expectThrow(SolverFactorCreationException.class,
"Cannot find factor function or custom factor implementation.*",
option, "createFactor", xorAliasFactor, sfg);
option.getOrCreate(env).addFirst("xor-alias", MyCustomXor4.class);
MyCustomXor4.throwMe = new RuntimeException("last failure");
expectThrow(SolverFactorCreationException.class,
"Cannot find factor function 'xor-alias'.*last failure",
option, "createFactor", xorAliasFactor, sfg);
option.getOrCreate(env).add("xor-alias", MyCustomXor.class);
sfactor = option.createFactor(xorAliasFactor, sfg);
assertSame(MyCustomXor.class, sfactor.getClass());
option.getOrCreate(env).addFirst("xor-alias", MyCustomXor4.class);
MyCustomXor4.throwMe = new RuntimeException("last failure");
sfactor = option.createFactor(xorAliasFactor, sfg);
assertSame(MyCustomXor.class, sfactor.getClass());
MyCustomXor4.throwMe = null;
}
}