/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.transform; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Objects; import java.util.Set; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.domains.JointDiscreteDomain; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.util.misc.Internal; import com.google.common.collect.Iterables; /** * Junction tree mapping generated by {@link JunctionTreeTransform}. * <p> * This holds references to the source factor graph, the transformation of the graph into a tree, * and information that describes the mapping between the two. * <p> * @since 0.05 * @author Christopher Barber */ public class JunctionTreeTransformMap { /*------- * State */ private final FactorGraph _sourceModel; private final long _sourceVersion; private final FactorGraph _targetModel; private final @Nullable Map<Factor,Factor> _sourceToTargetFactors; private final @Nullable Map<Variable, Variable> _sourceToTargetVariables; /** * Newly created joint variables that are deterministically computed from component variables. */ private final LinkedHashMap<Variable, AddedJointVariable<?>> _addedDeterministicVariables; private final Set<Variable> _conditionedVariables; /** * Represents a variable that joins two or more other variables along an edge between * two factors in the target model to ensure that it is singly connected. There may be * multiple such joined variables for the same set of underlying variables. * * @param <Var> specifies the variable type. Currently only {@link Discrete} is supported. * @since 0.05 * @author Christopher Barber */ public static abstract class AddedJointVariable<Var extends Variable> { protected final Var _variable; protected final Var[] _inputs; protected AddedJointVariable(Var newVariable, Var[] inputVariables) { _variable = newVariable; _inputs = inputVariables; } @Internal public abstract void updateGuess(); /** * The domain of the joined variable. */ public Domain getDomain() { return _variable.getDomain(); } /** * The joined variable itself. */ public Var getVariable() { return _variable; } /** * The i'th variable that is joined into this one. */ public Var getInput(int i) { return _inputs[i]; } /** * The number of variables that were joined into this one. */ public final int getInputCount() { return _inputs.length; } /** * @category internal */ @Internal public abstract void updateValue(Value newVariableValue, Value[] inputs); } public static class AddedJointDiscreteVariable extends AddedJointVariable<Discrete> { /** * @param newVariable * @param inputVariables */ public AddedJointDiscreteVariable(Discrete newVariable, Discrete[] inputVariables) { super(newVariable, inputVariables); assert(invariantsHold()); } private boolean invariantsHold() { JointDomainIndexer domain = getDomain().getDomainIndexer(); assert(domain.size() == _inputs.length); for (int i = 0; i < _inputs.length; ++i) { assert(domain.get(i) == _inputs[i].getDomain()); } return true; } @Override public JointDiscreteDomain<?> getDomain() { return (JointDiscreteDomain<?>) getVariable().getDomain(); } @Override public void updateGuess() { final JointDomainIndexer indexer = getDomain().getDomainIndexer(); final int[] indices = indexer.allocateIndices(null); boolean allWereSet = true; for (int i = 0; i < _inputs.length; ++i) { Discrete input = getInput(i); allWereSet &= input.guessWasSet() || input.hasFixedValue(); indices[i] = getInput(i).getGuessIndex(); } Discrete var = getVariable(); if (allWereSet) { var.setGuessIndex(indexer.jointIndexFromIndices(indices)); } else { var.setGuess(null); } } @Override public void updateValue(Value newVariableValue, Value[] inputs) { JointDomainIndexer indexer = getDomain().getDomainIndexer(); newVariableValue.setIndex(indexer.jointIndexFromValues(inputs)); } } /*-------------- * Construction */ protected JunctionTreeTransformMap(FactorGraph source, FactorGraph target) { final boolean identity = (source == target); _sourceModel = source; _sourceVersion = source.structureVersion(); _targetModel = target; _sourceToTargetVariables = identity? null : new HashMap<Variable,Variable>(source.getVariableCount()); _sourceToTargetFactors = identity? null : new HashMap<Factor,Factor>(source.getFactorCount()); _addedDeterministicVariables = new LinkedHashMap<Variable, AddedJointVariable<?>>(); _conditionedVariables = new LinkedHashSet<Variable>(); } protected JunctionTreeTransformMap(FactorGraph source) { this(source, source); } static JunctionTreeTransformMap create(FactorGraph source, FactorGraph target) { return new JunctionTreeTransformMap(source, target); } static JunctionTreeTransformMap identity(FactorGraph model) { return new JunctionTreeTransformMap(model); } /*--------- * Methods */ public Iterable<AddedJointVariable<?>> addedJointVariables() { return Iterables.unmodifiableIterable(_addedDeterministicVariables.values()); } public @Nullable <Var extends Variable> AddedJointVariable<Var> getAddedDeterministicVariable(Var targetVariable) { @SuppressWarnings("unchecked") AddedJointVariable<Var> var = (AddedJointVariable<Var>) _addedDeterministicVariables.get(targetVariable); return var; } /** * Unmodifiable set of source variables that have been conditioned out of * the target graph. */ public Set<Variable> conditionedVariables() { return Collections.unmodifiableSet(_conditionedVariables); } /** * True if mapping is the identity mapping, which is a simple copy of the graph. */ public boolean isIdentity() { return _sourceToTargetVariables == null; } /** * True if the current mapping is up-to-date with respect to the current state of * the {@link #source()} model (and therefore can be reused for inference). */ public boolean isValid() { if (_sourceVersion != _sourceModel.structureVersion()) { return false; } for (Variable sourceVar : _conditionedVariables) { if (!sourceVar.hasFixedValue()) return false; Variable targetVar = sourceToTargetVariable(sourceVar); if (!targetVar.hasFixedValue()) return false; if (!Objects.equals(sourceVar.getPrior(), targetVar.getPrior())) return false; } return true; } /** * The original model from which the transformation was generated. */ public FactorGraph source() { return _sourceModel; } /** * Returns the target factor that subsumes the given {@code sourceFactor}. * <p> * As long as the transform {@link #isValid()} this is guaranteed to return a * non-null variable in {@link #target()} for every variable in {@link #source()}. * Note that unlike {@link #sourceToTargetVariable(Variable)} the target factor * may not exactly correspond to the source factor. Instead it may represent the * product of multiple factors. * <p> * @see #sourceToTargetFactors() */ public Factor sourceToTargetFactor(Factor sourceFactor) { final Map<Factor,Factor> sourceToTargetFactors = _sourceToTargetFactors; if (sourceToTargetFactors == null) { return sourceFactor; } return sourceToTargetFactors.get(sourceFactor); } /** * Returns a read-only mapping from factors in {@link #source()} to factors * in {@link #target()}. * * @see #sourceToTargetFactor(Factor) */ public Map<Factor,Factor> sourceToTargetFactors() { if (_sourceToTargetFactors == null) { return Collections.emptyMap(); } return Collections.unmodifiableMap(_sourceToTargetFactors); } /** * Returns the target variable corresponding to the given {@code sourceVariable}. * <p> * As long as the transform {@link #isValid()} this is guaranteed to return a * non-null variable in {@link #target()} for every variable in {@link #source()}. * <p> * @see #sourceToTargetVariables() */ public Variable sourceToTargetVariable(Variable sourceVariable) { final Map<Variable, Variable> sourceToTargetVariables = _sourceToTargetVariables; if (sourceToTargetVariables == null) { return sourceVariable; } return sourceToTargetVariables.get(sourceVariable); } /** * Returns a read-only mapping from variables in {@link #source()} to variables * in {@link #target()}. * * @see #sourceToTargetVariable(Variable) */ public Map<Variable,Variable> sourceToTargetVariables() { if (_sourceToTargetVariables == null) { return Collections.emptyMap(); } return Collections.unmodifiableMap(_sourceToTargetVariables); } /** * Value of {@link FactorGraph#getVersionId()} of {@link #source()} when * transform map was created. */ public long sourceVersion() { return _sourceVersion; } /** * The generated target model generated from {@link #source()} by {@link JunctionTreeTransform}. * <p> * As long as {@link #isValid()} this will have variables corresponding to the ones in the source model. * <p> * @see #sourceToTargetVariable(Variable) * @see #sourceToTargetFactor(Factor) */ public FactorGraph target() { return _targetModel; } /*------------------ * Internal methods */ /** * @category internal */ @Internal public void updateGuesses() { for (Map.Entry<Variable,Variable> entry : sourceToTargetVariables().entrySet()) { Variable sourceVar = entry.getKey(); Variable targetVar = entry.getValue(); if (!sourceVar.guessWasSet()) { targetVar.setGuess(null); } else { if (sourceVar instanceof Discrete) { ((Discrete)targetVar).setGuessIndex(((Discrete)sourceVar).getGuessIndex()); } else { targetVar.setGuess(sourceVar.getGuess()); } } } for (AddedJointVariable<?> added : addedJointVariables()) { added.updateGuess(); } } /*----------------- * Package methods */ void addConditionedVariable(Variable variable) { assert(variable.hasFixedValue()); _conditionedVariables.add(variable); } void addDeterministicVariable(AddedJointVariable<?> addedVar) { _addedDeterministicVariables.put(addedVar.getVariable(), addedVar); } void addFactorMapping(Factor sourceFactor, Factor targetFactor) { Objects.requireNonNull(_sourceToTargetFactors).put(sourceFactor, targetFactor); } void addVariableMapping(Variable sourceVariable, Variable targetVariable) { Objects.requireNonNull(_sourceToTargetVariables).put(sourceVariable, targetVariable); } }