/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.junctiontree; import static java.util.Objects.*; import java.util.Map.Entry; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.repeated.BlastFromThePastFactor; import com.analog.lyric.dimple.model.transform.JunctionTreeTransform; import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap; import com.analog.lyric.dimple.model.transform.OptionVariableEliminatorCostList; import com.analog.lyric.dimple.model.transform.VariableEliminator; import com.analog.lyric.dimple.model.transform.VariableEliminator.CostFunction; import com.analog.lyric.dimple.model.transform.VariableEliminator.VariableCost; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.options.BPOptions; import com.analog.lyric.dimple.options.DimpleOptions; import com.analog.lyric.dimple.schedulers.SchedulerOptionKey; import com.analog.lyric.dimple.solvers.core.NoSolverEdge; import com.analog.lyric.dimple.solvers.core.NoSolverVariableBlock; import com.analog.lyric.dimple.solvers.core.proxy.ProxySolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.IFactorGraphFactory; import com.analog.lyric.dimple.solvers.interfaces.ISolverBlastFromThePastFactor; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable; import com.analog.lyric.util.misc.Matlab; import com.analog.lyric.util.misc.Misc; /** * Base class for solver graphs using junction tree algorithm to transform graph into a tree * for exact inference using belief propagation. * * @param <Delegate> specifies the type of the solver that will be used on the transformed graph and to * which this solver will delegate. * * @since 0.05 * @author Christopher Barber */ public abstract class JunctionTreeSolverGraphBase<Delegate extends ISolverFactorGraph> extends ProxySolverFactorGraph<JunctionTreeSolverFactor, IJunctionTreeSolverVariable<?>, NoSolverEdge, NoSolverVariableBlock, Delegate> { private final JunctionTreeTransform _transformer; private final @Nullable IFactorGraphFactory<?> _solverFactory; private @Nullable JunctionTreeTransformMap _transformMap = null; /*-------------- * Construction */ protected JunctionTreeSolverGraphBase(FactorGraph sourceModel, @Nullable JunctionTreeSolverGraphBase<Delegate> parent, @Nullable IFactorGraphFactory<?> solverFactory) { super(sourceModel, parent); _transformer = new JunctionTreeTransform(); _solverFactory = solverFactory; } /*--------------------- * ISolverNode methods */ @Override public double getBetheEntropy() { final FactorGraph sourceModel = getModelObject(); double entropy = 0; // Sum up factor entropy for (Factor factor : sourceModel.getFactors()) { entropy += factor.getBetheEntropy(); } // The following would be unnecessary if we implemented inputs as single node factors for (Variable variable : sourceModel.getVariablesFlat()) { entropy -= variable.getBetheEntropy() * (variable.getSiblingCount() - 1); } return entropy; } @Override public double getBetheFreeEnergy() { return getInternalEnergy() - getBetheEntropy(); } @Override public double getInternalEnergy() { final JunctionTreeTransformMap transformMap = getTransformMap(); if (transformMap == null) { return Double.NaN; } double energy = 0; //Sum up factor internal energy for (Factor factor : transformMap.target().getFactors()) { energy += factor.getInternalEnergy(); } //The following would be unnecessary if we implemented inputs as single node factors for (Variable variable : getModelObject().getVariablesFlat()) { energy += variable.getInternalEnergy(); } return energy; } @Override public abstract @Nullable JunctionTreeSolverGraphBase<Delegate> getParentGraph(); @Deprecated @Override public double getScore() { final JunctionTreeTransformMap transformMap = getTransformMap(); if (transformMap == null) { return Double.NaN; } transformMap.updateGuesses(); double energy = 0.0; for (Variable variable : getModelObject().getVariables()) { energy += variable.getScore(); } for (Factor factor : transformMap.target().getFactors()) { energy += factor.getScore(); if (Double.isInfinite(energy)) { Misc.breakpoint(); } } return energy; } @Override public abstract JunctionTreeSolverGraphBase<Delegate> getRootSolverGraph(); /*------------------------- * ProxySolverNode methods */ @Override public @Nullable Delegate getDelegate() { final JunctionTreeTransformMap transformMap = _transformMap; if (transformMap != null) { @SuppressWarnings("unchecked") Delegate delegate = (Delegate) transformMap.target().getSolver(); return delegate; } return null; } /*---------------------------- * ISolverFactorGraph methods */ @Override public ISolverBlastFromThePastFactor createBlastFromThePast(BlastFromThePastFactor factor) { // FIXME - blast from the past factor in junction tree throw unsupported("createBlastFromThePast"); } @Override public IJunctionTreeSolverVariable<?> createVariable(Variable var) { if (var instanceof Discrete) { return new JunctionTreeDiscreteSolverVariable((Discrete)var, this); } else { return new JunctionTreeSolverVariable<Variable>(var, this); } } @Override public JunctionTreeSolverFactor createFactor(Factor factor) { return new JunctionTreeSolverFactor(factor, this); } /** * {@inheritDoc} * <p> * For the junction tree solver, the number of iterations is always one. Additional iterations * should not modify the result. */ @Override public int getNumIterations() { return super.getNumIterations(); } @Override public void setNumIterations(int numIterations) { super.setNumIterations(1); } /** * {@inheritDoc} * @return {@link BPOptions#scheduler} */ @Override public @Nullable SchedulerOptionKey getSchedulerKey() { return BPOptions.scheduler; } @Override public void initialize() { // Configure settings from options. _transformer.useConditioning(getOptionOrDefault(JunctionTreeOptions.useConditioning)); _transformer.maxTransformationAttempts(getOptionOrDefault(JunctionTreeOptions.maxTransformationAttempts)); OptionVariableEliminatorCostList costFunctions = getOptionOrDefault(JunctionTreeOptions.variableEliminatorCostFunctions); _transformer.variableEliminatorCostFunctions(costFunctions.toArray(new CostFunction[costFunctions.size()])); Long seed = getOption(DimpleOptions.randomSeed); if (seed != null) { _transformer.random().setSeed(seed); } if (isTransformValid()) { final JunctionTreeTransformMap transformMap = requireNonNull(_transformMap); // Copy inputs/fixed values to transformed model in case they have changed. for (Entry<Variable,Variable> entry : transformMap.sourceToTargetVariables().entrySet()) { final Variable sourceVar = entry.getKey(); if (sourceVar != null) { final Variable targetVar = entry.getValue(); targetVar.setPrior(sourceVar.getPrior()); } } } else { updateDelegate(); // FIXME: update proxy factor mappings } requireNonNull(getDelegate()).initialize(); } @Override public void iterate() { updateDelegate(); requireDelegate("iterate").iterate(); } @Override public void solve() { getModelObject().initialize(); updateDelegate(); requireDelegate("solve").solve(); } @Override public void solveOneStep() { updateDelegate(); requireDelegate("solveOneStep").solveOneStep(); } @Override public void startSolver() { updateDelegate(); requireDelegate("startSolver").startSolver(); } /*--------------------------------- * JunctionTreeSolverGraph methods */ public @Nullable IFactorGraphFactory<?> getDelegateSolverFactory() { return _solverFactory; } /** * The object that implements the junction tree transformation. */ public JunctionTreeTransform getTransformer() { return _transformer; } /** * Returns transformed graph and accompanying mapping data. May be null if not yet computed * (i.e. {@link #initialize()} not yet run. */ public @Nullable JunctionTreeTransformMap getTransformMap() { return _transformMap; } /** * If true, then the transformation will condition out any variables that have a fixed value. * This will produce a more efficient graph but will prevent it from being reused if the fixed * value changes. * <p> * False by default. * @see #useConditioning(boolean) */ public boolean useConditioning() { return _transformer.useConditioning(); } /** * Sets {@link #useConditioning()} to specified value. * @return this */ public JunctionTreeSolverGraphBase<Delegate> useConditioning(boolean yes) { _transformer.useConditioning(yes); setOption(JunctionTreeOptions.useConditioning, yes); return this; } /** * The cost functions used by {@link VariableEliminator} to determine the variable * elimination ordering. If empty (the default), then all of the standard {@link VariableCost} * functions will be tried. * * @see #variableEliminatorCostFunctions(VariableEliminator.CostFunction...) * @see #variableEliminatorCostFunctions(VariableEliminator.VariableCost...) */ public CostFunction[] variableEliminatorCostFunctions() { return _transformer.variableEliminatorCostFunctions(); } /** * Sets {@link #variableEliminatorCostFunctions()} to specified value. * @return this * @see #variableEliminatorCostFunctions(VariableEliminator.VariableCost...) */ public JunctionTreeSolverGraphBase<Delegate> variableEliminatorCostFunctions(CostFunction ... costFunctions) { _transformer.variableEliminatorCostFunctions(costFunctions); setOption(JunctionTreeOptions.variableEliminatorCostFunctions, new OptionVariableEliminatorCostList(costFunctions)); return this; } /** * Sets {@link #variableEliminatorCostFunctions()} to specified value. * @return this * @see #variableEliminatorCostFunctions(VariableEliminator.CostFunction...) */ public JunctionTreeSolverGraphBase<Delegate> variableEliminatorCostFunctions(VariableCost ... costFunctions) { return variableEliminatorCostFunctions(VariableCost.toFunctions(costFunctions)); } @Matlab public JunctionTreeSolverGraphBase<Delegate> variableEliminatorCostFunctions(String ... costFunctionNames) { final int n = costFunctionNames.length; CostFunction[] costFunctions = new CostFunction[n]; for (int i = 0; i < n; ++i) { costFunctions[i] = VariableCost.valueOf(costFunctionNames[i]).function(); } return variableEliminatorCostFunctions(costFunctions); } /** * Specifies the maximum number of times to attempt to determine an optimal junction tree * transformation. * <p> * Specifies the number of iterations of the {@link VariableEliminator} algorithm when * attempting to determine the variable elimination ordering that determines the junction tree * transformation. Each iteration will pick a cost function from * {@link #variableEliminatorCostFunctions()} at random and will randomize the order of * variables that have equivalent costs. A higher number of iterations may produce a better * ordering. * <p> * Default value is specified by * {@link JunctionTreeTransform#DEFAULT_MAX_TRANSFORMATION_ATTEMPTS}. * <p> * * @see #maxTransformationAttempts(int) */ public int maxTransformationAttempts() { return _transformer.maxTransformationAttempts(); } /** * Sets {@link #maxTransformationAttempts()} to the specified value. * @return this */ public JunctionTreeSolverGraphBase<Delegate> maxTransformationAttempts(int iterations) { _transformer.maxTransformationAttempts(iterations); setOption(JunctionTreeOptions.maxTransformationAttempts, iterations); return this; } /*----------------- * Package methods */ @Nullable ISolverVariable getDelegateSolverVariable(IJunctionTreeSolverVariable<?> var) { final JunctionTreeTransformMap transformMap = _transformMap; if (transformMap != null) { return transformMap.sourceToTargetVariable(var.getModelObject()).getSolver(); } return null; } /*----------------- * Private methods */ private boolean isTransformValid() { final JunctionTreeTransformMap transformMap = _transformMap; return transformMap != null && transformMap.isValid(); } private @Nullable ISolverFactorGraph updateDelegate() { if (!isTransformValid()) { final JunctionTreeTransformMap transformMap = _transformMap = _transformer.transform(getModelObject()); transformMap.target().setSolverFactory(_solverFactory); } return notifyNewDelegate(getDelegate()); } @Override public boolean checkAllEdgesAreIncludedInSchedule() { return true; } }