/******************************************************************************* * Copyright 2012-2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.factors; import static java.lang.String.*; import static java.util.Objects.*; import java.util.AbstractList; import java.util.Arrays; import java.util.BitSet; import java.util.Collections; import java.util.List; import java.util.Set; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.factorfunctions.core.TableFactorFunction; import com.analog.lyric.dimple.model.core.EdgeDirection; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.FactorPort; import com.analog.lyric.dimple.model.core.Ids; import com.analog.lyric.dimple.model.core.NodeType; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.domains.DomainList; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.model.domains.JointDomainReindexer; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Constant; import com.analog.lyric.dimple.model.variables.IConstantOrVariable; import com.analog.lyric.dimple.model.variables.IVariableToValue; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableList; import com.analog.lyric.dimple.model.variables.VariablePredicates; import com.analog.lyric.dimple.options.DimpleOptions; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.util.misc.Internal; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import cern.colt.list.IntArrayList; public class Factor extends FactorBase implements Cloneable { /*------- * State */ /** * Sentinel value for {@link #_directedTo} indicating it has not yet been set. */ private static final int[] NOT_YET_SET = new int[0]; @Deprecated private String _modelerFunctionName = ""; private FactorFunction _factorFunction; protected @Nullable int [] _directedTo = NOT_YET_SET; @Nullable int [] _directedFrom = null; /** * Represents arguments to factor function. * <p> * Includes all of the edge ids found in {@link _siblingEdges} in the same order, but with additional constant * ids inserted. If there are no constants, this field will simply point to {@link _siblingEdges}. */ private IntArrayList _argids; /** * Mapping from {@link _factorArguments} offset to sibling number offset in {@code _siblingEdges}. Offsets * corresponding to constants, which have no corresponding edge, will instead map to offsets after the * edges (i.e. the same format used by * {@link JointDomainReindexer#createPermuter(JointDomainIndexer, JointDomainIndexer, int[])} * <p> * Computed lazily */ private int[] _argToEdge = NOT_YET_SET; /** * Mapping from sibling number offset in {@code _siblingEdges} to corresponding offset in {@link _factorArguments} * <p> * The entries from [0, {@link #getSiblingCount()} - 1] maps edges to factor arguments. Entries after that map * constants to factor arguments. * <p> * Computed lazily. */ private int[] _edgeToArgNumber = NOT_YET_SET; /*-------------- * Construction */ /** * @category internal */ @Internal public Factor(FactorFunction factorFunc) { super(Ids.INITIAL_FACTOR_ID); _factorFunction = factorFunc; _modelerFunctionName = factorFunc.getName(); _argids = _siblingEdges; } protected Factor(Factor that) { super(that); _modelerFunctionName = that._modelerFunctionName; _factorFunction = that._factorFunction; // clone? // Note that this does not copy the constants or the edges _argids = _siblingEdges; int[] directedTo = _directedTo = that._directedTo; if (directedTo != null && directedTo != NOT_YET_SET && directedTo != ArrayUtil.EMPTY_INT_ARRAY) { _directedTo = directedTo.clone(); } int[] directedFrom = _directedFrom = that._directedFrom; if (directedFrom != null && directedFrom != ArrayUtil.EMPTY_INT_ARRAY) { _directedFrom = directedFrom.clone(); } } @Override public Factor clone() { return new Factor(this); } /*--------------- * INode methods */ @Override public final Factor asFactor() { return this; } /*-------------- * Node methods */ @Override protected void fixSiblingEdgeStateIndex(EdgeState edge) { super.fixSiblingEdgeStateIndex(edge); if (hasConstants()) { _argids.set(siblingNumberToArgIndex(edge.getSiblingIndex(this)), edge.edgeIndex(this)); } } /*---------------- * Factor methods */ @Override public NodeType getNodeType() { return NodeType.FACTOR; } @Override public FactorPort getPort(int siblingNumber) { return new FactorPort(this, siblingNumber); } @Override public @Nullable ISolverFactor getSolver() { final FactorGraph fg = getParentGraph(); if (fg != null) { final ISolverFactorGraph sfg = fg.getSolver(); if (sfg != null) { return sfg.getSolverFactor(this, false); } } return null; } @Override @Internal public ISolverFactor requireSolver(String method) { return requireSolver(method, getSolver()); } /** * Returns sibling edge number corresponding to given factor argument index. * <p> * If index refers to a sibling edge, this will return the corresponding sibling number that. If index * refers to a constant argument, this will return the {@linkplain #getSiblingCount() number of sibling edges} * plus the index of the constant in the ordered list of indexes. * <p> * {@link #siblingNumberToArgIndex(int)} is the inverse of this method. * <p> * This is primarily intended for use by solver implementations. * <p> * @param index must be non-negative and less than the {@linkplain #getArgumentCount() number of factor arguments}. * <p> * @since 0.08 * @see #getArgument(int) */ public int argIndexToSiblingNumber(int index) { return computeConstantInfo() ? _argToEdge[index] : index; } /** * Evaluates energy for factor for values computed from its variables and constants. * <p> * Builds a {@link Value} list using {@code v2v} mapping for variable arguments, and * the constant value directly for constant arguments, and then computes the energy * using the underlying {@linkplain #getFactorFunction() factor function}. * <p> * For use internally in solver implementations. * @since 0.08 */ public double evalEnergy(IVariableToValue v2v) { return _factorFunction.evalEnergy(fillInArgumentValues(v2v, null)); } /** * Evaluates energy for factor for given values. * <p> * Calls underlying {@linkplain #getFactorFunction() factor function} after inserting * any constant values held by the factor. * <p> * @param values specifies the values of the sibling variables to be evaluated. * @since 0.08 */ public double evalEnergy(Object[] values) { final int nEdges = getSiblingCount(); final int nArgs = getArgumentCount(); if (values.length != getSiblingCount()) { throw new IllegalArgumentException("Values length does not match number of edges"); } if (nEdges != nArgs) { // Fill in constant values Value[] tmp = new Value[nArgs]; for (int i = 0, j = 0; i < nArgs; ++i) { Value value = getConstantValueByIndex(i); if (value != null) { tmp[i] = value; } else { Variable var = getSibling(j); tmp[i] = Value.create(var.getDomain(), values[j]); ++j; } } values = tmp; } return _factorFunction.evalEnergy(values); } /** * Evaluates energy for factor for given values. * <p> * Calls underlying {@linkplain #getFactorFunction() factor function} after inserting * any constant values held by the factor. * <p> * @param values specifies the values of the sibling variables to be evaluated. * @since 0.08 */ public double evalEnergy(Value[] values) { final int nEdges = getSiblingCount(); final int nArgs = getArgumentCount(); if (values.length != getSiblingCount()) { throw new IllegalArgumentException("Values length does not match number of edges"); } if (nEdges != nArgs) { // Fill in constant values Value[] tmp = new Value[nArgs]; for (int i = 0, j = 0; i < nArgs; ++i) { Value value = getConstantValueByIndex(i); tmp[i] = value != null ? value : values[j++]; } values = tmp; } return _factorFunction.evalEnergy(values); } /** * Fills in array of values corresponding to factor's arguments. * <p> * For use internally in solver implementations. * <p> * @param v2v maps {@link Variable} instances to corresponding {@link Value} * @param values array to fill in. Only used if non-null and has length equal to * {@link #getArgumentCount() the number of factor arguments}. * @return filled in array * @since 0.08 * @category internal */ @Internal public Value[] fillInArgumentValues(IVariableToValue v2v, @Nullable Value[] values) { final int nargs = getArgumentCount(); if (values == null || values.length != nargs) { values = new Value[nargs]; } for (int i = 0; i < nargs; ++i) { IConstantOrVariable arg = getArgument(i); if (arg instanceof Constant) { values[i] = ((Constant)arg).value(); } else { Variable var = (Variable)arg; Value value = v2v.varToValue(var); if (value == null) { throw new IllegalStateException(format("There is no value for %s", var)); } values[i] = value; } } return values; } /** * Filters out constant factor argument indices. * <p> * This is primarily intended for use by solver implementations. * <p> * @param indices contains factor argument indices in any order. * @return copy of {@code indices} after removing any indices that refer to constants. If * the factor has no constants, this simply returns the original indices. * @since 0.08 */ public int[] filterConstantArgIndices(int[] indices) { if (hasConstants()) { // Filter out constant indices IntArrayList tmp = new IntArrayList(getSiblingCount()); for (int index : indices) { if (!hasConstantAtIndex(index)) { tmp.add(argIndexToSiblingNumber(index)); } } tmp.trimToSize(); indices = tmp.elements(); } return indices; } /** * @category internal * @since 0.08 */ @Internal public int[] getArgIndexToToSiblingNumberMapping() { return computeConstantInfo() ? _argToEdge.clone() : ArrayUtil.EMPTY_INT_ARRAY; } @SuppressWarnings("null") public IConstantOrVariable getArgument(int argIndex) { final FactorGraph graph = requireParentGraph(); int id = _argids.get(argIndex); int index = Ids.indexFromLocalId(id); switch (Ids.typeIndexFromLocalId(id)) { case Ids.UNKNOWN_TYPE: case Ids.EDGE_TYPE: return graph.getGraphEdgeState(index).getVariable(graph); case Ids.CONSTANT_TYPE: return graph.getConstantByLocalId(id); default: throw new IllegalStateException(); } } public int getArgumentCount() { return _argids.size(); } /** * The domains of the {@link #getArgument(int) factor arguments} in order. * <p> * Similar to {@link #getDomainList()} but for factor arguments instead of sibling * variables. If factor does not {@linkplain #hasConstants() have constants}, this is * the same as {@link #getDomainList()}. If there are constants, they will be represented * using single-element {@link DiscreteDomain}s. * <p> * @since 0.08 */ public DomainList<?> getArgumentDomains() { int nArgs = getArgumentCount(); Domain [] domains = new Domain[nArgs]; for (int i = 0; i < nArgs; i++) { IConstantOrVariable arg = getArgument(i); if (arg instanceof Constant) { // Make a single-element discrete domain for the constant value. domains[i] = DiscreteDomain.create(requireNonNull(((Constant)arg).value().getObject())); } else { domains[i] = ((Variable)arg).getDomain(); } } return DomainList.create(getDirectedTo(), domains); } /** * Returns a read-only list view of the arguments to the factor. * <p> * The arguments will contain either {@link Variable}s or {@link Constant}s. * @since 0.08 * @see #getArgument(int) */ public List<IConstantOrVariable> getArguments() { return new AbstractList<IConstantOrVariable>() { @Override public IConstantOrVariable get(int index) { return getArgument(index); } @Override public int size() { return _argids.size(); } }; } /** * The number of constant arguments to this factor. * @since 0.08 * @see #getConstants() */ public final int getConstantCount() { return _argids.size() - _siblingEdges.size(); } /** * Returns copy of the factor argument indices that refer to constants. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 */ public int[] getConstantIndices() { return computeConstantInfo() ? Arrays.copyOfRange(_edgeToArgNumber, _siblingEdges.size(), _edgeToArgNumber.length) : ArrayUtil.EMPTY_INT_ARRAY; } /** * Returns a read-only view of the constant arguments. * @since 0.08 * @see #getArguments() * @see #getConstantValues() */ public final List<Constant> getConstants() { if (!computeConstantInfo()) return Collections.emptyList(); final int nEdges = getSiblingCount(); final FactorGraph graph = requireParentGraph(); return new AbstractList<Constant> () { @Override public Constant get(int index) { final int argNumber = _edgeToArgNumber[index + nEdges]; return requireNonNull(graph.getConstantByLocalId(_argids.get(argNumber))); } @Override public int size() { return getConstantCount(); } }; } /** * Returns {@link Constant#value} for given factor argument {@code index} or null if not a constant. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 */ public @Nullable Value getConstantValueByIndex(int index) { FactorGraph graph = _parentGraph; if (graph != null) { Constant constant = graph.getConstantByLocalId(_argids.get(index)); if (constant != null) { return constant.value(); } } return null; } /** * Returns a read-only view of the constant argument values. * @since 0.08 * @see #getArguments() * @see #getConstants() */ public List<Value> getConstantValues() { if (!computeConstantInfo()) return Collections.emptyList(); final int nEdges = getSiblingCount(); final FactorGraph graph = requireParentGraph(); return new AbstractList<Value> () { @Override public Value get(int index) { final int argNumber = _edgeToArgNumber[index + nEdges]; return requireNonNull(graph.getConstantByLocalId(_argids.get(argNumber))).value(); } @Override public int size() { return getConstantCount(); } }; } public @Nullable int [] getDirectedFrom() { ensureDirectedToSet(); return _directedFrom; } public @Nullable int [] getDirectedTo() { ensureDirectedToSet(); return _directedTo; } public VariableList getDirectedToVariables() { VariableList vl = null; final int[] directedTo = getDirectedTo(); if (directedTo != null) { vl = new VariableList(directedTo.length); for (int i = 0; i < directedTo.length; i++) { vl.add(getSibling(directedTo[i])); } } else { vl = new VariableList(); } return vl; } /** * The domains of the adjacent variables in order. */ public DomainList<?> getDomainList() { int numVariables = getSiblingCount(); Domain [] domains = new Domain[numVariables]; for (int i = 0; i < numVariables; i++) { domains[i] = getSibling(i).getDomain(); } return DomainList.create(getDirectedTo(), domains); } /** * Describes direction of given edge. * <p> * @throws IndexOutOfBoundsException if {@code edgeNumber} is negative or not less than {@link #getSiblingCount()}. * @since 0.08 */ public EdgeDirection getEdgeDirection(int edgeNumber) { getSiblingEdgeIndex(edgeNumber); // call this to do a range check ensureDirectedToSet(); if (_directedTo == null) { return EdgeDirection.UNDIRECTED; } else if (isDirectedTo(edgeNumber)) { return EdgeDirection.FROM_FACTOR; } else { return EdgeDirection.TO_FACTOR; } } public FactorFunction getFactorFunction() { return _factorFunction; } public IFactorTable getFactorTable() { throw new UnsupportedOperationException("Factor tables only available on DiscreteFactors"); } @Override public String getLabel() { String name = getOption(DimpleOptions.label); if(name == null) { name = _name; if(name == null) { name = getModelerFunctionName() + "_" + getLocalId(); } } return name; } public String getModelerFunctionName() { return _modelerFunctionName; } /** * True if there is a constant argument at given index. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 */ public boolean hasConstantAtIndex(int index) { return Ids.typeIndexFromLocalId(_argids.get(index)) == Ids.CONSTANT_TYPE; } /** * True if there is a constant argument with given value type at given index. * <p> * This is primarily intended for use by solver implementations. * <p> * @param index * @param type is the type of the constant value's Object type, not the subclass of * {@link Value}. E.g., use {@code Real.class} rather than {@code RealValue.class}. * @since 0.08 */ public boolean hasConstantAtIndexOfType(int index, Class<?> type) { Value value = getConstantValueByIndex(index); return value != null && type.isInstance(value.getObject()); } /** * True if there are constants with argument index not less than given index. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 * @see #numConstantsAtOrAboveIndex(int) */ public boolean hasConstantAtOrAboveIndex(int index) { return numConstantsAtOrAboveIndex(index) > 0; } /** * True if there are constants with argument index not greater than given index. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 * @see #numConstantsAtOrBelowIndex(int) */ public boolean hasConstantAtOrBelowIndex(int index) { return numConstantsAtOrBelowIndex(index) > 0; } /** * True if factor has any constant arguments. * @since 0.08 * @see #getConstants */ public final boolean hasConstants() { return _argids != _siblingEdges; } public boolean hasConstantsInIndexRange(int minIndex, int maxIndex) { return numConstantsInIndexRange(minIndex, maxIndex) > 0; } public boolean hasFactorTable() { return getFactorFunction().factorTableExists(this); } /** * Model-specific initialization for factors. * <p> * Assumes that model variables in same graph have already been initialized. * Does NOT invoke solver factor initialize. */ @Override public void initialize() { final FactorFunction func = _factorFunction; if (func.isDirected()) // Automatically set direction if inherent in factor function { setDirectedTo(filterConstantArgIndices(requireNonNull(func.getDirectedToIndices(getArgumentCount())))); if (_factorFunction.isDeterministicDirected()) { final int[] directedTo = requireNonNull(_directedTo); for (int to : directedTo) { getSibling(to).setDeterministicOutput(); } if (directedTo.length > 0) { final int[] directedFrom = requireNonNull(_directedFrom); for (int from : directedFrom) { getSibling(from).setDeterministicInput(); } } } } } public boolean isDirected() { ensureDirectedToSet(); return _directedTo != null; } public boolean isDirectedTo(int edge) { final int[] to = getDirectedTo(); if (to == null) return false; // Assume _directedTo is sorted: final int toRange = to.length - 1; final int first = to[0]; final int last = to[toRange]; if (last - first == toRange) { return edge <= last && edge >= first; } return Arrays.binarySearch(to, edge) >= 0; } public boolean isDiscrete() { return Iterables.all(getSiblings(), VariablePredicates.isDiscrete()); } @Override public final boolean isFactor() { return true; } /** * Returns the number of constants with argument index not less than given index. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 */ public int numConstantsAtOrAboveIndex(int index) { if (!computeConstantInfo()) return 0; final int nEdges = _siblingEdges.size(); final int x = _argToEdge[index]; final int nConstants = getConstantCount(); if (x < nEdges) { // An edge index return nConstants + x - index; } else { // A constant return nConstants + nEdges - x; } } /** * Returns the number of constants with argument index not greater than given index. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 */ public int numConstantsAtOrBelowIndex(int index) { if (!computeConstantInfo()) return 0; final int nEdges = _siblingEdges.size(); final int x = _argToEdge[index]; if (x < nEdges) { // An edge index return index - x; } else { // A constant return 1 + x - nEdges; } } /** * Returns the number of constants with argument index in given inclusive range. * <p> * This is primarily intended for use by solver implementations. * <p> * @since 0.08 */ public int numConstantsInIndexRange(int minIndex, int maxIndex) { int n = numConstantsAtOrBelowIndex(maxIndex); if (n > 0 && minIndex > 0) { n -= numConstantsAtOrBelowIndex(minIndex - 1); } return n; } /** * Removes edges to all variables that have fixed values and replace with constants. * <p> * For {@link DiscreteFactor}s, this will replace the factor function with one using a newly * generated factor table. * </p> * @return the number of variable edges that were removed. * * @see Variable#hasFixedValue() */ public final int removeFixedVariables() { assertNotFrozen(); final int nEdges = getSiblingCount(); IFactorTable oldFactorTable = null; final FactorGraph parent = requireNonNull(getParentGraph()); // Visit in reverse order so that disconnect is safe. int[] valueIndices = ArrayUtil.EMPTY_INT_ARRAY; IntArrayList factorArguments = _argids; int nRemoved = 0; for (int i = nEdges; --i>=0;) { final EdgeState edge = getSiblingEdgeState(i); final Variable var = edge.getVariable(parent); Value value = var.getPriorValue(); if (value != null) { if (factorArguments == _argids) { valueIndices = new int[factorArguments.size()]; Arrays.fill(valueIndices, -1); factorArguments = factorArguments.copy(); factorArguments.trimToSize(); if (isDiscrete()) { // Before disconnecting siblings, force the factor table // to be instantiated if discrete, since it may depend on // the original edges. oldFactorTable = getFactorTable(); } } final int argIndex = siblingNumberToArgIndex(i); valueIndices[argIndex] = value.getIndex(); factorArguments.set(argIndex, parent.addConstant(value).getLocalId()); removeSiblingEdge(edge); ++nRemoved; } } if (nRemoved > 0) { setArguments(factorArguments.elements()); if (_factorFunction instanceof TableFactorFunction) { IFactorTable newTable = requireNonNull(oldFactorTable).createTableConditionedOn(valueIndices); setFactorFunction(TableFactorFunction.forFactor(this, newTable)); } else { setFactorFunction(_factorFunction); } } return nRemoved; } @Internal public void replaceVariablesWithJoint(Variable [] variablesToJoin, Variable newJoint) { assertNotFrozen(); throw new DimpleException("not implemented"); } public void setDirectedTo(int [] directedTo) { assertNotFrozen(); if (!canBeDirected()) { throw new UnsupportedOperationException(String.format("'%s' does not support setting direction", getClass().getSimpleName())); } final JointDomainIndexer curDomains = getDomainList().asJointDomainIndexer(); BitSet toSet = new BitSet(directedTo.length); final int nVariables = getSiblingCount(); final int[] directedFrom = _directedFrom = new int[nVariables-directedTo.length]; boolean sort = false; int prev = -1; for (int toVarIndex : directedTo) { if (toSet.get(toVarIndex) || toVarIndex > nVariables) throw new DimpleException("invalid edge"); if (toVarIndex < prev) { sort = true; } prev = toVarIndex; toSet.set(toVarIndex); } for (int i = 0, fromVarIndex = 0; (fromVarIndex = toSet.nextClearBit(fromVarIndex)) < nVariables; ++fromVarIndex, ++i) { directedFrom[i] = fromVarIndex; } if (sort) { Arrays.sort(directedTo); } _directedTo = directedTo; notifyConnectionsChanged(); if (curDomains != null) { JointDomainIndexer newDomains = getDomainList().asJointDomainIndexer(); if (!curDomains.equals(newDomains)) { getFactorFunction().convertFactorTable(curDomains, newDomains); } } // FIXME - don't force creation ISolverFactor sfactor = getSolver(); if (sfactor != null) { sfactor.setDirectedTo(directedTo); } } /** * Sets all edges to specified variables as output edges. * @param variables * @since 0.08 */ public final void setDirectedTo(Set<Variable> variables) { assertNotFrozen(); final IntArrayList directedTo = new IntArrayList(variables.size()); for (int i = 0, n = getSiblingCount(); i < n; ++i) { if (variables.contains(getSibling(i))) { directedTo.add(i); } } directedTo.trimToSize(); setDirectedTo(directedTo.elements()); } /** * Sets all edges to specified variables as output edges. */ public final void setDirectedTo(Variable ... variables) { setDirectedTo(Sets.newHashSet(variables)); } /** * Sets all edges to specified variables as output edges. */ public final void setDirectedTo(VariableList vl) { setDirectedTo(Sets.newHashSet(vl)); } public void setFactorFunction(FactorFunction function) { assertNotFrozen(); _factorFunction = function; if (_factorFunction.isDirected()) { // Automatically set direction if inherent in factor function int[] to = requireNonNull(_factorFunction.getDirectedToIndices(getArgumentCount())); setDirectedTo(filterConstantArgIndices(to)); } } /** * Makes factor undirected. * * @since 0.07 */ public void setUndirected() { assertNotFrozen(); _directedTo = null; _directedFrom = null; if (hasFactorTable()) { getFactorTable().setDirected(null); } } /** * Returns factor argument index corresponding to given sibling edge number. * <p> * If {@code siblingNumber} is less than {@link #getSiblingCount}, this returns the corresponding * factor argument index for that sibling edge. If it is larger than that, then this will return * the factor argument index for the nth constant where n is {@code siblingNumber} minus the * sibling count. * <p> * {@link #argIndexToSiblingNumber(int)} is the inverse of this method. * <p> * @param siblingNumber must be non-negative and less than the {@linkplain #getArgumentCount() number of * factor arguments} * @since 0.08 */ public int siblingNumberToArgIndex(int siblingNumber) { return computeConstantInfo() ? _edgeToArgNumber[siblingNumber] : siblingNumber; } /*------------------- * Protected methods */ /** * @category internal */ @Internal @Override protected void addSiblingEdgeState(EdgeState edge) { super.addSiblingEdgeState(edge); if (_argids != _siblingEdges) { final int edgeNumber = edge.getFactorToVariableEdgeNumber(); _argids.add(edgeNumber); forgetConstantInfo(); } } /** * Return true if factor supports setting direction through one of the * {@link #setDirectedTo} methods. * <p> * Returns true by default. * <p> * @since 0.07 */ @Internal protected boolean canBeDirected() { return true; } /** * @category internal */ @Internal public void createSolverObject(@Nullable ISolverFactorGraph factorGraph) { if (factorGraph != null) { factorGraph.getSolverFactor(this, true); } } /** * @category internal */ @Internal @Override protected void removeSiblingEdgeState(EdgeState edge) { if (_argids != _siblingEdges) { final int edgeNumber = edge.getFactorToVariableEdgeNumber(); int argumentIndex; if (_edgeToArgNumber != NOT_YET_SET) { argumentIndex = _edgeToArgNumber[edgeNumber]; } else { argumentIndex = -1; for (int sibling = 0; sibling <= edgeNumber; ) { ++argumentIndex; if (Ids.typeIndexFromLocalId(_argids.get(argumentIndex)) != Ids.CONSTANT_TYPE) { ++sibling; } } } _argids.remove(argumentIndex); forgetConstantInfo(); } super.removeSiblingEdgeState(edge); } /** * @category internal */ @Override @Internal protected void replaceSiblingEdgeState(EdgeState oldEdge, EdgeState newEdge) { super.replaceSiblingEdgeState(oldEdge, newEdge); if (_argids != _siblingEdges) { final int newEdgeNumber = newEdge.getFactorToVariableEdgeNumber(); final int argumentIndex = siblingNumberToArgIndex(newEdgeNumber); _argids.set(argumentIndex, newEdgeNumber); } } protected <T extends ISolverFactor> T requireSolver(String method, @Nullable T solverFactor) { if (solverFactor == null) { throw new NullPointerException(String.format("solver must be set before using '%s'.", method)); } return solverFactor; } /** * @category internal */ @Override @Internal protected void setArguments(int[] argids) { _argids = new IntArrayList(argids); forgetConstantInfo(); } @Override protected void trimToSize() { super.trimToSize(); _argids.trimToSize(); } /*-------------------- * Deprecated methods */ @Deprecated @Override public String getClassLabel() { return "Factor"; } /** * @category internal */ @Deprecated @Internal public void setSolver(@Nullable ISolverFactor sfactor) { throw new UnsupportedOperationException("Factor.setSolver no longer supported"); } /*----------------- * Private methods */ /** * Compute fields related to constant argument indexing * @return true if factor has constants * @since 0.08 */ private boolean computeConstantInfo() { if (_edgeToArgNumber == NOT_YET_SET) { final IntArrayList factorArguments = _argids; if (factorArguments == _siblingEdges) { // There are no constants. _edgeToArgNumber = ArrayUtil.EMPTY_INT_ARRAY; return false; } final int nEdges = _siblingEdges.size(); final int nArgs = factorArguments.size(); final int[] edgeToArgNumber = new int[nArgs]; final int[] argToEdgeNumber = new int[nArgs]; for (int edgei = 0, argi = 0, constanti = nEdges; argi < nArgs; ++argi) { if (Ids.typeIndexFromLocalId(_argids.get(argi)) == Ids.CONSTANT_TYPE) { edgeToArgNumber[constanti] = argi; argToEdgeNumber[argi] = constanti; ++constanti; } else { argToEdgeNumber[argi] = edgei; edgeToArgNumber[edgei] = argi; ++edgei; } } _edgeToArgNumber = edgeToArgNumber; _argToEdge = argToEdgeNumber; return true; } return _edgeToArgNumber.length != 0; } private void ensureDirectedToSet() { if (_directedTo == NOT_YET_SET) { _directedTo = null; if (canBeDirected()) { setFactorFunction(getFactorFunction()); } } } private void forgetConstantInfo() { _edgeToArgNumber = NOT_YET_SET; _argToEdge = NOT_YET_SET; } }