/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.junctiontree; import static java.util.Objects.*; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.Map; import java.util.Map.Entry; import org.eclipse.jdt.annotation.NonNull; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.FactorTable; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.model.domains.JointDomainReindexer; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap; import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap.AddedJointVariable; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.solvers.core.SFactorBase; import com.analog.lyric.dimple.solvers.core.STableFactorBase; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor; import cern.colt.list.IntArrayList; /** * @since 0.05 * @author Christopher Barber * */ public class JunctionTreeSolverFactor extends SFactorBase { /*------- * State */ private final JunctionTreeSolverGraphBase<?> _root; private @Nullable ISolverFactor _delegate; private @Nullable JointDomainReindexer _reindexer; private boolean _reindexerComputed = false; /*-------------- * Construction */ JunctionTreeSolverFactor(Factor modelFactor, JunctionTreeSolverGraphBase<?> parent) { super(modelFactor, parent); _root = parent; } /*--------------------- * ISolverNode methods */ @Override public double getBetheEntropy() { double sum = 0; final double [] beliefs = getBelief(); for (double belief : beliefs) { sum -= belief * Math.log(belief); } return sum; } @Override public JunctionTreeSolverGraphBase<?> getRootSolverGraph() { return _root; } @Deprecated @Override public double getScore() { throw unsupported("getScore"); } @Override public void doUpdateEdge(int outPortNum) { throw unsupported("updateEdge"); } /*----------------------- * ISolverFactor methods */ @Override public double[] getBelief() { final ISolverFactor delegate = requireNonNull(getDelegate()); final double[] beliefs = (double[]) delegate.getBelief(); final JointDomainReindexer reindexer = getDelegateReindexer(); if (reindexer == null) { // No conversion necessary return beliefs; } final IFactorTable delegateTable = ((STableFactorBase)delegate).getFactorTable(); delegateTable.getDomainIndexer(); final IFactorTable beliefTable = FactorTable.create(delegateTable.getDomainIndexer()); beliefTable.setWeightsSparse(delegateTable.getIndicesSparseUnsafe(), beliefs); final IFactorTable convertedTable = FactorTable.convert(beliefTable, reindexer); convertedTable.setDirected(null); convertedTable.normalize(); return convertedTable.getWeightsSparseUnsafe(); } @Override public int[][] getPossibleBeliefIndices() { final ISolverFactor delegate = requireNonNull(getDelegate()); final JointDomainReindexer reindexer = getDelegateReindexer(); if (reindexer == null) { // No conversion necessary return delegate.getPossibleBeliefIndices(); } // TODO: perhaps we should cache this state with getBelief() final IFactorTable delegateTable = ((STableFactorBase)delegate).getFactorTable(); delegateTable.getDomainIndexer(); final IFactorTable beliefTable = FactorTable.create(delegateTable.getDomainIndexer()); beliefTable.setWeightsSparse(delegateTable.getIndicesSparseUnsafe(), (double[])delegate.getBelief()); final IFactorTable convertedTable = FactorTable.convert(beliefTable, reindexer); return convertedTable.getIndicesSparseUnsafe(); } /*----------------- * Private methods */ private @Nullable ISolverFactor getDelegate() { final ISolverFactor delegate = _delegate; if (delegate != null) { return delegate; } else { final Factor sourceFactor = getFactor(); final Factor targetFactor = requireNonNull(_root.getTransformMap()).sourceToTargetFactor(sourceFactor); return _delegate = targetFactor.getSolver(); } } private @Nullable JointDomainReindexer getDelegateReindexer() { if (!_reindexerComputed) { _reindexerComputed = true; // TODO: detect when no conversion is needed final JunctionTreeTransformMap transformMap = requireNonNull(_root.getTransformMap()); final Factor sourceFactor = getFactor(); final Factor targetFactor = transformMap.sourceToTargetFactor(sourceFactor); // Create mapping from target vars to their index in source factor final int nSourceVars = sourceFactor.getSiblingCount(); final Map<Variable,Integer> targetVarToSourceIndex = new LinkedHashMap<Variable, Integer>(nSourceVars); for (int si = 0; si < nSourceVars; ++si) { final Variable sourceVar = sourceFactor.getSibling(si); final Variable targetVar = transformMap.sourceToTargetVariable(sourceVar); if (null != targetVarToSourceIndex.put(targetVar, si)) { // FIXME - junction tree support duplicate variables throw new DimpleException("junction tree does not support factor with duplicate variables"); } } final int nTargetVars = targetFactor.getSiblingCount(); int removeIndex = nSourceVars; final IntArrayList targetToSourceIndex = new IntArrayList(nTargetVars); final IntArrayList targetVarsToSplit = new IntArrayList(); for (int ti = 0; ti < nTargetVars; ++ti) { final Variable targetVar = targetFactor.getSibling(ti); final AddedJointVariable<?> addedJointVar = transformMap.getAddedDeterministicVariable(targetVar); if (addedJointVar == null) { Integer index = targetVarToSourceIndex.remove(targetVar); if (index != null) { targetToSourceIndex.add(index); } else { targetToSourceIndex.add(removeIndex++); } } else { final int nInputs = addedJointVar.getInputCount(); for (int i = 0; i < nInputs; ++i) { final Variable inputVar = addedJointVar.getInput(i); Integer index = targetVarToSourceIndex.remove(inputVar); if (index != null) { targetToSourceIndex.add(index); } else { targetToSourceIndex.add(removeIndex++); } } targetVarsToSplit.add(ti); } } @NonNull JointDomainIndexer fromDomains = targetFactor.getFactorTable().getDomainIndexer(); // TODO: get target domains without forcing instantiation of factor table? JointDomainIndexer toDomains = sourceFactor.getFactorTable().getDomainIndexer(); if (!targetVarsToSplit.isEmpty()) { targetVarsToSplit.trimToSize(); JointDomainReindexer reindexer = _reindexer = JointDomainReindexer.createSplitter(fromDomains, targetVarsToSplit.elements()); fromDomains = reindexer.getToDomains(); } // Remaining entries in targetVarToSourceIndex should be conditioned variables // that were removed from the factor final int nConditioned = targetVarToSourceIndex.size(); if (nConditioned > 0) { final int fromSize = fromDomains.size(); final int[] conditionedValues = new int[fromSize + nConditioned]; Arrays.fill(conditionedValues, -1); final DiscreteDomain[] conditionedDomains = new DiscreteDomain[nConditioned]; int i = 0; for (Entry<Variable, Integer> entry : targetVarToSourceIndex.entrySet()) { final Discrete variable = entry.getKey().asDiscreteVariable(); final int sourceIndex = entry.getValue(); targetToSourceIndex.add(sourceIndex); conditionedValues[i + fromSize] = variable.getPriorIndex(); conditionedDomains[i] = variable.getDomain(); ++i; } fromDomains = JointDomainIndexer.concatNonNull(fromDomains, JointDomainIndexer.create(conditionedDomains)); JointDomainReindexer deconditioner = JointDomainReindexer.createConditioner(fromDomains, conditionedValues).getInverse(); _reindexer = deconditioner.appendTo(_reindexer); } targetToSourceIndex.trimToSize(); final JointDomainReindexer permuter = JointDomainReindexer.createPermuter(fromDomains, toDomains, targetToSourceIndex.elements()); _reindexer = permuter.appendTo(_reindexer); } return _reindexer; } private RuntimeException unsupported(String method) { return DimpleException.unsupportedMethod(getClass(), method); } }