/******************************************************************************* * Copyright 2012-2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.variables; import static java.lang.String.*; import java.util.Comparator; import java.util.List; import org.eclipse.jdt.annotation.NonNull; import org.eclipse.jdt.annotation.NonNullByDefault; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.data.DataLayer; import com.analog.lyric.dimple.data.IDatum; import com.analog.lyric.dimple.events.IDataEventSource; import com.analog.lyric.dimple.events.IDimpleEventListener; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.Equality; import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.Ids; import com.analog.lyric.dimple.model.core.Node; import com.analog.lyric.dimple.model.core.NodeType; import com.analog.lyric.dimple.model.core.VariablePort; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.factors.FactorBase; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.solvers.core.SNode; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteWeightMessage; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable; import com.analog.lyric.util.misc.Internal; import com.google.common.primitives.Longs; /** * Base class for model variables in Dimple * * @since 0.07 */ public abstract class Variable extends Node implements Cloneable, IDataEventSource, IConstantOrVariable { /*----------- * Constants */ /** * {@link #_topologicalFlags} value used by {@link #isDeterministicInput()} */ private static final byte DETERMINISTIC_INPUT = 0x04; /** * {@link #_topologicalFlags} value used by {@link #isDeterministicOutput()} */ private static final byte DETERMINISTIC_OUTPUT = 0x08; protected static final int RESERVED_FLAGS = 0xFFFF0000; private static final int EVENT_MASK = 0x000F0000; private static final int CHANGE_EVENT_KNOWN = 0x00010000; private static final int PRIOR_CHANGE_EVENT = 0x00020000; @Deprecated private static final int FIXED_VALUE_CHANGE_EVENT = 0x00040000; @Deprecated private static final int INPUT_CHANGE_EVENT = 0x00080000; private static final int CHANGE_EVENT_MASK = 0x000F0000; private static final int NO_CHANGE_EVENT = 0x00010000; /*------- * State */ protected @Nullable IDatum _prior; @Deprecated protected String _modelerClassName; private final Domain _domain; public static Comparator<Variable> orderById = new Comparator<Variable>() { @Override @NonNullByDefault(false) public int compare(Variable var1, Variable var2) { return Longs.compare(var1.getGlobalId(), var2.getGlobalId()); } }; /*-------------- * Construction */ public Variable(Domain domain) { this(domain, "Variable"); } @Deprecated public Variable(Domain domain, String modelerClassName) { super(Ids.INITIAL_VARIABLE_ID); _modelerClassName = modelerClassName; _domain = domain; } protected Variable(Variable that) { super(that); _modelerClassName = that._modelerClassName; _domain = that._domain; IDatum prior = that._prior; _prior = prior != null ? prior.clone() : null; } /*---------------- * Object methods */ @Override public abstract @NonNull Variable clone(); /*--------------- * INode methods */ @Override public final Variable asVariable() { return this; } @Override public final boolean isVariable() { return true; } @Override public NodeType getNodeType() { return NodeType.VARIABLE; } /** * @deprecated as of release 0.08 */ @Deprecated @Override public String getClassLabel() { return "Variable"; } @Override public final VariablePort getPort(int siblingNumber) { return new VariablePort(this, siblingNumber); } @Override public Factor getSibling(int i) { // Variables should only be connected to factors return (Factor)super.getSibling(i); } @Override public List<Factor> getSiblings() { @SuppressWarnings("unchecked") List<Factor> siblings = (List<Factor>)super.getSiblings(); return siblings; } /** * Returns the solver-specific variable instance associated with this model variable if any. */ @Override public @Nullable ISolverVariable getSolver() { final FactorGraph fg = getParentGraph(); if (fg != null) { final ISolverFactorGraph sfg = fg.getSolver(); if (sfg != null) { return sfg.getSolverVariable(this); } } return null; } /** * Model-specific initialization for variables. * <p> * Clears {@link #isDeterministicInput()} and {@link #isDeterministicOutput()}. * Does NOT invoke solver variable initialize. */ @Override public void initialize() { super.initialize(); } /*-------------- * Node methods */ @Override protected int getEventMask() { return super.getEventMask() | EVENT_MASK; } /*------------------ * Variable methods */ /** * Casts this object to a {@link Discrete}. * @throws ClassCastException if this is not an instance of {@link Discrete}. */ public Discrete asDiscreteVariable() { return (Discrete)this; } /** * Get condition value associated with variable in default conditioning layer, if any. * <p> * Returns the value associated with this variable from the {@linkplain FactorGraph#getDefaultConditioningLayer() * default conditioning layer} associated with this variable's {@linkplain #getParentGraph() parent}. Returns * null if there is no parent or no default conditioning layer. * <p> * This method is a convenience method to make it easy access condition values through model variables, * but it should not be used in situations when underlying solvers may use a different conditioning * layer. * <p> * @since 0.08 * @see #setCondition(Object) */ public @Nullable IDatum getCondition() { FactorGraph parent = _parentGraph; if (parent != null) { DataLayer<? extends IDatum> layer = parent.getDefaultConditioningLayer(); if (layer != null) { return layer.get(this); } } return null; } /** * Set condition value associated with variable in default conditioning layer, creating one if missing. * <p> * Associates the value with this variable in the {@linkplain FactorGraph#createDefaultConditioningLayer() * default conditioning layer} associated with this variable's {@linkplain #getParentGraph() parent}. * <p> * This method is a convenience method to make it easy set condition values through model variables, * but it should not be used in situations when underlying solvers may use a different conditioning * layer. * <p> * @since 0.08 * @see #getCondition() */ public void setCondition(@Nullable Object value) { assertNotFrozen(); FactorGraph parent = _parentGraph; if (parent != null) { DataLayer<? extends IDatum> layer = value == null ? parent.getDefaultConditioningLayer() : parent.createDefaultConditioningLayer(); if (layer != null) { layer.set(this.asVariable(), value); } } else if (value != null) { throw new IllegalStateException("Cannot set condition on parentless variable"); } } /** * Get prior value associated with variable, if any. * <p> * This may either be a {@link Value} object specifying a fixed value (i.e. making * the variable a named constant) or a {@link IUnaryFactorFunction} or {@link IParameterizedMessage} * representing a prior distribution. * <p> * Note: this attribute replaces the various "fixed value" and "input" methods that * will eventually be phased out. * <p> * @since 0.08 * @see #getInputObject() * @see #getFixedValueObject() */ public @Nullable IDatum getPrior() { return _prior; } /** * Get prior if it is a {@link IUnaryFactorFunction}, else null. * @since 0.08 * @see #getPrior() */ public final @Nullable IUnaryFactorFunction getPriorFunction() { final IDatum prior = _prior; return prior instanceof IUnaryFactorFunction ? (IUnaryFactorFunction)prior : null; } /** * Get prior if it is a {@link Value}, else null. * @since 0.08 * @see #getPrior() */ public final @Nullable Value getPriorValue() { final IDatum prior = _prior; return prior instanceof Value ? (Value)prior : null; } /** * Associates a prior with the variable. * <p> * Sets the value of the {@linkplain #getPrior prior}. * <p> * @param prior may be one of the following: * <ul> * <li>{@code null} to remove any existing prior * <li>a {@link Value} object specifying fixed value for variable * <li>a value of the variable's {@linkplain #getDomain domain} specifying a fixed value * <li>a {@link IParameterizedMessage} appropriate to the variable's type * <li>any {@link IUnaryFactorFunction} appropriate to the variables domain. However, note that * not all solvers currently support such priors. They may be safely used with the Gibbs solver. * <li>({@link Discrete} only) an array of double used to implicitly create a {@link DiscreteWeightMessage} * </ul> * @return previous value of prior * @since 0.08 */ public @Nullable IDatum setPrior(@Nullable Object prior) { assertNotFrozen(); final IDatum priorPrior = _prior; if (prior == null || prior instanceof IDatum) { _prior = (IDatum)prior; } else if (_domain.inDomain(prior)) { _prior = Value.create(_domain, prior); } else { throw new ClassCastException(format("'%s' is not a %s and is not a member of variable's domain", prior, // Use Class instead of hard-coding name so that we can rename it easily IDatum.class.getSimpleName())); } priorChanged(priorPrior, _prior); return priorPrior; } protected @Nullable Object priorToFixedValue(@Nullable IDatum prior) { return prior instanceof Value ? ((Value)prior).getObject() : null; } protected @Nullable Object priorToInput(@Nullable IDatum prior) { return prior instanceof Value ? null : prior; } @SuppressWarnings("deprecation") private void priorChanged(@Nullable IDatum priorPrior, @Nullable IDatum newPrior) { final int eventFlags = getChangeEventFlags(); if (eventFlags != NO_CHANGE_EVENT) { if ((eventFlags & PRIOR_CHANGE_EVENT) != 0) { raiseEvent(new VariablePriorChangeEvent(this, priorPrior, newPrior)); } // Deprecated cases if ((eventFlags & FIXED_VALUE_CHANGE_EVENT) != 0 && (newPrior instanceof Value || priorPrior instanceof Value)) { raiseEvent(new VariableFixedValueChangeEvent(this, priorToFixedValue(priorPrior), priorToFixedValue(newPrior))); } if ((eventFlags & INPUT_CHANGE_EVENT) != 0) { final Object prevInput = priorPrior instanceof Value ? null : priorToInput(priorPrior); final Object newInput = newPrior instanceof Value ? null : priorToInput(newPrior); if (prevInput != newInput) { raiseEvent(new VariableInputChangeEvent(this, prevInput, newInput)); } } } } public Domain getDomain() { return _domain; } /** * Returns the solver-specific variable instance associated with this model variable if it is * an instance of the specified {@code solverVariableClass}, otherwise returns null. */ public @Nullable <T extends ISolverVariable> T getSolverIfType(Class<? extends T> solverVariableClass) { final ISolverVariable svar = getSolver(); T result = null; if (svar != null && solverVariableClass.isAssignableFrom(svar.getClass())) { result = solverVariableClass.cast(svar); } return result; } /** * Returns the solver-specific variable instance associated with this model variable if it is * an instance of the specified {@code solverVariableClass} and has {@link SNode#getParentGraph()} equal to * {@code solverGraph}, otherwise returns null. */ public @Nullable <T extends ISolverVariable> T getSolverIfTypeAndGraph( Class<? extends T> solverVariableClass, ISolverFactorGraph solverGraph) { T svar = getSolverIfType(solverVariableClass); if (svar != null && svar.getParentGraph() != solverGraph) { svar = null; } return svar; } /** * True if variable is an input to a directed deterministic function. * <p> * This attribute is not valid until after graph initialization has occurred * (see {@link FactorGraph#initialize()}). * * @since 0.05 */ public final boolean isDeterministicInput() { return isFlagSet(DETERMINISTIC_INPUT); } /** * True if variable is an output from a directed deterministic function. * <p> * This attribute is not valid until after graph initialization has occurred * (see {@link FactorGraph#initialize()}). * * @since 0.05 */ public final boolean isDeterministicOutput() { return isFlagSet(DETERMINISTIC_OUTPUT); } public void setGuess(@Nullable Object guess) { requireSolver("setGuess").setGuess(guess); } public boolean guessWasSet() { ISolverVariable svar = getSolver(); return svar != null && svar.guessWasSet(); } public @Nullable Object getGuess() { return requireSolver("getGuess").getGuess(); } /** * @category internal */ @Internal public void moveInputs(Variable other) { setPrior(other.getPrior()); } /** * @category internal */ @Internal public void createSolverObject(@Nullable ISolverFactorGraph factorGraph) { if (factorGraph != null) { factorGraph.getSolverVariable(this, true); } } // For setting the variable to a fixed value in lieu of an input /** * True if {@link #getPriorValue()} is non-null. */ public final boolean hasFixedValue() { return _prior instanceof Value; } /** * @deprecated as of release 0.08 */ @Deprecated public String getModelerClassName() { return _modelerClassName; } public @Nullable Object getBeliefObject() { final ISolverVariable svar = getSolver(); if (svar != null) return svar.getBelief(); else return getPrior(); } public Factor [] getFactors() { return getFactorsFlat(); } public FactorBase [] getFactors(int relativeNestingDepth) { int nSiblings = getSiblingCount(); FactorBase [] retval = new FactorBase[nSiblings]; for (int i = 0; i < nSiblings; i++) { retval[i] = (FactorBase)getConnectedNode(relativeNestingDepth,i); } return retval; } public FactorBase [] getFactorsTop() { return getFactors(0); } public Factor [] getFactorsFlat() { int nSiblings = getSiblingCount(); Factor [] retval = new Factor[nSiblings]; for (int i = 0; i < nSiblings; i++) { retval[i] = (Factor)getConnectedNodeFlat(i); } return retval; } @Internal public Variable split(FactorGraph fg,Factor [] factorsToBeMovedToCopy) { assertNotFrozen(); //create a copy of this variable Variable mycopy = clone(); mycopy.createSolverObject(null); mycopy.setInputObject(null); mycopy.setName(null); fg.addFactor(new Equality(), this,mycopy); //for each factor to be moved for (int i = 0; i < factorsToBeMovedToCopy.length; i++) { Factor factor = factorsToBeMovedToCopy[i]; //Replace the connection from this variable to the copy in the factor for (int j = 0, endj = factor.getSiblingCount(); j < endj; j++) { EdgeState edge = factor.getSiblingEdgeState(j); if (edge.getVariable(fg) == this) { fg.replaceEdge(factor, j, mycopy); } } } //set the solvers to null for this variable, the copied variable, and all the factors that were moved. ISolverFactorGraph sfg = fg.getSolver(); if (sfg != null) { createSolverObject(fg.getSolver()); mycopy.createSolverObject(fg.getSolver()); for (int i = 0; i < factorsToBeMovedToCopy.length; i++) factorsToBeMovedToCopy[i].createSolverObject(fg.getSolver()); } return mycopy; } /*-------------------- * Deprecated methods */ /** * @deprecated use {@link #getPriorFunction()} instead. */ @Deprecated public @Nullable Object getInputObject() { return getPriorFunction(); } /** * @deprecated use {@link #getPriorValue()} and {@link Value#getObject()} instead. */ @Deprecated public final @Nullable Object getFixedValueAsObject() { final Value value = getPriorValue(); return value != null ? value.getObject() : null; } /** * @deprecated use {@link #setPrior(Object)} instead. */ @Deprecated public final void setFixedValueFromObject(@Nullable Object value) { setPrior(value); } /** * @deprecated use {@link #getPriorValue()} instead * <p> * Note that for {@link Discrete} variables this will return the <b>index</b> of the value! */ @Deprecated public @Nullable Object getFixedValueObject() { IDatum datum = getPrior(); if (datum instanceof Value) { return ((Value) datum).getObject(); } return null; } /** * @deprecated use {@link #setPrior} or {@link Discrete#setPriorIndex} instead. */ @Deprecated public void setFixedValueObject(@Nullable Object value) { setPrior(value != null ? Value.create(_domain, value) : null); } /** * @deprecated use {@link #setPrior(Object)} instead */ @Deprecated public void setInputObject(@Nullable Object value) { setPrior(value); } /** * @category internal */ @Deprecated @Internal public void setSolver(@Nullable ISolverVariable svar) { throw new UnsupportedOperationException("Variable.setSolver no longer supported"); } /*------------------- * Internal methods */ /** * Creates a new variable that combines the domains of this variable with additional {@code variables}. * <p> * For use by {@link FactorGraph#join(Variable...)}. Currently only supported for {@link Discrete} * variables. * <p> * @param variables specifies at least one additional variables to join with this one. As a convenience, this * may begin with this variable, in which case there must be at least one other variable. * * @category internal */ @Internal public Variable createJointNoFactors(Variable ... variables) { throw new DimpleException("not implemented"); } /** * Returns solver variable or throws an exception if not yet set. * <p> * For internal use only. * <p> * @since 0.06 */ @Override @Internal public ISolverVariable requireSolver(String methodName) { final ISolverVariable svar = getSolver(); if (svar == null) { throw new NullPointerException(String.format("solver must be set before using '%s'", methodName)); } return svar; } /** * Sets {@link #isDeterministicInput()} to true. * * @since 0.05 * * @category internal */ @Internal public final void setDeterministicInput() { setFlags(DETERMINISTIC_INPUT); } /** * Sets {@link #isDeterministicOutput()} to true. * * @since 0.05 * * @category internal */ @Internal public final void setDeterministicOutput() { setFlags(DETERMINISTIC_OUTPUT); } /*----------------- * Private methods */ @SuppressWarnings("deprecation") private int getChangeEventFlags() { final int prevFlags = _flags & CHANGE_EVENT_MASK; if ((prevFlags & CHANGE_EVENT_KNOWN) != 0) { return prevFlags; } int flags = 0; final IDimpleEventListener listener = getEventListener(); if (listener != null) { if (listener.isListeningFor(VariablePriorChangeEvent.class, this)) { flags |= PRIOR_CHANGE_EVENT; } if (listener.isListeningFor(VariableInputChangeEvent.class, this)) { flags |= INPUT_CHANGE_EVENT; } if (listener.isListeningFor(VariableFixedValueChangeEvent.class, this)) { flags |= FIXED_VALUE_CHANGE_EVENT; } } setFlagValue(CHANGE_EVENT_MASK, flags); return flags; } }