/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.sumproduct;
import com.analog.lyric.dimple.factorfunctions.ComplexNegate;
import com.analog.lyric.dimple.factorfunctions.ComplexSubtract;
import com.analog.lyric.dimple.factorfunctions.ComplexSum;
import com.analog.lyric.dimple.factorfunctions.FiniteFieldAdd;
import com.analog.lyric.dimple.factorfunctions.FiniteFieldMult;
import com.analog.lyric.dimple.factorfunctions.FiniteFieldProjection;
import com.analog.lyric.dimple.factorfunctions.LinearEquation;
import com.analog.lyric.dimple.factorfunctions.MatrixRealJointVectorProduct;
import com.analog.lyric.dimple.factorfunctions.Multiplexer;
import com.analog.lyric.dimple.factorfunctions.MultivariateNormal;
import com.analog.lyric.dimple.factorfunctions.Negate;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.factorfunctions.Product;
import com.analog.lyric.dimple.factorfunctions.RealJointNegate;
import com.analog.lyric.dimple.factorfunctions.RealJointSubtract;
import com.analog.lyric.dimple.factorfunctions.RealJointSum;
import com.analog.lyric.dimple.factorfunctions.Subtract;
import com.analog.lyric.dimple.factorfunctions.Sum;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.VariablePredicates;
import com.analog.lyric.dimple.solvers.core.CustomFactors;
import com.analog.lyric.dimple.solvers.core.ISolverFactorCreator;
import com.analog.lyric.dimple.solvers.core.SolverFactorCreationException;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomComplexGaussianPolynomial;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomFiniteFieldAdd;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomFiniteFieldConstantMult;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomFiniteFieldMult;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomFiniteFieldProjection;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomGaussianLinear;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomGaussianLinearEquation;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomGaussianNegate;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomGaussianProduct;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomGaussianSubtract;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomGaussianSum;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomMultiplexer;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomMultivariateGaussianNegate;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomMultivariateGaussianProduct;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomMultivariateGaussianSubtract;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomMultivariateGaussianSum;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomMultivariateNormalConstantParameters;
import com.analog.lyric.dimple.solvers.sumproduct.customFactors.CustomNormalConstantParameters;
import com.analog.lyric.dimple.solvers.sumproduct.sampledfactor.SampledFactor;
import com.google.common.collect.Iterables;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public class SumProductCustomFactors extends CustomFactors<ISolverFactor, SumProductSolverGraph>
{
private static final long serialVersionUID = 1L;
/*--------------
* Construction
*/
public SumProductCustomFactors()
{
super(ISolverFactor.class, SumProductSolverGraph.class);
}
protected SumProductCustomFactors(SumProductCustomFactors other)
{
super(other);
}
@Override
public SumProductCustomFactors clone()
{
return new SumProductCustomFactors(this);
}
/*-----------------------
* CustomFactors methods
*/
@Override
public void addBuiltins()
{
add(ComplexNegate.class, CustomMultivariateGaussianNegate.class);
add(ComplexSubtract.class, CustomMultivariateGaussianSubtract.class);
add(ComplexSum.class, CustomMultivariateGaussianSum.class);
add(FiniteFieldAdd.class, CustomFiniteFieldAdd.class);
add(FiniteFieldMult.class, CustomFiniteFieldConstantMult.class);
add(FiniteFieldMult.class, CustomFiniteFieldMult.class);
add(FiniteFieldProjection.class, CustomFiniteFieldProjection.class);
add(LinearEquation.class, CustomGaussianLinearEquation.class);
add(MatrixRealJointVectorProduct.class, CustomMultivariateGaussianProduct.class);
add(Multiplexer.class, CustomMultiplexer.class);
add(MultivariateNormal.class, CustomMultivariateNormalConstantParameters.class);
add(Negate.class, CustomGaussianNegate.class);
add(Normal.class, CustomNormalConstantParameters.class);
add(Product.class, CustomGaussianProduct.class);
add(RealJointNegate.class, CustomMultivariateGaussianNegate.class);
add(RealJointSubtract.class, CustomMultivariateGaussianSubtract.class);
add(RealJointSum.class, CustomMultivariateGaussianSum.class);
add(Subtract.class, CustomGaussianSubtract.class);
add(Sum.class, CustomGaussianSum.class);
// Backwards compatibility
add("add", new ISolverFactorCreator<ISolverFactor, SumProductSolverGraph>() {
@Override
public ISolverFactor create(Factor factor, SumProductSolverGraph sgraph)
{
// We don't need to implement this using a single creator, but this way we can produce
// a better error message.
if (Iterables.all(factor.getSiblings(), VariablePredicates.isUnboundedReal()))
return new CustomGaussianSum(factor, sgraph);
if (Iterables.all(factor.getSiblings(), VariablePredicates.isUnboundedRealJoint()))
return new CustomMultivariateGaussianSum(factor, sgraph);
throw new SolverFactorCreationException("Variables must be unbounded and all Real or all RealJoint");
}
});
add("constmult", CustomGaussianProduct.class);
add("constmult", CustomMultivariateGaussianProduct.class);
add("finiteFieldAdd", CustomFiniteFieldAdd.class);
add("finiteFieldMult", CustomFiniteFieldConstantMult.class);
add("finiteFieldMult", CustomFiniteFieldMult.class);
add("finiteFieldProjection", CustomFiniteFieldProjection.class);
add("linear", CustomGaussianLinear.class);
add("multiplexerCPD", CustomMultiplexer.class);
add("polynomial", CustomComplexGaussianPolynomial.class);
}
@Override
public ISolverFactor createDefault(Factor factor, SumProductSolverGraph sgraph)
{
if (factor.isDiscrete())
{
@SuppressWarnings("deprecation") // FIXME remove when STableFactor removed
ISolverFactor sfactor = new STableFactor(factor, sgraph);
return sfactor;
}
else
{
// For non-discrete factor that doesn't have a custom factor, create a sampled factor
return new SampledFactor(factor, sgraph);
}
}
}