/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.core; import static java.util.Objects.*; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Queue; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Variable; /** * A utility class for topologically sorting directed nodes. * <p> * @since 0.07 * @author Christopher Barber */ public class DirectedNodeSorter { /** * Computes a topological ordering for directed nodes in graph. * <p> * This assigns a non-negative integer to each node with at least one * directed edge such that each node has a value that is greater than * the value of the nodes that precede it. This function does not detect * directed cycles and will not assign meaningful orders to nodes in a cycle. * <p> * @param fg is the factor graph whose nodes will be ordered. * @return a mapping from each node with at least one directed edge to an * integer. Nodes without directed edges will not be included in the map. * @since 0.07 */ public static Map<Node, Integer> orderDirectedNodes(FactorGraph fg) { // // Build representation of directed portion of graph. // final Map<Node,NodeInfo> nodes = new HashMap<>(); for (Factor factor : fg.getFactors()) { final int[] directedTo = factor.getDirectedTo(); if (directedTo == null) { continue; } final int[] directedFrom = requireNonNull(factor.getDirectedFrom()); final int toSize = directedTo.length; final int fromSize = directedFrom.length; VarInfo[] toVars = new VarInfo[toSize]; for (int i = 0; i < toSize; ++i) { Variable variable = factor.getSibling(directedTo[i]); VarInfo varInfo = (VarInfo)nodes.get(variable); if (varInfo == null) { varInfo = new VarInfo(variable); nodes.put(variable, varInfo); } varInfo.addPrevFactor(); toVars[i] = varInfo; } FactorInfo factorInfo = new FactorInfo(factor, toVars, fromSize); nodes.put(factor, factorInfo); for (int i = 0; i < fromSize; ++i) { Variable variable = factor.getSibling(directedFrom[i]); VarInfo varInfo = (VarInfo)nodes.get(variable); if (varInfo == null) { varInfo = new VarInfo(variable); nodes.put(variable, varInfo); } varInfo.addNextFactor(factorInfo); } } if (nodes.isEmpty()) { return Collections.emptyMap(); } // // // final Map<Node,Integer> orderingMap = new HashMap<>(nodes.size()); Queue<NodeInfo> queue = new ArrayDeque<>(nodes.size()); for (NodeInfo nodeInfo : nodes.values()) { if (nodeInfo._order == 0) { queue.add(nodeInfo); } } int order = 0; while (!queue.isEmpty()) { NodeInfo nodeInfo = queue.remove(); orderingMap.put(nodeInfo._node, nodeInfo._order); for (int i = 0, n = nodeInfo.nNextNodes(); i < n; ++i) { NodeInfo nextNode = nodeInfo.nextNode(i); if (++nextNode._order == 0) { nextNode._order = ++order; queue.add(nextNode); } } } return orderingMap; } private abstract static class NodeInfo { private final Node _node; protected int _order = 0; private NodeInfo(Node node) { _node = node; } protected abstract int nNextNodes(); protected abstract NodeInfo nextNode(int i); } private static class FactorInfo extends NodeInfo { private final VarInfo[] _nextVars; private FactorInfo(Factor factor, VarInfo[] nextVars, int nPrevVars) { super(factor); _nextVars = nextVars; _order = -nPrevVars; } @Override protected int nNextNodes() { return _nextVars.length; } @Override protected VarInfo nextNode(int i) { return _nextVars[i]; } } private static class VarInfo extends NodeInfo { private final ArrayList<FactorInfo> _nextFactors = new ArrayList<>(); private VarInfo(Variable var) { super(var); } @Override protected int nNextNodes() { return _nextFactors.size(); } @Override protected FactorInfo nextNode(int i) { return _nextFactors.get(i); } private void addNextFactor(FactorInfo factorInfo) { _nextFactors.add(factorInfo); } private void addPrevFactor() { --_order; } } }