/******************************************************************************* * Copyright 2012-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.core; import static java.util.Objects.*; import java.util.AbstractCollection; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import org.eclipse.jdt.annotation.NonNull; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ExtendedArrayList; import com.analog.lyric.collect.NonNullListIndices; import com.analog.lyric.collect.PrimitiveIterable; import com.analog.lyric.collect.ReleasableIterator; import com.analog.lyric.dimple.data.DataLayer; import com.analog.lyric.dimple.data.IDatum; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.environment.DimpleThread; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.IFactorGraphChild; import com.analog.lyric.dimple.model.core.Ids; import com.analog.lyric.dimple.model.core.Node; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.factors.FactorBase; import com.analog.lyric.dimple.model.repeated.BlastFromThePastFactor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableBlock; import com.analog.lyric.dimple.options.BPOptions; import com.analog.lyric.dimple.options.SolverOptions; import com.analog.lyric.dimple.schedulers.EmptyScheduler; import com.analog.lyric.dimple.schedulers.IScheduler; import com.analog.lyric.dimple.schedulers.SchedulerOptionKey; import com.analog.lyric.dimple.schedulers.schedule.ISchedule; import com.analog.lyric.dimple.schedulers.schedule.ScheduleValidationException; import com.analog.lyric.dimple.schedulers.scheduleEntry.BlockScheduleEntry; import com.analog.lyric.dimple.schedulers.scheduleEntry.EdgeScheduleEntry; import com.analog.lyric.dimple.schedulers.scheduleEntry.IScheduleEntry; import com.analog.lyric.dimple.schedulers.scheduleEntry.NodeScheduleEntry; import com.analog.lyric.dimple.schedulers.scheduleEntry.SubgraphScheduleEntry; import com.analog.lyric.dimple.schedulers.validator.ScheduleValidator; import com.analog.lyric.dimple.schedulers.validator.ScheduleValidatorOptionKey; import com.analog.lyric.dimple.solvers.core.multithreading.MultiThreadingManager; import com.analog.lyric.dimple.solvers.interfaces.IParameterizedSolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverBlastFromThePastFactor; import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverNode; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariableBlock; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; import com.analog.lyric.options.IOptionHolder; import com.analog.lyric.util.misc.Internal; import com.analog.lyric.util.misc.Matlab; import com.google.common.collect.UnmodifiableIterator; /** * Standard base implementation of {@link IParameterizedSolverFactorGraph}. * <p> * @param <SFactor> type of solver factor objects for this graph * @param <SVariable> type of solver variable objects for this graph * @param <SEdge> type of solver edge objects for this graph. Graphs that do not support solver edges can * use {@link NoSolverEdge} for this type. * @param <SBlock> type of solver variable block objects for this graph if applicable. Graphs that do not support * solver variable blocks can use {@link NoSolverVariableBlock} for this type. * @since 0.08 * @author Christopher Barber */ public abstract class SFactorGraphBase <SFactor extends ISolverFactor, SVariable extends ISolverVariable, SEdge extends ISolverEdgeState, SBlock extends ISolverVariableBlock> extends SNode<FactorGraph> implements IParameterizedSolverFactorGraph<SFactor, SVariable, SEdge, SBlock> { /** * Bits in {@link #_flags} reserved by this class and its superclasses. */ @SuppressWarnings("hiding") protected static final int RESERVED_FLAGS = 0xFFFF0000; private @Nullable ISolverFactorGraph _parent; protected int _numIterations = 1; // Default number of iterations unless otherwise specified private @Nullable MultiThreadingManager _multithreader; // = new MultiThreadingManager(); protected boolean _useMultithreading = false; /** * Solver factors belonging to {@link this} indexed by {@link Factor}s local index. */ private final ExtendedArrayList<SFactor> _factors; /** * Solver variables belonging to {@link this} indexed by {@link Variable}s local index. */ private final ExtendedArrayList<SVariable> _variables; /** * Solver subgraphs belonging to {@link this} indexed by each {@link FactorGraph}s local index. */ private final ExtendedArrayList<ISolverFactorGraph> _subgraphs; private final ExtendedArrayList<SEdge> _edges; private final ExtendedArrayList<SBlock> _blocks; private SolverNodeMapping _solverNodeMapping; protected @Nullable ISchedule _schedule; /*-------------- * Construction */ protected SFactorGraphBase(FactorGraph graph, @Nullable ISolverFactorGraph parent) { super(graph); _factors = new ExtendedArrayList<>(graph.getFactorCount(0)); _variables = new ExtendedArrayList<>(graph.getVariableCount(0)); _subgraphs = new ExtendedArrayList<>(graph.getOwnedGraphs().size()); _edges = new ExtendedArrayList<SEdge>(hasEdgeState() ? graph.getGraphEdgeStateMaxIndex() + 1: 0); _blocks = new ExtendedArrayList<>(graph.getOwnedVariableBlocks().size()); _solverNodeMapping = new StandardSolverNodeMapping(this); _parent = parent; if (parent == null) { // Inherit default conditioning layer from model graph if this is the root. setConditioningLayer(graph.getDefaultConditioningLayer()); } } /*------------------ * IVariableToValue */ @Override public @Nullable Value varToValue(Variable var) { Value value = var.getPriorValue(); if (value == null) { DataLayer<?> layer = getConditioningLayer(); if (layer != null) { IDatum datum = layer.get(var); if (datum instanceof Value) { value = (Value)datum; } } } return value; } /*---------------------------- * ISolverEventSource methods */ @Override public SFactorGraphBase<SFactor,SVariable,SEdge, SBlock> getContainingSolverGraph() { return this; } /*--------------------- * ISolverNode methods */ public FactorGraph getModel() { return _model; } @Override public @Nullable ISolverFactorGraph getParentGraph() { return _parent; } @Override @Internal public void setParent(ISolverFactorGraph parent) { _parent = parent; _solverNodeMapping = parent.getSolverMapping(); _solverNodeMapping.addSolverGraph(this); } @Override public ISchedule getSchedule() { ISchedule schedule = _schedule; if (schedule != null && schedule.isUpToDateForSolver(this)) { return schedule; } final IScheduler scheduler = getScheduler(); _schedule = schedule = scheduler.createSchedule(this); schedule.setScheduler(scheduler); return schedule; } @Override public void setSchedule(@Nullable ISchedule schedule) { if (schedule == null && getSchedulerKey() == null) { throw new UnsupportedOperationException(String.format("%s does not support schedules", this)); } _schedule = schedule; } @Override public IScheduler getScheduler() { final SchedulerOptionKey schedulerKey = getSchedulerKey(); if (schedulerKey == null) { return EmptyScheduler.INSTANCE; } IScheduler scheduler = getLocalOption(schedulerKey); if (scheduler == null || scheduler.isDefaultScheduler()) { // If this is the default scheduler, we need to check to see if it has been // superceded by a schedule set further in the delegation chain. ReleasableIterator<? extends IOptionHolder> delegates = getOptionDelegates(); delegates.next(); // skip this solver graph IScheduler nextScheduler = null; while (delegates.hasNext()) { nextScheduler = delegates.next().getLocalOption(schedulerKey); if (nextScheduler != null && schedulerKey.validForDelegator(nextScheduler, this)) { break; } nextScheduler = null; } if (nextScheduler != null || scheduler != null && scheduler.getClass() != schedulerKey.defaultClass()) { // If we found a scheduler in the delegation chain or if the default class // is not this class, then clear the currently saved schedule and use the // delegate version instead. scheduler = nextScheduler; unsetOption(schedulerKey); } delegates.release(); } if (scheduler == null) { // If no schedule was found, create a default instance and save it locally scheduler = schedulerKey.defaultValue(); setOption(schedulerKey, scheduler); } return scheduler; } /** * Validates schedule. * <p> * This can be used to validate schedules before use. This is invoked by {@link #initialize()}. * <p> * The implementation currently only validates custom schedules but could be used to validate * other schedules for debugging purposes subject to an option setting. * <p> * This uses the the current {@linkplain #getSchedulerKey() scheduler key} to find what * {@linkplain SchedulerOptionKey#getValidatorKey() validator} to use, if any. * @since 0.08 */ protected void validateSchedule(ISchedule schedule) throws ScheduleValidationException { if (schedule.isCustom()) { final SchedulerOptionKey schedulerKey = getSchedulerKey(); if (schedulerKey != null) { ScheduleValidatorOptionKey validatorKey = schedulerKey.getValidatorKey(); if (validatorKey != null) { ScheduleValidator validator = validatorKey.instantiate(this); validator.validate(schedule); } } } } @Override public void setScheduler(@Nullable IScheduler scheduler) { final SchedulerOptionKey schedulerKey = getSchedulerKey(); if (schedulerKey != null) { if (scheduler == null) { schedulerKey.unset(this); } else { schedulerKey.set(this, scheduler); } } else if (scheduler != null) { throw new UnsupportedOperationException(String.format("%s does not support schedulers", this)); } } @Override public ISolverFactorGraph getRootSolverGraph() { return _solverNodeMapping.getRootSolverGraph(); } @Override public final SolverNodeMapping getSolverMapping() { return _solverNodeMapping; } /*---------------------------- * ISolverFactorGraph methods */ @SuppressWarnings("null") @Override public SEdge createEdgeState(EdgeState edge) { return null; } @NonNull // FIXME - workaround for Eclipse JDT bug (467610?) @Override public abstract SFactor createFactor(Factor factor); @NonNull // FIXME - workaround for Eclipse JDT bug (467610?) @Override public abstract SVariable createVariable(Variable variable); @SuppressWarnings("null") @Override public SBlock createVariableBlock(VariableBlock block) { return null; } @Override public @Nullable DataLayer<? extends IDatum> getConditioningLayer() { return _solverNodeMapping.getConditioningLayer(); } @Override public @Nullable SEdge getSolverEdge(EdgeState edge) { return getSolverEdge(edge, true); } @Override public @Nullable SEdge getSolverEdge(int edgeIndex) { return getSolverEdge(edgeIndex, true); } /** * {@inheritDoc} * <p> * The default implementation simply returns {@link Factor#getSolver()}, which * assumes that the {@code factor}'s model is currently attached to this solver graph. * Subclasses may override this to return a more precise type or to support solvers that * can still be used when they are detached from the model. */ @SuppressWarnings("unchecked") @Override public ISolverFactor getSolverFactor(Factor factor) { return _solverNodeMapping.getSolverFactor(factor); } @Override public @Nullable ISolverNode getSolverNodeByLocalId(int localId) { final int index = Ids.indexFromLocalId(localId); switch (localId >>> Ids.LOCAL_ID_TYPE_OFFSET) { case Ids.FACTOR_TYPE: return getSolverFactorByIndex(index); case Ids.GRAPH_TYPE: return getSolverSubgraphByIndex(index); case Ids.VARIABLE_TYPE: return getSolverVariableByIndex(index); default: return null; } } @Override public ISolverFactorGraph getSolverSubgraph(FactorGraph subgraph) { return _solverNodeMapping.getSolverGraph(subgraph); } @SuppressWarnings("unchecked") @Override public ISolverVariable getSolverVariable(Variable variable) { return _solverNodeMapping.getSolverVariable(variable); } @Override public @Nullable ISolverVariableBlock getSolverVariableBlock(VariableBlock block) { return _solverNodeMapping.getSolverVariableBlock(block); } /** * {@inheritDoc} * */ @Override public boolean hasEdgeState() { return false; } @Override public void moveMessages(ISolverNode other) { @SuppressWarnings("unchecked") SFactorGraphBase<SFactor,SVariable,SEdge,SBlock> sother = (SFactorGraphBase<SFactor,SVariable,SEdge,SBlock>)other; FactorGraph otherGraph = sother.getModelGraph(); final ExtendedArrayList<SEdge> edges = _edges; for (int i = 0, n = edges.size(); i < n; ++i) { SEdge thisEdge = edges.get(i); if (thisEdge != null) { SEdge thatEdge = requireNonNull(sother.getSolverEdge(i)); thisEdge.setFrom(thatEdge); thatEdge.reset(); } } for (SVariable svar : getSolverVariables()) { final int localId = svar.getModelObject().getLocalId(); final Variable thatVar = requireNonNull(otherGraph.getVariableByLocalId(localId)); SVariable thatSVar = requireNonNull(sother.getSolverVariable(thatVar, true)); svar.moveNonEdgeSpecificState(thatSVar); } for (ISolverFactorGraph ssubgraph : getSolverSubgraphs()) { final int localId = ssubgraph.getModelObject().getLocalId(); final FactorGraph subgraph2 = requireNonNull(otherGraph.getGraphByLocalId(localId)); ISolverFactorGraph ssubgraph2 = requireNonNull(sother.getSolverSubgraph(subgraph2, true)); ssubgraph.moveMessages(ssubgraph2); } } @Override public final void removeSolverEdge(int edgeIndex) { _edges.set(edgeIndex, null); } @Override public void removeSolverEdge(EdgeState edge) { removeSolverEdge(edge.edgeIndexInParent(_model)); if (!edge.isLocal()) { final SolverNodeMapping solvers = getSolverMapping(); FactorGraph factorParent = edge.getFactorParent(_model); if (factorParent != _model) { solvers.getSolverGraph(factorParent).removeSolverEdge(edge.factorEdgeIndex()); } else { solvers.getSolverGraph(edge.getVariableParent(_model)).removeSolverEdge(edge.variableEdgeIndex()); } } } @Override public void removeSolverFactor(ISolverFactor sfactor) { removeSolverNode(sfactor, _factors); } @Override public void removeSolverGraph(ISolverFactorGraph subgraph) { removeSolverNode(subgraph, _subgraphs); } @Override public void removeSolverVariable(ISolverVariable svariable) { // FIXME - what if boundary variable? removeSolverNode(svariable, _variables); } private void removeSolverNode(ISolverNode snode, ExtendedArrayList<?> list) { if (snode.getParentGraph() != this) { throw new IllegalArgumentException(String.format("'%s' does not belong to '%s'", snode, this)); } list.set(Ids.indexFromLocalId(snode.getModelObject().getLocalId()), null); } @SuppressWarnings("deprecation") // for SUBSCHEDULE @Override public void runScheduleEntry(IScheduleEntry entry) { switch (entry.type()) { case VARIABLE_BLOCK: { runBlockScheduleEntry((BlockScheduleEntry)entry); break; } case EDGE: { runEdgeScheduleEntry((EdgeScheduleEntry)entry); break; } case NODE: { runNodeScheduleEntry((NodeScheduleEntry)entry); break; } case SUBGRAPH: { runSubgraphEntry((SubgraphScheduleEntry)entry); break; } case SUBSCHEDULE: { runSubScheduleEntry((com.analog.lyric.dimple.schedulers.scheduleEntry.SubScheduleEntry)entry); break; } case CUSTOM: runCustomScheduleEntry(entry); break; } } protected void runBlockScheduleEntry(BlockScheduleEntry blockEntry) { final VariableBlock block = blockEntry.getBlock(); final ISolverVariableBlock sblock = getSolverVariableBlock(block); if (sblock == null || ! blockEntry.getBlockUpdater().update(sblock)) { for (Variable var : block) { _solverNodeMapping.getSolverVariable(var).update(); } } } protected void runEdgeScheduleEntry(EdgeScheduleEntry edgeEntry) { ISolverNode snode = _solverNodeMapping.getSolverNode(edgeEntry.getNode()); snode.updateEdge(edgeEntry.getPortNum()); } protected void runNodeScheduleEntry(NodeScheduleEntry nodeEntry) { ISolverNode snode = _solverNodeMapping.getSolverNode(nodeEntry.getNode()); snode.update(); } protected void runSubgraphEntry(SubgraphScheduleEntry subgraphEntry) { ISolverFactorGraph subgraph = _solverNodeMapping.getSolverGraph(subgraphEntry.getSubgraph()); subgraph.update(); } @Deprecated protected void runSubScheduleEntry(com.analog.lyric.dimple.schedulers.scheduleEntry.SubScheduleEntry subSchedule) { for (IScheduleEntry subentry : subSchedule.getSchedule()) { runScheduleEntry(subentry); } } protected void runCustomScheduleEntry(IScheduleEntry entry) { DimpleEnvironment.logError("Cannot handle custom schedule entry '%s'", entry); } @Deprecated @Override public boolean customFactorExists(String funcName) { return false; } @Override public void setConditioningLayer(@Nullable DataLayer<? extends IDatum> layer) { _solverNodeMapping.setConditioningLayer(layer); } /** * Sets number of solver iterations. * <p> * Sets {@link #getNumIterations()} and {@link BPOptions#iterations} option * to specified value. */ @Override public void setNumIterations(int numIter) { setOption(BPOptions.iterations, numIter); _numIterations = numIter; } /** * Number of solver iterations * <p> * This is set from {@link BPOptions#iterations} during {@link #initialize}. * <p> * This value is not meaningful to all solvers. */ @Override public int getNumIterations() { return _numIterations; } @Override public void update() { for (IScheduleEntry entry : getSchedule()) { runScheduleEntry(entry); } } @Override public void updateEdge(int outPortNum) { throw new DimpleException("Not supported"); } @Override public void iterate() { iterate(1); } @Override public void iterate(int numIters) { final MultiThreadingManager multithreader = _multithreader; if (multithreader == null || ! _useMultithreading) { // *** Single thread for (int iterNum = 0; iterNum < numIters; iterNum++) { update(); // Allow interruption (if the solver is run as a thread); currently interruption is allowed only between iterations, not within a single iteration if (Thread.interrupted()) return; } } else { // *** Multiple threads multithreader.iterate(numIters); } } @Override public void solveOneStep() { iterate(_numIterations); } @Override public void solve() { _model.initialize(); solveOneStep(); continueSolve(); } @Override public void continueSolve() { int i = 0; int maxSteps = _model.getNumSteps(); boolean infinite = _model.getNumStepsInfinite(); while (getModel().hasNext()) { if (!infinite && i >= maxSteps) break; getModel().advance(); solveOneStep(); i++; } } @Override public double getBetheFreeEnergy() { return getInternalEnergy() - getBetheEntropy(); } @Override public void estimateParameters(IFactorTable[] tables, int numRestarts, int numSteps, double stepScaleFactor) { throw new DimpleException("not supported by this solver"); } @Override public void baumWelch(IFactorTable [] tables,int numRestarts,int numSteps) { throw new DimpleException("not supported by this solver"); } @Override public double getBetheEntropy() { double sum = 0; // Sum up factor entropy for (Factor f : _model.getFactors()) sum += f.getBetheEntropy(); // The following would be unnecessary if we implemented inputs as single node factors for (Variable v : _model.getVariablesFlat()) sum -= v.getBetheEntropy() * (v.getSiblingCount() - 1); return sum; } @Deprecated @Override public double getScore() { double energy = 0; // FIXME: get*Top() methods copy all the objects into a new collection. // That should not be necessary. for (Variable v : getModel().getVariablesTop()) energy += v.getScore(); for (FactorBase f : getModel().getFactorsTop()) energy += f.getScore(); return energy; } @Override public double getInternalEnergy() { double sum = 0; //Sum up factor internal energy for (Factor f : _model.getFactors()) sum += f.getInternalEnergy(); //The following would be unnecessary if we implemented inputs as single node factors for (Variable v : _model.getVariablesFlat()) sum += v.getInternalEnergy(); return sum; } @Override public ISolverBlastFromThePastFactor createBlastFromThePast(BlastFromThePastFactor f) { return new SBlastFromThePast(f, this); } @Override public void recordDefaultSubgraphSolver(FactorGraph subgraph) { setSubgraphSolver(subgraph, subgraph.getSolver()); } /*********************************************** * * Threading for Ctrl+C * ***********************************************/ // FIXME: this is not really thread safe! There is nothing to prevent you from calling // these methods before the previous thread is done. // For running as a thread, which allows the solver to be interrupted. // This is backward compatible with versions of the modeler that call solve() directly. private volatile @Nullable Thread _thread; private @Nullable Exception _exception = null; // For throwing exceptions back up to client when solve is running in a thread @Override public void startContinueSolve() { final Thread thread = _thread = new DimpleThread(new Runnable() { @Override public void run() { try { continueSolve(); } catch (Exception e) { _exception = e; // Pass any exceptions to the main thread so they can be passed to the client } } } ); thread.start(); } @Override public void startSolveOneStep() { final Thread thread = _thread = new DimpleThread(new Runnable() { @Override public void run() { try { solveOneStep(); } catch (Exception e) { _exception = e; // Pass any exceptions to the main thread so they can be passed to the client } } } ); thread.start(); } @Override public void startSolver() { final Thread thread = _thread = new DimpleThread(new Runnable() { @Override public void run() { try { solve(); } catch (Exception e) { _exception = e; // Pass any exceptions to the main thread so they can be passed to the client } } } ); thread.start(); } @Override public void interruptSolver() { final Thread thread = _thread; if (thread != null) { System.out.println(">>> Interrupting solver"); thread.interrupt(); } } @Override public boolean isSolverRunning() { final Exception e = _exception; if (e != null) { _exception = null; // Clear the exception; the exception should happen only once; no exception if this is called again throw new DimpleException(e); // Pass the exception up to the client } else { final Thread thread = _thread; if (thread != null) return thread.isAlive(); else return false; } } // Allow interruption (if the solver is run as a thread) protected void interruptCheck() throws InterruptedException { try {Thread.sleep(0);} catch (InterruptedException e) { Thread.currentThread().interrupt(); throw e; } } /*********************************************** * * For multi-threaded computation * ***********************************************/ @Override public void useMultithreading(boolean use) { if (_multithreader == null) throw new DimpleException("Multithreading is not currently supported by this solver."); else _useMultithreading = use; setOption(SolverOptions.enableMultithreading, use); } @Override public boolean useMultithreading() { return _useMultithreading; } @Matlab public MultiThreadingManager getMultithreadingManager() { final MultiThreadingManager multithreader = _multithreader; if (multithreader == null) throw new DimpleException("Multithreading is not currently supported by this solver."); else return multithreader; } protected void setMultithreadingManager(@Nullable MultiThreadingManager manager) { _multithreader = manager; } /*********************************************** * * Initialization methods * ***********************************************/ /** * Initialize solver graph. * <p> * Default implementation does the following: * <ul> * <li>Initializes {@linkplain #getNumIterations() iterations} and multithreading from options. * <li>Builds and {@linkplain #validateSchedule(ISchedule) validates} the schedule. * <li>{@linkplain #initializeSolverEdges() Initializes solver edge state}. * <li>Invokes {@linkplain ISolverNode#initialize() initialize} on contents of graph in this order * <ol> * <li>owned solver variables * <li>boundary solver variables (only if this is the root solver graph) * <li>solver variable blocks * <li>solver factors * <li>solver subgraphs * </ol> * </ul> */ @Override public void initialize() { _numIterations = getOptionOrDefault(BPOptions.iterations); _useMultithreading = getOptionOrDefault(SolverOptions.enableMultithreading); validateSchedule(getSchedule()); initializeSolverEdges(); FactorGraph fg = _model; for (Variable variable : fg.getOwnedVariables()) { requireNonNull(getSolverVariable(variable, true)).initialize(); } if (!fg.hasParentGraph()) // FIXME: redundant? { for (int i = 0, end = fg.getBoundaryVariableCount(); i <end; ++i) { getSolverVariable(fg.getBoundaryVariable(i)).initialize(); } } for (VariableBlock block : fg.getOwnedVariableBlocks()) { ISolverVariableBlock sblock = getSolverVariableBlock(block); if (sblock != null) { sblock.initialize(); } } for (Factor f : fg.getOwnedFactors()) { requireNonNull(getSolverFactor(f, true)).initialize(); } for (FactorGraph g : fg.getOwnedGraphs()) { requireNonNull(getSolverSubgraph(g, true)).initialize(); } } /*********************************************** * * Stuff for rolled up graphs * ***********************************************/ @Deprecated @Override public @Nullable Object getInputMsg(int portIndex) { return null; } @Deprecated @Override public @Nullable Object getOutputMsg(int portIndex) { return null; } @Override public void postAdvance() { } @Override public void postAddFactor(Factor f) { } @Override public void postSetSolverFactory() { } /** * {@inheritDoc} * * The default implementation always returns null. */ @Override public @Nullable String getMatlabSolveWrapper() { return null; } @Override public boolean checkAllEdgesAreIncludedInSchedule() { return true; // By default assume all edges must be included unless told otherwise; TODO: should this be the default? } /*-------------------------- * Protected helper methods */ /** * Description name for solver for use in error messages. * @since 0.08 */ abstract protected String getSolverName(); protected DimpleException unsupportedVariableType(Variable var) { return new DimpleException("'%s' solver does not support %s variables", getSolverName(), var.getClass().getSimpleName()); } /*--------------- * Inner classes * * These provide iterable views of solver objects in this graph and subgraphs. */ // TODO - make implicit instantiation optional private abstract static class SNodeIterator<N extends Node,SN extends ISolverNode> extends UnmodifiableIterator<SN> { private final Iterator<N> _iter; private SNodeIterator(Collection<N> collection) { _iter = collection.iterator(); } @Override public final boolean hasNext() { return _iter.hasNext(); } @Override public final SN next() { return map(_iter.next()); } abstract SN map(N node); } private abstract static class SNodes<N extends Node, SN extends ISolverNode> extends AbstractCollection<SN> { final Collection<N> _nodes; private SNodes(Collection<N> nodes) { _nodes = nodes; } @Override public int size() { return _nodes.size(); } } private class OwnedSFactorIterator extends SNodeIterator<Factor,SFactor> { private OwnedSFactorIterator(Collection<Factor> iterable) { super(iterable); _factors.growSize(iterable.size()); } @NonNull // FIXME - workaround for Eclipse JDT bug (467610?) @Override public SFactor map(Factor factor) { return requireNonNull(getSolverFactor(factor, true)); } } private class OwnedSFactors extends SNodes<Factor, SFactor> { private OwnedSFactors() { super(getModelGraph().getOwnedFactors()); } @Override public Iterator<SFactor> iterator() { return new OwnedSFactorIterator(_nodes); } } private class OwnedSVarIterator extends SNodeIterator<Variable,SVariable> { private OwnedSVarIterator(Collection<Variable> iterable) { super(iterable); _variables.growSize(iterable.size()); } @NonNull // FIXME - workaround for Eclipse JDT bug (467610?) @Override public SVariable map(Variable variable) { return requireNonNull(getSolverVariable(variable, true)); } } private class OwnedSVars extends SNodes<Variable, SVariable> { private OwnedSVars() { super(getModelGraph().getOwnedVariables()); } @Override public Iterator<SVariable> iterator() { return new OwnedSVarIterator(_nodes); } } private class OwnedSubgraphIterator extends SNodeIterator<FactorGraph,ISolverFactorGraph> { private OwnedSubgraphIterator(Collection<FactorGraph> iterable) { super(iterable); _subgraphs.growSize(iterable.size()); } @Override public ISolverFactorGraph map(FactorGraph subgraph) { return instantiateSubgraph(subgraph); } } private class OwnedSubgraphs extends SNodes<FactorGraph, ISolverFactorGraph> { private OwnedSubgraphs() { super(getModelGraph().getOwnedGraphs()); } @Override public Iterator<ISolverFactorGraph> iterator() { return new OwnedSubgraphIterator(_nodes); } } /** * Collection of subgraphs rooted at this in breadth-first order */ private class RecursiveSubgraphs extends ArrayList<ISolverFactorGraph> { private static final long serialVersionUID = 1L; private RecursiveSubgraphs() { super(); add(SFactorGraphBase.this); // Add all subgraphs recursively in bread-first order. for (int i = 0; i < size(); ++i) { ISolverFactorGraph subgraph = get(i); addAll(subgraph.getSolverSubgraphs()); } } } private abstract class RecursiveSNodeIterator<SN extends ISolverNode> extends UnmodifiableIterator<SN> { private final Iterator<ISolverFactorGraph> _sgraphIterator = getSolverSubgraphsRecursive().iterator(); private Iterator<? extends SN> _snodeIterator = Collections.emptyIterator(); @Override public boolean hasNext() { while (!_snodeIterator.hasNext() && _sgraphIterator.hasNext()) { _snodeIterator = getNodes(_sgraphIterator.next()).iterator(); } return _snodeIterator.hasNext(); } @Override public SN next() { hasNext(); return _snodeIterator.next(); } int count() { int n = 0; while (_sgraphIterator.hasNext()) { n += getNodes(_sgraphIterator.next()).size(); } return n; } abstract Collection<? extends SN> getNodes(ISolverFactorGraph graph); } private abstract class RecursiveSNodes<SN extends ISolverNode> extends AbstractCollection<SN> { @Override public abstract RecursiveSNodeIterator<SN> iterator(); @Override public int size() { return iterator().count(); } } private class RecursiveSFactorIterator extends RecursiveSNodeIterator<ISolverFactor> { @Override Collection<? extends ISolverFactor> getNodes(ISolverFactorGraph graph) { return graph.getSolverFactors(); } } private class RecursiveSFactors extends RecursiveSNodes<ISolverFactor> { @Override public RecursiveSFactorIterator iterator() { return new RecursiveSFactorIterator(); } } private class RecursiveSVariableIterator extends RecursiveSNodeIterator<ISolverVariable> { @Override Collection<? extends ISolverVariable> getNodes(ISolverFactorGraph graph) { return graph.getSolverVariables(); } } private class RecursiveSVariables extends RecursiveSNodes<ISolverVariable> { @Override public RecursiveSVariableIterator iterator() { return new RecursiveSVariableIterator(); } } /*--------- * Methods */ public final FactorGraph getModelGraph() { return this.getModelObject(); } public final IParameterizedSolverFactorGraph<SFactor,SVariable,SEdge,SBlock> getSolverGraph() { return this; } /** * Unmodifiable collection over owned solver factors, implicitly instantiated if necessary. * @since 0.08 */ @Override public Collection<SFactor> getSolverFactors() { return new OwnedSFactors(); } @Override public Collection<? extends ISolverFactor> getSolverFactorsRecursive() { return new RecursiveSFactors(); } /** * Unmodifiable collection over owned solver variables, implicitly instantiated if necessary. * @since 0.08 */ @Override public Collection<SVariable> getSolverVariables() { return new OwnedSVars(); } @Override public Collection<? extends ISolverVariable> getSolverVariablesRecursive() { return new RecursiveSVariables(); } @Override public @Nullable ISolverFactorGraph getSolverSubgraph(FactorGraph subgraph, boolean create) { assertSameGraph(subgraph); if (create) { return instantiateSubgraph(subgraph); } else { return _subgraphs.getOrNull(Ids.indexFromLocalId(subgraph.getLocalId())); } } @Override public @Nullable ISolverFactorGraph getSolverSubgraphByIndex(int index) { return _subgraphs.getOrNull(index); } @Override public Collection<ISolverFactorGraph> getSolverSubgraphs() { return new OwnedSubgraphs(); } @Override public Collection<ISolverFactorGraph> getSolverSubgraphsRecursive() { return new RecursiveSubgraphs(); } @SuppressWarnings("unchecked") public @Nullable SEdge getSolverEdge(EdgeState edge, boolean create) { final ExtendedArrayList<SEdge> edges = _edges; final int index = edge.edgeIndexInParent(_model); SEdge result = edges.getOrNull(index); if (result == null) { if (create) { FactorGraph factorParent = edge.getFactorParent(_model); if (factorParent != _model) { // If the factor is from a different graph, get the edge from there. result = (SEdge) _solverNodeMapping.getSolverGraph(factorParent).getSolverEdge(edge); } else { result = this.createEdgeState(edge); } } edges.set(index, result); } return result; } public @Nullable SEdge getSolverEdge(int edgeIndex, boolean create) { final ExtendedArrayList<SEdge> edges = _edges; SEdge result = edges.getOrNull(edgeIndex); if (result == null) { if (create) { final EdgeState modelEdge = _model.getGraphEdgeState(edgeIndex); if (modelEdge != null) { return getSolverEdge(modelEdge, true); } } } return result; } @Override @SuppressWarnings("unchecked") public @Nullable SFactor getSolverFactor(Factor factor, boolean create) { assertSameGraph(factor); final int index = Ids.indexFromLocalId(factor.getLocalId()); final ExtendedArrayList<SFactor> factors = _factors; @SuppressWarnings("unchecked") SFactor sfactor = factors.getOrNull(index); if (sfactor == null || sfactor.getModelObject() != factor) { if (create) { if (factor instanceof BlastFromThePastFactor) { // FIXME - hacky sfactor = (SFactor)this.createBlastFromThePast((BlastFromThePastFactor)factor); } else { sfactor = this.createFactor(factor); factors.set(index, sfactor); } } else { sfactor = null; } factors.set(index, sfactor); } return sfactor; } @Override public @Nullable ISolverFactor getSolverFactorByIndex(int index) { return _factors.getOrNull(index); } @SuppressWarnings("null") @Override public ISolverFactor getSolverFactorForEdge(EdgeState edge) { final FactorGraph fg = _model; final FactorGraph factorParent = edge.getFactorParent(fg); final ISolverFactorGraph sfactorParent = fg == factorParent ? this : _solverNodeMapping.getSolverGraph(factorParent); return sfactorParent.getSolverFactorByIndex(edge.factorIndex()); } public void initializeSolverEdges() { ExtendedArrayList<SEdge> edges = _edges; if (hasEdgeState()) { final int n = getModelGraph().getGraphEdgeStateMaxIndex() + 1; edges.setSize(n); for (int i = 0; i < n; ++i) { SEdge sedge = getSolverEdge(i, true); if (sedge != null) { sedge.reset(); } } } } public ISolverFactorGraph instantiateSubgraph(FactorGraph subgraph) { assertSameGraph(subgraph); final int index = Ids.indexFromLocalId(subgraph.getLocalId()); final ExtendedArrayList<ISolverFactorGraph> graphs = _subgraphs; @SuppressWarnings("unchecked") ISolverFactorGraph sgraph = graphs.getOrNull(index); if (sgraph == null || sgraph.getModelObject() != subgraph) { sgraph = this.createSubgraph(subgraph); sgraph.setParent(this); graphs.set(index, sgraph); if (this == _model.getSolver()) { // If parent is default solver for it's graph, make this the default solver for the subgraph. subgraph.setSolver(sgraph); } } return sgraph; } @Override public @Nullable SVariable getSolverVariable(Variable variable, boolean create) { assertSameGraph(variable); final int index = Ids.indexFromLocalId(variable.getLocalId()); final ExtendedArrayList<SVariable> variables = _variables; @SuppressWarnings("unchecked") SVariable svar = variables.getOrNull(index); if (svar == null || svar.getModelObject() != variable) { if (create) { svar = this.createVariable(variable); svar.createNonEdgeSpecificState(); } else { svar = null; } variables.set(index, svar); } return svar; } public PrimitiveIterable.OfInt getSolverVariableIndices() { return new NonNullListIndices(_variables); } @Override public @Nullable ISolverVariable getSolverVariableByIndex(int index) { return _variables.getOrNull(index); } @Override public @Nullable SBlock getSolverVariableBlock(VariableBlock block, boolean create) { assertSameGraph(block); final int index = Ids.indexFromLocalId(block.getLocalId()); final ExtendedArrayList<SBlock> sblocks = _blocks; SBlock sblock = sblocks.getOrNull(index); if (sblock == null || sblock.getModelObject() != block) { if (create) { sblock = createVariableBlock(block); } sblocks.set(index, sblock); } return sblock; } @Override public @Nullable SBlock getSolverVariableBlockByIndex(int index) { return _blocks.getOrNull(index); } @Override public Collection<SBlock> getSolverVariableBlocks() { return Collections.unmodifiableList(_blocks); } public void setSubgraphSolver(FactorGraph subgraph, @Nullable ISolverFactorGraph sgraph) { assertSameGraph(subgraph); final int index = Ids.indexFromLocalId(subgraph.getLocalId()); final ExtendedArrayList<ISolverFactorGraph> graphs = _subgraphs; if (sgraph != null) { sgraph.setParent(this); } graphs.set(index, sgraph); } /*----------------- * Private methods */ private void assertSameGraph(IFactorGraphChild child) { if (child.getParentGraph() != this.getModelObject()) { throw new IllegalArgumentException(String.format("'%s' does not belong to graph.", child)); } } }