/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.factorfunctions.core; import static java.util.Objects.*; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.MatrixProduct; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.domains.DomainList; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.model.domains.JointDomainReindexer; import com.analog.lyric.dimple.model.factors.DiscreteFactor; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.IndexedValue; import com.analog.lyric.dimple.model.values.RealValue; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage; import com.analog.lyric.util.misc.Matlab; import net.jcip.annotations.ThreadSafe; @ThreadSafe public abstract class FactorFunction implements IFactorFunction { // FIXME remove default FactorFunction.getParameterizedMessage implementation // This is only here to shut up errors until I get around to implementing this on // implementors of IParametricFactorFunction public @Nullable IParameterizedMessage getParameterizedMessage() { return null; } /*------- * State */ // FIXME - make factor table cache weak // Cache of factor tables for this function by domain. private AtomicReference<ConcurrentMap<JointDomainIndexer, IFactorTable>> _factorTables = new AtomicReference<ConcurrentMap<JointDomainIndexer, IFactorTable>>(); private final String _name; /*-------------- * Construction */ protected FactorFunction() { this((String)null); } protected FactorFunction(@Nullable String name) { _name = name != null ? name : getClass().getSimpleName(); } protected FactorFunction(FactorFunction other) { _name = other._name; } @Override public FactorFunction clone() throws CloneNotSupportedException { throw new CloneNotSupportedException(String.format("%s objects do not support cloning", getClass().getSimpleName())); } /*------------------ * Abstract methods */ @Override public abstract double evalEnergy(Value[] values); /*------------------------ * FactorFunction methods */ /** * Single argument version of {@link #evalEnergy(Value[])} * <p> * Simply wraps the {@code value} in an an array and calls the former. * <p> * The existence of this method helps to avoid accidental invocation of {@link #evalEnergy(Object...)}. * @since 0.08 */ public double evalEnergy(Value value) { return evalEnergy(new Value[] { value }); } // Evaluate the factor function energy using unwrapped object arguments @Matlab public double evalEnergy(Object... arguments) { final int size = arguments.length; final Value[] values = new Value[size]; for (int i = 0; i < size; ++i) values[i] = Value.create(arguments[i]); final double energy = evalEnergy(values); if (energy != energy) // Faster isNaN return Double.POSITIVE_INFINITY; return energy; } @Matlab public double evalEnergy(Object value) { final double energy = evalEnergy(Value.create(value)); return energy == energy ? energy : Double.POSITIVE_INFINITY; } // Evaluate the factor and return a weight rather than an energy value @Override public double eval(Value[] values) { final double energy = evalEnergy(values); if (energy != energy) // Faster isNaN return Double.POSITIVE_INFINITY; return Math.exp(-energy); } /** * Single argument version of {@link #eval(Value[])} * <p> * Simply wraps the {@code value} in an an array and calls the former. * <p> * The existence of this method helps to avoid accidental invocation of {@link #eval(Object...)}. * @since 0.08 */ public double eval(Value value) { return eval(new Value[] { value }); } /** * Evaluate the factor and return a weight value using unwrapped object arguments * @since 0.08 */ @Matlab public double eval(Object... arguments) { return Math.exp(-evalEnergy(arguments)); } /** * For deterministic-directed factor functions, set the value of the output variables given the input variables, *<p> * The default implementation does nothing; any deterministic-directed factor function must override this method. */ @Override public void evalDeterministic(Value[] arguments) { } // Used by MATLAB core Dimple code for discrete variables only @Matlab public @Nullable Object getDeterministicFunctionValue(Object... arguments) { Value[] fullArgumentList = new Value[arguments.length + 1]; for (int i = 0; i < arguments.length; i++) fullArgumentList[i + 1] = Value.create(arguments[i]); fullArgumentList[0] = RealValue.create(); // Ok to use RealValue since it will be a number, but we don't know what evalDeterministic(fullArgumentList); return fullArgumentList[0].getObject(); } /** * Required domains of factor function arguments, if any. * <p> * Factor functions that require a fixed number of arguments with specific domains * can specify them through this method. * <p> * The default implementation returns null. * <p> * @since 0.08 */ public @Nullable DomainList<?> getDomains() { return null; } /** * Run {@link #evalDeterministic} without modifying the arguments. * @param arguments will be copied to return * @return freshly allocated array of {@link Value}s holding references the original {@code Value} * for each input argument, and new clones for output arguments. */ public Value[] evalDeterministicToCopy(Value[] arguments) { final Value[] copy = cloneOutputArguments(arguments); evalDeterministic(copy); return copy; } /** * @since 0.05 */ public boolean factorTableExists(@Nullable JointDomainIndexer domains) { boolean exists = false; if (domains != null) { ConcurrentMap<JointDomainIndexer, IFactorTable> factorTables = _factorTables.get(); exists = factorTables != null && factorTables.containsKey(domains); } return exists; } /** * @since 0.05 */ public boolean factorTableExists(Factor factor) { return factorTableExists(factor.getArgumentDomains().asJointDomainIndexer()); } public boolean convertFactorTable(@Nullable JointDomainIndexer oldDomains, @Nullable JointDomainIndexer newDomains) { boolean converted = false; if (oldDomains != null && newDomains != null) { ConcurrentMap<JointDomainIndexer, IFactorTable> tables = _factorTables.get(); if (tables != null) { IFactorTable table = tables.get(oldDomains); if (table != null) { table.setConditional(Objects.requireNonNull(newDomains.getOutputSet())); } } } return converted; } public @Nullable int[] getDirectedToIndices(int numEdges) {return getDirectedToIndices();} // May depend on the number of edges protected @Nullable int[] getDirectedToIndices() {return null;} // This can be overridden instead, if result doesn't depend on the number of edges /** * Returns the output indices that can be changed when specified input is changed or else null * if the same as the full set of output edges. * <p> * The default implementation returns null. * <p> * This may be overridden for functions that have multiple outputs and inputs for which * a single input may only affect a subset of the full outputs (e.g. {@link MatrixProduct}). * * @since 0.05 */ public @Nullable int[] getDirectedToIndicesForInput(Factor factor, int inputEdge) { return null; } /** * @since 0.05 */ public final IFactorTable getFactorTable(Domain [] domains) { return getFactorTable(DomainList.create(domains).asJointDomainIndexer()); } /** * @since 0.05 */ public IFactorTable getFactorTable(@Nullable JointDomainIndexer domains) { if (domains == null) { throw new DimpleException("only support getFactorTable for discrete domains"); } ConcurrentMap<JointDomainIndexer, IFactorTable> factorTables = _factorTables.get(); if (factorTables == null) { _factorTables.compareAndSet(null, new ConcurrentHashMap<JointDomainIndexer, IFactorTable>()); factorTables = _factorTables.get(); } IFactorTable factorTable = factorTables.get(domains); if (factorTable == null) { IFactorTable newTable = createTableForDomains(domains); factorTable = factorTables.putIfAbsent(domains, newTable); if (factorTable == null) { factorTable = newTable; } } return factorTable; } /** * Create a factor table for given factor. * <p> * Intended for internal use in {@link DiscreteFactor#getFactorTable()} * <p> * @param factor must be a {@link DiscreteFactor} * @since 0.05 */ public IFactorTable getFactorTable(Factor factor) { final JointDomainIndexer argDomains = requireNonNull(factor.getArgumentDomains().asJointDomainIndexer()); IFactorTable table = getFactorTable(argDomains); if (factor.hasConstants()) { // Get rid of constant dimensions final JointDomainIndexer edgeDomains = requireNonNull(factor.getDomainList().asJointDomainIndexer()); final int nArgs = argDomains.size(); final int[] oldToNew = new int[nArgs]; for (int i = 0, keep = 0, remove = edgeDomains.size(); i < nArgs; ++i) { oldToNew[i] = factor.hasConstantAtIndex(i) ? remove++ : keep++; } JointDomainReindexer converter = JointDomainReindexer.createPermuter(argDomains, edgeDomains, oldToNew); table = table.convert(converter); } return table; } /** * @since 0.05 */ public @Nullable IFactorTable getFactorTableIfExists(@Nullable JointDomainIndexer domains) { IFactorTable factorTable = null; if (domains != null) { ConcurrentMap<JointDomainIndexer, IFactorTable> factorTables = _factorTables.get(); if (factorTables != null) { factorTable = factorTables.get(domains); } } return factorTable; } /** * @since 0.05 */ public @Nullable IFactorTable getFactorTableIfExists(Factor factor) { return getFactorTableIfExists(factor.getArgumentDomains().asJointDomainIndexer()); } @Override public String getName() { return _name; } @Override public boolean isDeterministicDirected() {return false;} @Override public boolean isDirected() {return false;} @Override public boolean isParametric() { return IParametricFactorFunction.class.isInstance(this); } /** * The maximum number of variable updates beyond which {@link #updateDeterministic} * should not be called. * <p> * Default implementation returns 0, indicating that {@link #updateDeterministic} should * never be called. * <p> * @param numEdges is the number of edges (variables) to consider. It corresponds to the * size of the first argument to {@link #updateDeterministic}. * * @since 0.05 */ public int updateDeterministicLimit(int numEdges) { return 0; } /** * Deterministically update output values in {@code values} array incrementally based on changed input * values. * <p> * For functions that support it, this can allow for more efficient computation when there are many * inputs and or outputs and only a small subset of inputs have changed (e.g. one when doing a single * Gibbs update). * <p> * The default implementation delegates back to {@link #evalDeterministic(Value[])}, which * will do a full update. * <p> * @param values is the array of output and input values that are maintained persistently. When this * method is called, it may be assumed that the contents contains the current values of all input * variables and the last computed values of all output variables (which were based on previous values * of inputs). * @param oldValues contains descriptions of the variable number and old value of each input. Only indexes * of input variables should be specified. This list should not contain more than * {@link #updateDeterministicLimit(int)} elements. * @param changedOutputsHolder should be set by the function to contain the list of indexes of output variables * that were changed or else set to contain null if all of the outputs were modified. * @return true if update was done incrementally (i.e not all inputs were processed), false if full * update was done. * * @throws IndexOutOfBoundsException if an index in {@code oldValues} does not refer to an input variable. */ public boolean updateDeterministic(Value[] values, Collection<IndexedValue> oldValues, AtomicReference<int[]> changedOutputsHolder) { evalDeterministic(values); changedOutputsHolder.set(null); return false; } /** * Run {@link #updateDeterministic} without modifying the output arguments. * @param arguments will be copied to return * @return freshly allocated array of {@link Value}s holding references the original {@code Value} * for each input argument, and new clones for output arguments. */ public Value[] updateDeterministicToCopy(Value[] arguments, Collection<IndexedValue> oldValues, AtomicReference<int[]> changedOutputsHolder) { final Value[] copy = cloneOutputArguments(arguments); updateDeterministic(copy, oldValues, changedOutputsHolder); return copy; } /** * Indicates whether to use {@link #updateEnergy} * <p> * Functions that override {@link #updateEnergy} should override this method to return * true when that function is expected to provide a performance benefit over {@link #evalEnergy(Value[])}. * <p> * The default implementation returns false. * <p> * @param values the values that will be passed to {@code updateEnergy}. * @param nChangedValues the number of {@code oldValues} that will be provided to {@code updateEnergy}. * @since 0.08 */ public boolean useUpdateEnergy(Value[] values, int nChangedValues) { return false; } /** * Optimized energy evaluation based on previously computed value. * <p> * Implementations of this method should produce the same energy value you * would get by calling {@link #evalEnergy} with the specified {@code values}, * and that is what the default implementation does. * <p> * When the same function needs to be evaluated many times with only small changes * in the input arguments (as is the case in the Gibbs solver), the energy can * sometimes be more efficiently computed based on the previous energy values together * with information about the value changes. * <p> * Functions that implement this method should also implement {@code #useUpdateEnergy}. * <p> * @param values the argument values for which the energy is to be computed * @param oldValues specifies one or more old values. Each {@link IndexedValue} argument * indicates the value and its position in the argument list. * @param oldEnergy the previous energy computed from {@code values} except for those * overridden by {@code oldValues}. * * @since 0.08 */ public double updateEnergy(Value[] values, IndexedValue[] oldValues, double oldEnergy) { return evalEnergy(values); } /*------------------- * Protected methods */ /** * Generate a factor table for this function over the given domains. * <p> * Invoked implicitly by {@link #getFactorTable(JointDomainIndexer)} the first time * a factor table is needed for specified domains. */ protected IFactorTable createTableForDomains(JointDomainIndexer domains) { return FactorTable.create(this, domains); } /********* * Methods from FactorFunctionWithConstants that allow calling even if there are no constants */ /** * Return whether or not there are constants in the factor function instance * * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public boolean hasConstants() { return false; } /** * Return number of constants built into the factor function instance. * <p> * Default implementation returns zero. * * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public int getConstantCount() { return 0; } /** * Return the innermost FactorFunction object, in case it is wrapped in a containing class. * In this case, there is no containing class * * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public FactorFunction getContainedFactorFunction() { return this; } /** * Returns constant at edge identified by {@code index} or null if specified * edge is not a constant. * <p> * Default implementation returns null. * * @since 0.08 * @deprecated since release 0.08 */ @Deprecated public @Nullable Value getConstantValueByIndex(int index) { return null; } /** * @deprecated as of release 0.08 use {@link #getConstantValueByIndex(int)} * instead. * <p> * Returns constant at edge identified by {@code index} or null if specified * edge is not a constant. * <p> * Default implementation returns null. * * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public @Nullable Object getConstantByIndex(int index) { return null; } /** * Returns whether or not the index corresponds to a constant * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public boolean isConstantIndex(int index) { return false; } /** * Returns whether or not the index range corresponds to a constant * @since 0.06 * @deprecated since release 0.08 */ @Deprecated public boolean hasConstantsInIndexRange(int minIndex, int maxIndex) { return false; } /** * Returns whether or not the index corresponds to a constant * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public boolean hasConstantAtOrAboveIndex(int index) { return false; } /** * Returns whether or not the index corresponds to a constant * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public boolean hasConstantAtOrBelowIndex(int index) { return false; } /** * Returns whether or not the index corresponds to a constant * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public int numConstantsInIndexRange(int minIndex, int maxIndex) { return 0; } /** * Returns whether or not the index corresponds to a constant * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public int numConstantsAtOrAboveIndex(int index) { return 0; } /** * Returns whether or not the index corresponds to a constant * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public int numConstantsAtOrBelowIndex(int index) { return 0; } /** * Return the edge number associated with the specified factor index. * For factors with not constants, these are identical. * * @since 0.05 * @deprecated since release 0.08 */ @Deprecated public int getEdgeByIndex(int index) { return index; } /** * Return all edges within the range of indices specified * @since 0.06 */ @Deprecated public @Nullable int[] getEdgesByIndexRange(int minIndex, int maxIndex) { int[] edges = new int[maxIndex - minIndex + 1]; for (int i = 0, index = minIndex; index <= maxIndex; i++, index++) edges[i] = index; return edges; } /** * @since 0.05 */ @Deprecated public int getIndexByEdge(int edge) { return edge; } /** * @deprecated as of release 0.08 use {@link Factor#getConstantValues} instead. * @since 0.08 */ @Deprecated public List<Value> getConstantValues() { return Collections.emptyList(); } /** * @deprecated as of release 0.08 use {@link Factor#getConstantValues} instead. * @since 0.05 */ @Deprecated public Object[] getConstants() { return ArrayUtil.EMPTY_OBJECT_ARRAY; } /** * @since 0.05 * @deprecated since release 0.08 use {@link Factor#getConstantIndices} instead. */ @Deprecated public int[] getConstantIndices() { return ArrayUtil.EMPTY_INT_ARRAY; } /*--------------------------- * Parameter utility methods */ /** * Looks up a value or default from a map. * <p> * Returns result from {@code map.get(key)} if non-null, otherwise * returns {@code defaultValue}. * <p> * This can be used to read parameters in FactorFunction constructors that take a parameter map. * <p> * @since 0.07 * @see #getFirstOrDefault(Map, Object, Object...) */ public static <K,V> V getOrDefault(Map<K,V> map, K key, V defaultValue) { final V value = map.get(key); return value != null ? value : defaultValue; } /** * Looks up a value using multiple keys * <p> * Returns first non-null result from {@code map.get(key}} for each key * in {@code keys}. If none is found, null is returned. * @since 0.07 * @see #getOrDefault(Map, Object, Object) */ @SafeVarargs public static @Nullable <K,V> V getFirst(Map<K,V> map, K ... keys) { for (K key : keys) { V value = map.get(key); if (value != null) { return value; } } return null; } /** * Looks up a value using multiple keys and returns value or default. * <p> * Returns first non-null result from {@code map.get(key}} for each key * in {@code keys}. If none is found, the {@code defaultValue} is returned instead. * @since 0.07 * @see #getOrDefault(Map, Object, Object) */ @SafeVarargs public static <K,V> V getFirstOrDefault(Map<K,V> map, V defaultValue, K ... keys) { V value = getFirst(map, keys); return value != null ? value : defaultValue; } @SafeVarargs public static <K,V> V require(Map<K,V> map, K ... keys) { V value = getFirst(map, keys); if (value != null) { return value; } throw new IllegalArgumentException(String.format("Expected parameter named '%s'", keys[0])); } /*----------------- * Private methods */ /** * Returns a copy of arguments with output arguments replaced with mutable clones. * @since 0.08 */ private Value[] cloneOutputArguments(Value[] arguments) { final Value[] copy = arguments.clone(); // Clone the Values for output indices only for (int i : requireNonNull(getDirectedToIndices(arguments.length))) { copy[i] = copy[i].mutableClone(); // Assumes a deep clone } return copy; } }