/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.model.sugar;
import java.lang.reflect.Array;
import java.util.ArrayDeque;
import java.util.Deque;
import com.analog.lyric.dimple.factorfunctions.ACos;
import com.analog.lyric.dimple.factorfunctions.ASin;
import com.analog.lyric.dimple.factorfunctions.ATan;
import com.analog.lyric.dimple.factorfunctions.Abs;
import com.analog.lyric.dimple.factorfunctions.Bernoulli;
import com.analog.lyric.dimple.factorfunctions.Beta;
import com.analog.lyric.dimple.factorfunctions.Binomial;
import com.analog.lyric.dimple.factorfunctions.ComplexAbs;
import com.analog.lyric.dimple.factorfunctions.ComplexExp;
import com.analog.lyric.dimple.factorfunctions.ComplexNegate;
import com.analog.lyric.dimple.factorfunctions.ComplexProduct;
import com.analog.lyric.dimple.factorfunctions.ComplexSum;
import com.analog.lyric.dimple.factorfunctions.ConstantPower;
import com.analog.lyric.dimple.factorfunctions.ConstantProduct;
import com.analog.lyric.dimple.factorfunctions.Cos;
import com.analog.lyric.dimple.factorfunctions.Cosh;
import com.analog.lyric.dimple.factorfunctions.Exp;
import com.analog.lyric.dimple.factorfunctions.Gamma;
import com.analog.lyric.dimple.factorfunctions.InverseGamma;
import com.analog.lyric.dimple.factorfunctions.Log;
import com.analog.lyric.dimple.factorfunctions.LogNormal;
import com.analog.lyric.dimple.factorfunctions.Negate;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.factorfunctions.Power;
import com.analog.lyric.dimple.factorfunctions.Product;
import com.analog.lyric.dimple.factorfunctions.Sin;
import com.analog.lyric.dimple.factorfunctions.Sqrt;
import com.analog.lyric.dimple.factorfunctions.Square;
import com.analog.lyric.dimple.factorfunctions.Sum;
import com.analog.lyric.dimple.factorfunctions.Tan;
import com.analog.lyric.dimple.factorfunctions.Tanh;
import com.analog.lyric.dimple.factorfunctions.Xor;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.Node;
import com.analog.lyric.dimple.model.domains.ComplexDomain;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.domains.RealJointDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Bit;
import com.analog.lyric.dimple.model.variables.Complex;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.model.variables.RealJoint;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableBlock;
import com.google.common.collect.ObjectArrays;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public class ModelSyntacticSugar
{
public static class CurrentModel implements AutoCloseable
{
public final FactorGraph graph;
CurrentModel(FactorGraph fg)
{
graph = fg;
}
@Override
public void close()
{
Deque<CurrentModel> stack = stateStack();
if (stack.peek() == this)
{
stack.pop();
}
}
}
private static final ThreadLocal<Deque<CurrentModel>> _currentState = new ThreadLocal<>();
private static Deque<CurrentModel> stateStack()
{
Deque<CurrentModel> state = _currentState.get();
if (state == null)
{
_currentState.set(state = new ArrayDeque<CurrentModel>());
}
return state;
}
public static CurrentModel using(FactorGraph fg)
{
CurrentModel state = new CurrentModel(fg);
stateStack().push(state);
return state;
}
/**
* Current model state used by model configuration functions in this class.
* <p>
* This is set for the current thread by the {@link #using(FactorGraph)} method. The previous
* state is restored when the returned state's {@link CurrentModel#close()} method is called.
* @throws IllegalStateException if there is no current model state
* @see #using
* @since 0.08
*/
public static CurrentModel current()
{
CurrentModel model = stateStack().peek();
if (model == null)
{
throw new IllegalStateException("No current Dimple model state has been set.");
}
return model;
}
/**
* The current {@link FactorGraph} used by model configuration functions in this class.
* <p>
* This is simply the {@link CurrentModel#graph} member of the {@link #current()} state.
* @since 0.08
*/
public static FactorGraph graph()
{
return current().graph;
}
/*-------
* Nodes
*/
/**
* Set label on a model object and return it.
* @since 0.08
*/
public static <T extends Node> T label(String label, T node)
{
node.setLabel(label);
return node;
}
/**
* Set label for all nodes in array and return.
* <p>
* @param labelPrefix is a prefix to which the array index will be added to form the name.
* <p>
* @since 0.08
*/
public static <T extends Node> T[] label(String labelPrefix, T[] nodes)
{
for (int i = nodes.length; --i>=0;)
{
nodes[i].setLabel(labelPrefix + i);
}
return nodes;
}
/**
* Set name on a model object and return it.
* @since 0.08
*/
public static <T extends Node> T name(String name, T node)
{
node.setName(name);
return node;
}
/**
* Set name for all nodes in array and return.
* <p>
* @param namePrefix is a prefix to which the array index will be added to form the name.
* <p>
* @since 0.08
*/
public static <T extends Node> T[] name(String namePrefix, T[] nodes)
{
for (int i = nodes.length; --i>=0;)
{
nodes[i].setName(namePrefix + i);
}
return nodes;
}
/*------------
* Variables
*/
public static Bit bit(String name)
{
return nameAndAdd(name, new Bit());
}
/**
* Adds a new variable block containing the specified variables.
* @since 0.08
*/
@SafeVarargs
public static VariableBlock block(Variable ... vars)
{
return graph().addVariableBlock(vars);
}
/**
* Adds variable as boundary variable to current {@link #graph}.
* @param var
* @return {@code var}
* @since 0.08
*/
public static <V extends Variable> V boundary(V var)
{
graph().addBoundaryVariables(var);
return var;
}
/**
* Add a new {@link Discrete} variable with given name and domain to current {@link #graph}.
* @since 0.08
*/
public static Discrete discrete(String name, DiscreteDomain domain)
{
return nameAndAdd(name, new Discrete(domain));
}
public static Discrete[] discretes(String namePrefix, DiscreteDomain domain, int n)
{
final Discrete[] vars = new Discrete[n];
for (int i = 0; i < n; ++i)
{
vars[i] = discrete(namePrefix + i, domain);
}
return vars;
}
public static Real fixed(String name, double value)
{
Real var = nameAndAdd(name, new Real());
var.setPrior(value);
return var;
}
public static Real[] fixed(String name, double ... values)
{
final int size = values.length;
Real[] vars = new Real[size];
for (int i = 0; i < size; ++i)
{
vars[i] = fixed(name + i, values[i]);
}
return vars;
}
public static RealJoint fixedJoint(String name, double ... value)
{
RealJoint var = realjoint(name, value.length);
var.setPrior(value);
return var;
}
public static Discrete fixed(String name, DiscreteDomain domain, Object value)
{
Discrete var = discrete(name, domain);
var.setPrior(value);
return var;
}
public static Complex complex(String name)
{
return complex(name, ComplexDomain.create());
}
public static Complex complex(String name, ComplexDomain domain)
{
return nameAndAdd(name, new Complex(domain));
}
/**
* Add a new {@link Real} variable with given name to current {@link #graph}.
* @since 0.08
*/
public static Real real(String name)
{
return real(name, RealDomain.unbounded());
}
/**
* Add a new {@link Real} variable with given name and domain to current {@link #graph}.
* @since 0.08
*/
public static Real real(String name, RealDomain domain)
{
Real var = new Real(domain);
var.setName(name);
graph().addVariables(var);
return var;
}
public static Real[] reals(String namePrefix, int size, RealDomain domain)
{
Real[] vars = new Real[size];
for (int i = 0; i < size; ++i)
{
vars[i] = nameAndAdd(namePrefix + i, new Real(domain));
}
return vars;
}
public static Real[] reals(String namePrefix, int size)
{
return reals(namePrefix, size, RealDomain.unbounded());
}
public static RealJoint realjoint(String name, RealJointDomain domain)
{
return nameAndAdd(name, new RealJoint(domain));
}
public static RealJoint realjoint(String name, int size)
{
return realjoint(name, RealJointDomain.create(size));
}
/*---------
* Factors
*/
public static Factor addFactor(FactorFunction function, Object ... args)
{
int argCount = args.length;
for (Object arg : args)
{
Class<?> argType = arg.getClass();
if (argType.isArray() && Variable.class.isAssignableFrom(argType.getComponentType()))
{
argCount += Array.getLength(arg) - 1;
}
}
Object[] expandedArgs = args;
if (argCount > args.length)
{
expandedArgs = new Object[argCount];
int i = 0;
for (Object arg : args)
{
Class<?> argType = arg.getClass();
if (argType.isArray() && Variable.class.isAssignableFrom(argType.getComponentType()))
{
Variable[] vars = (Variable[])arg;
for (Variable var : vars)
{
expandedArgs[i++] = var;
}
}
else
{
expandedArgs[i++] = arg;
}
}
args = expandedArgs;
}
return graph().addFactor(function, expandedArgs);
}
/*------------------
* Factor functions
*/
public static Real abs(Real x)
{
return addFactorWithRealFirst(new Abs(), x);
}
public static Complex abs(Complex x)
{
return addFactorWithComplexFirst(new ComplexAbs(), x);
}
public static Real acos(Real x)
{
return addFactorWithRealFirst(new ACos(), x);
}
public static Real asin(Real x)
{
return addFactorWithRealFirst(new ASin(), x);
}
public static Real atan(Real x)
{
return addFactorWithRealFirst(new ATan(), x);
}
public static Bit bernoulli(Object p)
{
Bit bit = new Bit();
if (p instanceof Number)
{
graph().addFactor(new Bernoulli(toDouble(p)), bit);
}
else
{
graph().addFactor(new Bernoulli(), p, bit);
}
return bit;
}
public static Real beta(Object alpha, Object beta)
{
if (alpha instanceof Number && beta instanceof Number)
{
return addFactorWithRealLast(new Beta(toDouble(alpha), toDouble(beta)));
}
else
{
return addFactorWithRealLast(new Beta(), alpha, beta);
}
}
public static Discrete binomial(int N, Object p)
{
Discrete var = new Discrete(DiscreteDomain.range(0, N));
graph().addFactor(new Binomial(N), p, var);
return var;
}
public static Real binomial(Variable N, Object p)
{
return addFactorWithRealLast(new Binomial(), N, p);
}
public static Real cos(Real x)
{
return addFactorWithRealFirst(new Cos(), x);
}
public static Real cosh(Real x)
{
return addFactorWithRealFirst(new Cosh(), x);
}
public static Real exp(Real x)
{
return addFactorWithRealFirst(new Exp(), x);
}
public static Complex exp(Complex x)
{
return addFactorWithComplexFirst(new ComplexExp(), x);
}
public static Real gamma(Object alpha, Object beta)
{
if (alpha instanceof Number && beta instanceof Number)
{
return addFactorWithRealLast(new Gamma(toDouble(alpha), toDouble(beta)));
}
else
{
return addFactorWithRealLast(new Gamma(), alpha, beta);
}
}
public static Real inversegamma(Object alpha, Object beta)
{
if (alpha instanceof Number && beta instanceof Number)
{
return addFactorWithRealLast(new InverseGamma(toDouble(alpha), toDouble(beta)));
}
else
{
return addFactorWithRealLast(new InverseGamma(), alpha, beta);
}
}
public static Real log(Real x)
{
return addFactorWithRealFirst(new Log(), x);
}
public static Real lognormal(Object mean, Object precision)
{
if (mean instanceof Number && precision instanceof Number)
{
return addFactorWithRealLast(new LogNormal(toDouble(mean), toDouble(precision)));
}
else
{
return addFactorWithRealLast(new LogNormal(), mean, precision);
}
}
public static Real negate(Real x)
{
return addFactorWithRealFirst(new Negate(), x);
}
public static Complex negate(Complex x)
{
return addFactorWithComplexFirst(new ComplexNegate(), x);
}
public static Real normal(Object mean, Object precision)
{
if (mean instanceof Number && precision instanceof Number)
{
return addFactorWithRealLast(new Normal(toDouble(mean), toDouble(precision)));
}
else
{
return addFactorWithRealLast(new Normal(), mean, precision);
}
}
public static void normal(Object mean, Object precision, Real ... vars)
{
if (mean instanceof Number && precision instanceof Number)
{
graph().addFactor(new Normal(toDouble(mean), toDouble(precision)), vars);
}
else
{
graph().addFactor(new Normal(), mean, precision, vars);
}
}
public static Real power(Real base, double exponent)
{
if (exponent == 2.0)
{
return square(base);
}
return addFactorWithRealFirst(new ConstantPower(exponent), base);
}
public static Real power(Real base, Real exponent)
{
return addFactorWithRealFirst(new Power(), base, exponent);
}
public static Real product(Real x, Real y)
{
return addFactorWithRealFirst(new Product(), x, y);
}
public static Real product(Real x, double y)
{
return addFactorWithRealFirst(new ConstantProduct(y), x);
}
public static Real product(double x, Real y)
{
return addFactorWithRealFirst(new ConstantProduct(x), y);
}
public static Complex product(Complex x, Complex y)
{
return addFactorWithComplexFirst(new ComplexProduct(), x, y);
}
public static Complex product(Complex x, Variable y)
{
return addFactorWithComplexFirst(new ComplexProduct(), x, y);
}
public static Complex product(Variable x, Complex y)
{
return addFactorWithComplexFirst(new ComplexProduct(), x, y);
}
public static Real square(Real x)
{
return addFactorWithRealFirst(new Square(), x);
}
public static Real sin(Real x)
{
return addFactorWithRealFirst(new Sin(), x);
}
public static Real sinh(Real x)
{
return addFactorWithRealFirst(new Sin(), x);
}
public static Real sqrt(Real x)
{
return addFactorWithRealFirst(new Sqrt(), x);
}
public static Real sum(Real x, Real y)
{
return addFactorWithRealFirst(new Sum(), x, y);
}
// public static Real sum(Real ... vars)
// {
// return addFactorWithRealFirstOutput(new Sum(), vars);
// }
public static Real sum(Real x, double y)
{
return addFactorWithRealFirst(new Sum(), x, y);
}
public static Real sum(double x, Real y)
{
return addFactorWithRealFirst(new Sum(), x, y);
}
public static Complex sum(Complex x, Variable y)
{
return addFactorWithComplexFirst(new ComplexSum(), x, y);
}
public static Complex sum(Variable x, Complex y)
{
return addFactorWithComplexFirst(new ComplexSum(), x, y);
}
public static Real tan(Real x)
{
return addFactorWithRealFirst(new Tan(), x);
}
public static Real tanh(Real x)
{
return addFactorWithRealFirst(new Tanh(), x);
}
public static Bit xor(Bit ... bits)
{
return addFactorWithBitFirst(new Xor(), (Object[])bits);
}
/*-----------------
* Private methods
*/
private static <V extends Variable> V nameAndAdd(String name, V var)
{
graph().addVariables(var);
var.setName(name);
return var;
}
private static Bit addFactorWithBitFirst(FactorFunction function, Object ... args)
{
Bit bit = new Bit();
graph().addFactor(function, ObjectArrays.concat(bit, args));
return bit;
}
private static Real addFactorWithRealFirst(FactorFunction function, Object ... args)
{
Real var = new Real();
graph().addFactor(function, ObjectArrays.concat(var, args));
return var;
}
private static Real addFactorWithRealLast(FactorFunction function, Object ... args)
{
Real var = new Real();
graph().addFactor(function, ObjectArrays.concat(args, var));
return var;
}
private static Complex addFactorWithComplexFirst(FactorFunction function, Object ... args)
{
Complex var = new Complex();
graph().addFactor(function, ObjectArrays.concat(var, args));
return var;
}
private static double toDouble(Object obj)
{
if (obj instanceof Number)
{
return ((Number)obj).doubleValue();
}
return Double.NaN;
}
}