/******************************************************************************* * Copyright 2012-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.core; import static java.util.Objects.*; import java.io.Serializable; import java.util.AbstractCollection; import java.util.AbstractList; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Objects; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.BitSetUtil; import com.analog.lyric.dimple.events.DimpleEvent; import com.analog.lyric.dimple.events.IDimpleEventListener; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.options.DimpleOptions; import com.analog.lyric.dimple.solvers.interfaces.ISolverNode; import com.analog.lyric.options.IOptionKey; import com.analog.lyric.util.misc.IMapList; import com.analog.lyric.util.misc.Internal; import com.analog.lyric.util.misc.MapList; import com.google.common.collect.Iterators; import com.google.common.collect.UnmodifiableIterator; import cern.colt.list.IntArrayList; public abstract class Node extends FactorGraphChild implements INode { /*----------- * Constants */ /** * {@link #_topologicalFlags} value used by {@link #isMarked()} */ private static final int MARKED = 0x10000000; /** * {@link #_topologicalFlags} value used by {@link #wasVisited()} */ private static final int VISITED = 0x20000000; /** * Flags that are reserved for use by this class and should not be * used by subclasses when invoking {@link #setFlags(int)} or {@link #clearFlags()}. */ protected static final int RESERVED = 0xFF000000; /*------- * State */ protected @Nullable String _name; /** * Identifies the edges that connect to this node. * <p> * Contains integer indexes into the parent graph's edge list. */ protected final IntArrayList _siblingEdges = new IntArrayList(); /** * Temporary flags that can be used to mark the node during the execution of various algorithms * or to mark non-static attributes of the node. * <p> * The flags are automatically cleared by {@link #initialize()}. */ protected int _flags; private class SiblingEdgeStateIterator extends UnmodifiableIterator<EdgeState> { private final int _size = getSiblingCount(); private int _index; @Override public boolean hasNext() { return _index < _size; } @Override public EdgeState next() { return getSiblingEdgeState(_index++); } } private class SiblingEdgeStateIterable extends AbstractCollection<EdgeState> { @Override public Iterator<EdgeState> iterator() { return new SiblingEdgeStateIterator(); } @Override public int size() { return getSiblingCount(); } } /*-------------- * Construction */ @Internal protected Node(int id) { _id = id; } protected Node(Node other) { super(other); _id = other._id | Ids.LOCAL_ID_INDEX_MAX; _name = other._name; } /*----------------------- * IOptionHolder methods */ @Override public void clearLocalOptions() { assertNotFrozen(); super.clearLocalOptions(); } @Override public <T extends Serializable> void setOption(IOptionKey<T> key, T value) { assertNotFrozen(); super.setOption(key, value); } @Override public void unsetOption(IOptionKey<?> key) { assertNotFrozen(); super.unsetOption(key); } /*---------------------------- * IDimpleEventSource methods */ @Override public String getEventSourceName() { // FIXME - determine what this should be return toString(); } @Override public Node getModelEventSource() { return this; } @Override public void notifyListenerChanged() { clearFlags(getEventMask()); } /*--------------- * INode methods */ // FIXME - give this a better name! e.g. getReverseSiblingEdgeNumber @Override public int getReverseSiblingNumber(int index) { final EdgeState edge = getSiblingEdgeState(index); return isVariable() ? edge._factorToVariableEdgeNumber : edge._variableToFactorEdgeNumber; } @Override public @Nullable Factor asFactor() { return null; } @Override public @Nullable FactorGraph asFactorGraph() { return null; } @Override public @Nullable Variable asVariable() { return null;} @Override public boolean isFactor() { return false; } @Override public boolean isFactorGraph() { return false; } @Override public boolean isVariable() { return false; } /** * Returns newly allocated list of ancestor graphs from the root down to the parent of this node. */ public List<FactorGraph> getAncestors() { LinkedList<FactorGraph> ancestors = new LinkedList<FactorGraph>(); FactorGraph ancestor = this.getParentGraph(); while (ancestor != null) { ancestors.addFirst(ancestor); ancestor = ancestor.getParentGraph(); } return ancestors; } @Override public @Nullable FactorGraph getAncestorAtHeight(int height) { FactorGraph ancestor = this.getParentGraph(); while (height-- > 0 && ancestor != null) { ancestor = ancestor.getParentGraph(); } return ancestor; } @Override public final double getBetheEntropy() { return requireSolver("getBetheEntropy").getBetheEntropy(); } /** * Returns the closest common ancestor graph containing both this node and {@code other} * or null if there isn't one. * * @param other is another node with which to compare. * @param uncommonAncestors if non-null, then any ancestors that are not in common will * be added to this list in order from top to bottom. * * @see #getCommonAncestor(Node) */ public @Nullable FactorGraph getCommonAncestor(Node other, @Nullable List<FactorGraph> uncommonAncestors) { // First try some common special cases to avoid computation of full path to the root. FactorGraph thisParent = getParentGraph(); FactorGraph otherParent = other.getParentGraph(); if (thisParent == otherParent) { return thisParent; } if (thisParent == null || otherParent == null) { return null; } if (this == otherParent) { return otherParent; } if (other == thisParent) { return thisParent; } Iterator<FactorGraph> theseAncestors = getAncestors().iterator(); Iterator<FactorGraph> otherAncestors = other.getAncestors().iterator(); FactorGraph ancestor = null; while (theseAncestors.hasNext() && otherAncestors.hasNext()) { FactorGraph thisAncestor = theseAncestors.next(); FactorGraph otherAncestor = otherAncestors.next(); if (thisAncestor == otherAncestor) { ancestor = thisAncestor; } else { if (uncommonAncestors != null) { // Add remaining ancestors to set, if provided uncommonAncestors.add(thisAncestor); Iterators.addAll(uncommonAncestors, theseAncestors); uncommonAncestors.add(otherAncestor); Iterators.addAll(uncommonAncestors, otherAncestors); } break; } } return ancestor; } /** * Returns the closest common ancestor graph containing both this node and {@code other} * or null if there isn't one. * * @see #getCommonAncestor(Node, List) */ public @Nullable FactorGraph getCommonAncestor(Node other) { return getCommonAncestor(other, null); } @Override public INode getConnectedNodeFlat(int portNum) { return getSibling(portNum); } @Override public final double getInternalEnergy() { return requireSolver("getInternalEnergy").getInternalEnergy(); } @SuppressWarnings("deprecation") // FIXME - deprecated this method as well @Override public final double getScore() { return requireSolver("getScore").getScore(); } @Override public List<? extends INode> getSiblings() { return new AbstractList<INode>() { @Override public Node get(int index) { return getSibling(index); } @Override public int size() { return getSiblingCount(); } }; } @Override public int getSiblingCount() { return _siblingEdges.size(); } @Override public Node getSibling(int i) { return getSiblingEdgeState(i).getSibling(this); } @Override public IMapList<INode> getConnectedNodes() { return getConnectedNodesFlat(); } @Override public INode getConnectedNode(int relativeDepth, int portNum) { if (relativeDepth < 0) relativeDepth = 0; int myDepth = getDepth(); //int desiredDepth = siblingDepth - relativeDepth; int desiredDepth = myDepth+relativeDepth; //Avoid overflow if (desiredDepth < 0) desiredDepth = Integer.MAX_VALUE; INode node = getSibling(portNum); // TODO: Instead of computing depths, which is O(depth), could we instead // just look for matching parent. For example, if relativedDepth is zero // can we just walk through the sibling node's parents until we find a match // for the parent of the node for this side of the connection? for (int depth = node.getDepth(); depth > desiredDepth; --depth) { node = requireNonNull(node.getParentGraph()); } return Objects.requireNonNull(node); } @Override public ArrayList<INode> getConnectedNodeAndParents(int index) { ArrayList<INode> retval = new ArrayList<INode>(); INode n = getSibling(index); while (n != null) { retval.add(n); n = n.getParentGraph(); } return retval; } @Override public IMapList<INode> getConnectedNodes(int relativeNestingDepth) { MapList<INode> list = new MapList<INode>(); for (int i = 0, end = getSiblingCount(); i < end; i++) { list.add(getConnectedNode(relativeNestingDepth,i)); } return list; } @Override public int getDepth() { int depth = 0; for (FactorGraph parent = this.getParentGraph(); parent != null; parent = parent.getParentGraph()) { ++depth; } return depth; } @Override public int getDepthBelowAncestor(FactorGraph ancestor) { int depth = 0; for (FactorGraph parent = this.getParentGraph(); parent != null; parent = parent.getParentGraph()) { if (parent == ancestor) { return depth; } ++depth; } return -depth - 1; } @Override public IMapList<INode> getConnectedNodesFlat() { return getConnectedNodes(Integer.MAX_VALUE); } @Override public IMapList<INode> getConnectedNodesTop() { return getConnectedNodes(0); } @Override public Collection<Port> getPorts() { final int size = _siblingEdges.size(); ArrayList<Port> ports = new ArrayList<Port>(size); for (int i = 0; i < size; i++ ) ports.add(getPort(i)); return ports; } @Override public boolean hasParentGraph() { return _parentGraph != null; } @Override public void setName(@Nullable String name) { assertNotFrozen(); // TODO restrict name to valid Java identifier // Note that there are currently factors in test code with names like "f(a,b)" if(name != null && name.contains(".")) { throw new DimpleException("ERROR '.' is not a valid character in names"); } final FactorGraph parentGraph = _parentGraph; if(parentGraph != null) { parentGraph.setChildNameInternal(this, name); } this._name = name; } @Override public void setLabel(@Nullable String name) { if (name != null) { setOption(DimpleOptions.label, name); } else { unsetOption(DimpleOptions.label); } } /** * {@inheritDoc} * <p> * For regular nodes implicitly generated name will be computed * by {@link Ids#defaultNameForLocalId(int)} using the * value of {@link #getId()}. */ @Override public String getName() { String name = _name; return name != null ? name : Ids.defaultNameForLocalId(_id); } /** * @deprecated as of release 0.08 */ @Deprecated abstract public String getClassLabel(); @Override public String getQualifiedName() { StringBuilder sb = new StringBuilder(); buildQualifiedName(sb); return sb.toString(); } protected void buildQualifiedName(StringBuilder sb) { final FactorGraph parent = getParentGraph(); if (parent != null) { parent.buildQualifiedName(sb); sb.append('.'); } sb.append(getName()); } @Override public String getLabel() { String name = getOption(DimpleOptions.label); if (name == null) { name = getName(); } return name; } @Override public String getQualifiedLabel() { String s = getLabel(); final FactorGraph parentGraph = _parentGraph; if (parentGraph != null) { s = parentGraph.getQualifiedLabel() + "." + s; } return s; } @Override public @Nullable String getExplicitName() { return _name; } @Override public String toString() { return getLabel(); } @Override public final int findSibling(INode node) { return findSibling(node, 0); } @Override public final int findSibling(INode node, int start) { for (int i = start, n = getSiblingCount(); i < n; ++i) { if (node == getSibling(i)) { return i; } } return -1; } @Override @Deprecated public final int getPortNum(INode node) { int port = findSibling(node); if (port < 0) { throw new DimpleException("Nodes are not connected: " + this + " and " + node); } return port; } @Override public void initialize() { clearFlags(); } @Override public final void update() { requireSolver("update").update(); } @Override public final void updateEdge(int siblingNumber) { requireSolver("updateEdge").updateEdge(siblingNumber); } @Deprecated @Override public void updateEdge(INode other) { int num = findSibling(other); updateEdge(num); } @Override public final boolean isConnected(INode node) { INode a,b; if (getSiblingCount() <= node.getSiblingCount()) { a = this; b = node; } else { a = node; b = this; } return a.findSibling(b) >= 0; } /*-------------- * Node methods */ /** * Gets representation of i'th edge. * <p> * Note that unlike {@link #getSiblingEdgeState(int)}, this returns a temporary object * that fully describes the edge. To avoid excess allocation do not use this method within * inner loops. * <p> * @param i should be between 0 (inclusive) and {@link #getSiblingCount()} (exclusive) * @since 0.08 * @throws IndexOutOfBoundsException if {@code i} is not in range. * @since 0.08 * @see #getSiblingEdgeState(int) */ @Override public Edge getSiblingEdge(int i) { return new Edge(requireNonNull(_parentGraph), getSiblingEdgeState(i)); } /** * Returns the graph index of the i'th edge connected to this node. * @param i is a number in the range from 0 to {@link #getSiblingCount()} - 1. * @return non-negative index in parent graph for the sibling edge, which can be used with * {@link FactorGraph#getGraphEdgeState(int)} to retrieve the edge state. * @since 0.08 */ public final int getSiblingEdgeIndex(int i) { return _siblingEdges.get(i); } /** * Get state for i'th edge. * <p> * @param i should be between 0 (inclusive) and {@link #getSiblingCount()} (exclusive) * @since 0.08 * @throws IndexOutOfBoundsException if {@code i} is not in range. */ @SuppressWarnings("null") @Override public EdgeState getSiblingEdgeState(int i) { return requireParentGraph().getGraphEdgeState(_siblingEdges.get(i)); } /** * A view of the sibling edge state objects connected to this node. * @since 0.08 */ public Collection<EdgeState> getSiblingEdgeState() { return new SiblingEdgeStateIterable(); } /** * Returns the index of the edge state * @param edge an edge attached to this node * @return the index of the edge or -1 if edge is not currently attached. * @since 0.08 */ public int indexOfSiblingEdgeState(EdgeState edge) { return isVariable() ? edge._variableToFactorEdgeNumber : edge._factorToVariableEdgeNumber; } /*------------------ * Internal methods */ @Override @Internal public void clearMarked() { assertNotFrozen(); clearFlags(MARKED); } @Override @Internal public void clearVisited() { assertNotFrozen(); clearFlags(VISITED); } @Override @Internal public final boolean isMarked() { return isFlagSet(MARKED); } @Override @Internal public final boolean wasVisited() { return isFlagSet(VISITED); } @Override @Internal public final void setMarked() { assertNotFrozen(); setFlags(MARKED); } @Override @Internal public final void setVisited() { assertNotFrozen(); setFlags(VISITED); } /*------------------- * Protected methods */ /** * @category internal */ @Internal protected void addEdge(Factor factor, Variable variable) { requireNonNull(_parentGraph).addEdge(factor, variable); } /** * @category internal */ @Internal protected void addSiblingEdgeState(EdgeState edge) { final int i = _siblingEdges.size(); if (isVariable()) { edge._variableToFactorEdgeNumber = i; } else { edge._factorToVariableEdgeNumber = i; } _siblingEdges.add(edge.edgeIndex(this)); notifyConnectionsChanged(); } /** * Reset index for given edge. * <p> * To be invoked after edge index has been changed to a lower value. * <p> * @category internal * @since 0.08 */ @Internal protected void fixSiblingEdgeStateIndex(EdgeState edge) { _siblingEdges.set(edge.getSiblingIndex(this), edge.edgeIndex(this)); } /** * @category internal */ @Internal protected void removeSiblingEdge(EdgeState edge) { requireNonNull(_parentGraph).removeSiblingEdge(edge); } /** * @category internal */ @Internal protected void removeSiblingEdgeState(EdgeState edge) { if (isVariable()) { final int i = edge._variableToFactorEdgeNumber; _siblingEdges.remove(i); edge._variableToFactorEdgeNumber = -1; for (int j = _siblingEdges.size(); --j >= i;) { getSiblingEdgeState(j)._variableToFactorEdgeNumber = j; } } else { final int i = edge._factorToVariableEdgeNumber; _siblingEdges.remove(i); edge._factorToVariableEdgeNumber = -1; for (int j = _siblingEdges.size(); --j >= i;) { getSiblingEdgeState(j)._factorToVariableEdgeNumber = j; } } notifyConnectionsChanged(); } /** * @category internal */ @Internal protected void replaceSiblingEdgeState(EdgeState oldEdge, EdgeState newEdge) { assert(isFactor()); final int i = oldEdge._factorToVariableEdgeNumber; oldEdge._factorToVariableEdgeNumber = -1; _siblingEdges.set(i, newEdge.edgeIndex(this)); newEdge._factorToVariableEdgeNumber = i; notifyConnectionsChanged(); } /** * Clear all flag values. Invoked automatically by {@link #initialize()}. */ protected void clearFlags() { _flags = 0; } /** * Clear flags in given mask. * <p> * Subclasses should not use bits in the {@link #RESERVED} mask. */ protected void clearFlags(int mask) { _flags = BitSetUtil.clearMask(_flags, mask); } /** * Return mask of flag bits that are used to determine whether to * generate events. This is used by {@link #notifyListenerChanged()} * to clear the specified flag bits. It is assumed that the value of * all zeros indicates that the object needs to recompute its flags * based on the listener. * <p> * The default implementation returns zero. * * @since 0.06 */ protected int getEventMask() { return 0; } /** * True if all of the bits in {@code mask} are set in the flags. * <p> * Subclasses should not use bits in the {@link #RESERVED} mask. */ protected boolean isFlagSet(int mask) { return BitSetUtil.isMaskSet(_flags, mask); } /** * Invoked when a change is made to the siblings list. * Subclasses may override this to clear cached state that * was computed from the siblings. The default implementation * does nothing. * * @since 0.07 */ protected void notifyConnectionsChanged() { } protected final void raiseEvent(@Nullable DimpleEvent event) { if (event != null) { final IDimpleEventListener listener = getEventListener(); final boolean handled = listener != null && listener.raiseEvent(event); if (!handled) { // Listener configuration probably changed. Reconfigure source to // prevent further event creation. notifyListenerChanged(); } } } protected abstract ISolverNode requireSolver(String method); /** * @category internal */ @Internal protected void setArguments(int[] argids) { } /** * Sets all of the bits in {@code mask} in the flags. * <p> * Subclasses should not use bits in the {@link #RESERVED} mask. */ protected final void setFlags(int mask) { _flags = BitSetUtil.setMask(_flags, mask); } /** * Sets bits of flag specified by {@code mask} to {@code value}. */ @Internal protected final void setFlagValue(int mask, int value) { _flags = BitSetUtil.setMaskedValue(_flags, mask, value); } /** * Trims representation of sibling edges to their size to save memory. * @category internal * @since 0.08 */ @Internal protected void trimToSize() { _siblingEdges.trimToSize(); } /*-------------------- * Deprecated methods */ @Deprecated @Override public int getSiblingPortIndex(int siblingNumber) { return getReverseSiblingNumber(siblingNumber); } @Deprecated @Override public final void initialize(int siblingNumber) { } }