/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.schedulers; import static java.util.Objects.*; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.INode; import com.analog.lyric.dimple.schedulers.schedule.FixedSchedule; import com.analog.lyric.dimple.schedulers.schedule.ISchedule; import com.analog.lyric.dimple.schedulers.scheduleEntry.EdgeScheduleEntry; import com.analog.lyric.dimple.schedulers.scheduleEntry.NodeScheduleEntry; import com.analog.lyric.util.misc.IMapList; /** * @author jeffb * * If this graph is a tree, or any of it's sub-graphs are trees, this * class generates a tree-schedule. Otherwise it generates uses a * scheduler to be defined in a sub-class. * * This scheduler respects any schedulers already assigned to * sub-graphs. That is, if a sub-graph already has a scheduler * associated with it, that scheduler will be used for that sub-graph * instead of this one. */ public abstract class TreeSchedulerAbstract extends BPSchedulerBase { private static final long serialVersionUID = 1L; /*------- * State */ /** * Will use node-update if number of edges to update is greater than threshold, otherwise will use edge update. */ protected int _nodeUpdateThreshold = 1; /*-------------- * Construction */ protected TreeSchedulerAbstract() { } protected TreeSchedulerAbstract(TreeSchedulerAbstract other) { _nodeUpdateThreshold = other._nodeUpdateThreshold; } /*---------------- * Object methods */ @Override public int hashCode() { return getClass().hashCode() + 13 * _nodeUpdateThreshold; } /*---------------------- * IOptionValue methods */ /** * This type of scheduler is mutable. * @see #setNodeUpdateThreshold(int) */ @Override public boolean isMutable() { return true; } /*-------------------- * IScheduler methods */ @Override public ISchedule createSchedule(FactorGraph g) { if (g.isForest()) // The graph is a tree return createTreeSchedule(g); else // Not a tree return createNonTreeSchedule(g); } // To be overridden to specify the desired non-tree scheduler. // Note that sub-graphs in the non-tree schedule should be scheduled using // the tree-scheduler sub-class. protected abstract ISchedule createNonTreeSchedule(FactorGraph g) ; @SuppressWarnings("unchecked") protected ISchedule createTreeSchedule(FactorGraph g) { FixedSchedule schedule = new FixedSchedule(this, g); HashMap<Long,NodeUpdateState> updateState = new HashMap<>(); IMapList<INode> allIncludedNodes = g.getNodes(); ArrayList<INode> startingNodes = new ArrayList<INode>(); // For all nodes, set up the node update state // Edges connected to nodes outside the graph have already been updated if (g.hasParentGraph()) { for (INode node : allIncludedNodes) { List<? extends INode> siblings = node.getSiblings(); int numSiblings = siblings.size(); NodeUpdateState nodeState = new NodeUpdateState(numSiblings); int numSiblingsInSubGraph = 0; for (int index = 0; index < numSiblings; index++) if (!allIncludedNodes.contains(siblings.get(index))) nodeState.inputUpdated(index); else numSiblingsInSubGraph++; updateState.put(node.getGlobalId(), nodeState); if (numSiblingsInSubGraph <= 1) startingNodes.add(node); } } else // If there's no parent, then nothing has already been updated { for (INode node : allIncludedNodes) { int numSiblings = node.getSiblingCount(); updateState.put(node.getGlobalId(), new NodeUpdateState(numSiblings)); if (numSiblings <= 1) startingNodes.add(node); } } // Start with leaf nodes int numStartingNodes = startingNodes.size(); for (int i = 0; i < numStartingNodes; i++) { INode node = startingNodes.get(i); boolean moreInThisPath = true; while (moreInThisPath) { NodeUpdateState nodeState = updateState.get(requireNonNull(node).getGlobalId()); INode nextNode = null; if (!nodeState.doneUpdatingAllOutputs() && nodeState.readyToUpdateAllOutputs()) { // Update all output edges that have not already been updated if (nodeState.getNumOutputPortsNotUpdated() > _nodeUpdateThreshold) { // Use node update schedule.add(new NodeScheduleEntry(node)); int nextNodeCount = 0; List<? extends INode> siblings = node.getSiblings(); int numSiblings = siblings.size(); for (int index = 0; index < numSiblings; index++) { if (!nodeState.isOutputUpdated(index)) { nodeState.outputUpdated(index); INode sibling = siblings.get(index); NodeUpdateState siblingNodeState = updateState.get(sibling.getGlobalId()); if (siblingNodeState != null) { siblingNodeState.inputUpdated(node.getReverseSiblingNumber(index)); if (nextNodeCount++ == 0) nextNode = sibling; // Do the first one next else { startingNodes.add(sibling); // Will need to come back and revisit these paths numStartingNodes++; } } } } if (nextNodeCount == 0) moreInThisPath = false; // No variables that aren't already done or boundary variables } else { // Use edge update int nextNodeCount = 0; List<? extends INode> siblings = node.getSiblings(); int numSiblings = siblings.size(); for (int index = 0; index < numSiblings; index++) { if (!nodeState.isOutputUpdated(index)) { schedule.add(new EdgeScheduleEntry(node, index)); nodeState.outputUpdated(index); INode sibling = siblings.get(index); NodeUpdateState siblingNodeState = updateState.get(sibling.getGlobalId()); if (siblingNodeState != null) { siblingNodeState.inputUpdated(node.getReverseSiblingNumber(index)); if (nextNodeCount++ == 0) nextNode = sibling; // Do the first one next else { startingNodes.add(sibling); // Will need to come back and revisit these paths numStartingNodes++; } } } } if (nextNodeCount == 0) moreInThisPath = false; // No variables that aren't already done or boundary variables } } else if (!nodeState.doneUpdatingSingleOutput() && nodeState.readyToUpdateSingleOutput()) { // We're ready to update one output, so update that output int portId = nodeState.outputToUpdate(); schedule.add(new EdgeScheduleEntry(node, portId)); nodeState.outputUpdated(portId); INode sibling = node.getSibling(portId); NodeUpdateState siblingNodeState = updateState.get(sibling.getGlobalId()); if (siblingNodeState != null) { siblingNodeState.inputUpdated(node.getReverseSiblingNumber(portId)); nextNode = sibling; } else // No node state, must be a boundary variable moreInThisPath = false; } else // Next node isn't ready to update moreInThisPath = false; node = nextNode; } } return schedule; } // Set/get the threshold determining when to use node vs. edge updates when a node is ready to update all its remaining output edges public void setNodeUpdateThreshold(int threshold) {_nodeUpdateThreshold = threshold;} public int getNodeUpdateThreshold() {return _nodeUpdateThreshold;} public void useOnlyEdgeUpdates() {_nodeUpdateThreshold = Integer.MAX_VALUE;} public void useDefaultUpdateRule() {_nodeUpdateThreshold = 1;} protected class NodeUpdateState { protected int _portCount = 0; protected int _inputUpdateCount = 0; protected int _outputUpdateCount = 0; protected boolean[] _inputUpdated; protected boolean[] _outputUpdated; protected int _outputToUpdate = -1; protected boolean _readyToUpdateSingleOutput = false; protected boolean _doneUpdatingSingleOutput = false; protected boolean _readyToUpdateAllOutputs = false; protected boolean _doneUpdatingAllOutputs = false; public NodeUpdateState(int portCount) { _portCount = portCount; _inputUpdated = new boolean[portCount]; // Note: assumes ports are indexed sequentially from 0 _outputUpdated = new boolean[portCount]; for (int i = 0; i < portCount; i++) { _inputUpdated[i] = false; _outputUpdated[i] = false; } if (portCount == 1) { _readyToUpdateSingleOutput = true; _outputToUpdate = 0; } } public final void inputUpdated(int portId) { if (!_inputUpdated[portId]) // Make sure it hasn't already been marked as updated { _inputUpdated[portId] = true; _inputUpdateCount++; if (_inputUpdateCount == _portCount) { // If all ports have been updated, then we can update any of the outputs _readyToUpdateAllOutputs = true; } else if (_inputUpdateCount == _portCount - 1) { // If all ports but one have been updated, we're ready to update the // output on the port that hasn't yet been updated _readyToUpdateSingleOutput = true; for (int i = 0; i < _portCount; i++) { if (!_inputUpdated[i]) { _outputToUpdate = i; break; } } } } } public final void outputUpdated(int portId) { if (!_outputUpdated[portId]) // Make sure it hasn't already been marked as updated { _outputUpdated[portId] = true; _outputUpdateCount++; _doneUpdatingSingleOutput = true; if (_outputUpdateCount == _portCount) _doneUpdatingAllOutputs = true; } } public final int outputToUpdate() {return _outputToUpdate;} public final boolean readyToUpdateSingleOutput() {return _readyToUpdateSingleOutput;} public final boolean doneUpdatingSingleOutput() {return _doneUpdatingSingleOutput;} public final boolean readyToUpdateAllOutputs() {return _readyToUpdateAllOutputs;} public final boolean doneUpdatingAllOutputs() {return _doneUpdatingAllOutputs;} public final boolean isInputUpdated(int portId) {return _inputUpdated[portId];} public final boolean isOutputUpdated(int portId) {return _outputUpdated[portId];} public final int getPortCount() {return _portCount;} public final int getInputUpdateCount() {return _inputUpdateCount;} public final int getOutputUpdateCount() {return _outputUpdateCount;} public final int getNumInputPortsNotUpdated() {return _portCount - _inputUpdateCount;} public final int getNumOutputPortsNotUpdated() {return _portCount - _outputUpdateCount;} } }