/******************************************************************************* * Copyright 2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.gibbs; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableBlock; import com.analog.lyric.dimple.solvers.core.SChild; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariableBlock; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; /** * Solver variable block state for Gibbs solver. * @since 0.08 * @author Christopher Barber */ public class GibbsVariableBlock extends SChild<VariableBlock> implements ISolverVariableBlock { /*------- * State */ /** * The owner of this object. */ private final GibbsSolverGraph _parent; /** * The root Gibbs solver graph above the parent. */ private final GibbsSolverGraph _root; /** * The solver variables in the same order as in the model block. */ private final ISolverVariableGibbs[] _vars; /** * The domains of the variables in order. */ private final Domain[] _domains; /** * Copies of the appropriate {@link Value} objects for the variables in order. */ private final Value[] _values; /** * The neighbor solver nodes for the block (i.e. the Markov blanket) */ private final ISolverNodeGibbs[] _neighbors; /** * Count of the number of updates performed since last reset. */ private long _updateCount; /** * Count of number of rejected updates since the last reset. */ private long _rejectCount; /*-------------- * Construction */ GibbsVariableBlock(final VariableBlock block, GibbsSolverGraph parent) { super(block); _parent = parent; // FIXME - assumes that Gibbs is the root, this will change when we support nesting solvers _root = (GibbsSolverGraph)_parent.getRootSolverGraph(); final int nVars = block.size(); _vars = new ISolverVariableGibbs[nVars]; _values = new Value[nVars]; _domains = new Domain[nVars]; final Set<ISolverNodeGibbs> neighborSet = new HashSet<ISolverNodeGibbs>(); for (int i = 0; i < nVars; ++i) { Variable var = block.get(i); ISolverVariableGibbs svar = parent.getSolverVariable(block.get(i)); _values[i] = svar.getCurrentSampleValue().clone(); _vars[i] = svar; _domains[i] = var.getDomain(); GibbsNeighbors neighbors = GibbsNeighbors.create(svar); if (neighbors == null) // No deterministic dependents, neighbors are same as siblings { for (Factor f : var.getSiblings()) neighborSet.add((ISolverNodeGibbs)f.getSolver()); } else // Has deterministic dependents { for (ISolverNodeGibbs neighbor : neighbors) neighborSet.add(neighbor); } } _neighbors = neighborSet.toArray(new ISolverNodeGibbs[neighborSet.size()]); } /*--------------------------------- * ISolverFactorGraphChild methods */ @Override @Nullable public ISolverFactorGraph getParentGraph() { return _parent; } @Override public ISolverFactorGraph getRootSolverGraph() { return _parent.getRootSolverGraph(); } @Override public SolverNodeMapping getSolverMapping() { return _parent.getSolverMapping(); } @Override public ISolverFactorGraph getContainingSolverGraph() { return _parent; } @Override public void initialize() { } /*------------------------------ * ISolverVariableBlock methods */ @Override public List<ISolverVariableGibbs> getSolverVariables() { return Collections.unmodifiableList(Arrays.asList(_vars)); } /*---------------------------- * GibbsVariableBlock methods */ /** * Computes the sample score of the block. * <p> * This is the sum of the {@linkplain ISolverNodeGibbs#getPotential() potentials} of * the block's solver {@linkplain #getSolverVariables() variables} and * {@linkplain #getSolverNeighbors() neighbors}. * @since 0.08 */ public double getCurrentSampleScore() { double score = 0; for (ISolverVariableGibbs v : _vars) { score += v.getPotential(); } for (ISolverNodeGibbs v : _neighbors) { score += v.getPotential(); } return score; } /** * Array of domains for variables in block. * <p> * The array should be treated as immutable! * @since 0.08 */ public Domain[] getDomains() { return _domains; } /** * The count of rejections since the last reset. * <p> * Counts the number of times that {@link #updateReject} have * been invoked since last call to {@link #resetCounts()}. * @since 0.08 */ public final long getRejectionCount() { return _rejectCount; } /** * Computes the sample score of the block and its neighbors given the specified values. * <p> * This will set the sample values on the graph and return its {@linkplain #getCurrentSampleScore() score}. * Since the sample values are expected to be set again when the update finishes, this method does not * restore the variables to their previous values. * <p> * @param sampleValues * @since 0.08 */ public double getSampleScore(Value[] sampleValues) { // WARNING: Side effect is that the current sample value changes to this sample value // Could change back but less efficient to do this, since we'll be updating the sample value anyway setCurrentSample(sampleValues); return getCurrentSampleScore(); } /** * Immutable view of neighbor solver nodes of the block. * @since 0.08 */ public List<ISolverNodeGibbs> getSolverNeighbors() { return Collections.unmodifiableList(Arrays.asList(_neighbors)); } /** * The count of updates since the last reset. * <p> * Counts the number of times that {@link #updateFinish} and {@link #updateReject} have * been invoked since last call to {@link #resetCounts()}. * @since 0.08 */ public final long getUpdateCount() { return _updateCount; } /** * Clear the rejection rate statistics * <p> * Resets counts of {@linkplain #getUpdateCount() updates} and {@linkplain #getRejectionCount() rejections} * to zero. * @since 0.08 */ public final void resetCounts() { _updateCount = 0; _rejectCount = 0; } /** * Initiate update of variable block. * <p> * This saves a copy of the current sample values of variables in the block and returns * it. The update should be finished by calling one of {@link #updateFinish(Value[])} or * {@link #updateReject()}. * <p> * @since 0.08 */ public Value[] updateStart() { for (int i = 0, n = _vars.length; i < n; ++i) { _values[i].setFrom(_vars[i].getCurrentSampleValue()); } return _values; } /** * Terminate update of variable block with rejection. * <p> * This will restore the sample values of variables in the block back to * the values saved by {@link #updateStart()}. This will also increment * the {@linkplain #getUpdateCount() update} and {@linkplain #getRejectionCount() rejection} * counts. * * @since 0.08 */ public void updateReject() { ++_updateCount; ++_rejectCount; setCurrentSample(_values); } /** * Terminate update of variable block with final sample values. * <p> * This will set the sample values of variables in the block to the specifed values. * This will also increment the {@linkplain #getUpdateCount() update} count. * @param values must contain non-null {@link Value} compatible with the corresponding * variables in the block. * @since 0.08 */ public void updateFinish(Value[] values) { ++_updateCount; setCurrentSample(values); } /*----------------- * Private methods */ private void setCurrentSample(Value[] sampleValues) { _root.deferDeterministicUpdates(); final ISolverVariableGibbs[] vars = _vars; for (int i = 0, n = vars.length; i < n; i++) { vars[i].setCurrentSample(sampleValues[i]); } _root.processDeferredDeterministicUpdates(); } }