/******************************************************************************* * Copyright 2014 Analog Devices, Inc. Licensed under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance with the License. You may obtain a * copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable * law or agreed to in writing, software distributed under the License is distributed on an "AS IS" * BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License * for the specific language governing permissions and limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct; import java.util.Arrays; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.collect.Tuple2; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.solvers.optimizedupdate.CostEstimationTableWrapper; import com.analog.lyric.dimple.solvers.optimizedupdate.CostType; import com.analog.lyric.dimple.solvers.optimizedupdate.Costs; import com.analog.lyric.dimple.solvers.optimizedupdate.FactorUpdatePlan; import com.analog.lyric.dimple.solvers.optimizedupdate.IMarginalizationStep; import com.analog.lyric.dimple.solvers.optimizedupdate.IMarginalizationStepEstimator; import com.analog.lyric.dimple.solvers.optimizedupdate.ISFactorGraphToOptimizedUpdateAdapter; import com.analog.lyric.dimple.solvers.optimizedupdate.ISTableFactorSupportingOptimizedUpdate; import com.analog.lyric.dimple.solvers.optimizedupdate.IUpdateStep; import com.analog.lyric.dimple.solvers.optimizedupdate.IUpdateStepEstimator; import com.analog.lyric.dimple.solvers.optimizedupdate.TableWrapper; /** * Implements a factor update approach for sum-product that optimally shares partial results. Since * it computes all output edges, it only overrides the update method, and does not override the * update_edge method. * * @since 0.06 * @author jking */ public class TableFactorEngineOptimized extends TableFactorEngine { /** * The update plan for the factor related to this update engine. */ private final FactorUpdatePlan _updatePlan; /** * Dimple creates an instance of this class per factor. * * @param tableFactor * @since 0.06 */ public TableFactorEngineOptimized(SumProductTableFactor tableFactor, FactorUpdatePlan updatePlan) { super(tableFactor); _updatePlan = updatePlan; } @Override public void update() { _updatePlan.apply(_tableFactor); } static final class DenseMarginalizationStep implements IMarginalizationStep { private final TableWrapper _f; private final TableWrapper _g; private final int _inPortNum; private final int _d; private final int _p; DenseMarginalizationStep(final TableWrapper f, final ISFactorGraphToOptimizedUpdateAdapter helper, final int inPortNum, final int dimension, final IFactorTable g_factorTable) { _f = f; JointDomainIndexer f_indexer = f.getFactorTable().getDomainIndexer(); _p = f_indexer.getStride(dimension); _d = f_indexer.getDomainSize(dimension); _g = new TableWrapper(g_factorTable, true, helper, f.getSparseThreshold()); _inPortNum = inPortNum; } @Override public void apply(ISTableFactorSupportingOptimizedUpdate tableFactor) { final double[] f_values = _f.getValues().get(); final double[] g_values = _g.getValues().get(); final double[] inputMsg = tableFactor.getInPortMsg(_inPortNum); Arrays.fill(g_values, 0.0); int c = 0; int msg_index = 0; int g_index = 0; int g_index_limit = _p; double input_value = inputMsg[0]; for (final double value : f_values) { g_values[g_index] += value * input_value; if (++g_index == g_index_limit) { if (++msg_index == _d) { msg_index = 0; c = g_index_limit; g_index_limit += _p; } g_index = c; input_value = inputMsg[msg_index]; } } } @Override public TableWrapper getAuxiliaryTable() { return _g; } } static final class SparseMarginalizationStep implements IMarginalizationStep { private final TableWrapper _f; private final int _inPortNum; private final TableWrapper _g; private final int[] _msg_indices; private final int[] _g_sparse_indices; SparseMarginalizationStep(final TableWrapper f, final ISFactorGraphToOptimizedUpdateAdapter isFactorGraphToCostOptimizerAdapter, final int inPortNum, final int dimension, final IFactorTable g_factorTable, final Tuple2<int[][], int[]> g_and_msg_indices) { _f = f; _inPortNum = inPortNum; _g = new TableWrapper(g_factorTable, true, isFactorGraphToCostOptimizerAdapter, f.getSparseThreshold()); final int[][] g_indices = g_and_msg_indices.first; _msg_indices = g_and_msg_indices.second; _g_sparse_indices = new int[g_indices.length]; if (g_factorTable.hasSparseRepresentation()) { for (int i = 0; i < g_indices.length; i++) { _g_sparse_indices[i] = g_factorTable.sparseIndexFromIndices(g_indices[i]); } } else { final JointDomainIndexer g_domainIndexer = g_factorTable.getDomainIndexer(); for (int i = 0; i < g_indices.length; i++) { _g_sparse_indices[i] = g_domainIndexer.jointIndexFromIndices(g_indices[i]); } } } @Override public void apply(ISTableFactorSupportingOptimizedUpdate tableFactor) { final double[] f_values = _f.getValues().get(); final double[] g_values = _g.getValues().get(); final double[] inputMsg = tableFactor.getInPortMsg(_inPortNum); Arrays.fill(g_values, 0.0); int n = 0; for (final double value : f_values) { final double input_value = inputMsg[_msg_indices[n]]; g_values[_g_sparse_indices[n]] += value * input_value; n += 1; } } @Override public TableWrapper getAuxiliaryTable() { return _g; } } static final class DenseOutputStep implements IUpdateStep { private final int _outPortNum; private final TableWrapper _f; DenseOutputStep(int outPortNum, final TableWrapper f) { _outPortNum = outPortNum; _f = f; } @Override public void apply(ISTableFactorSupportingOptimizedUpdate tableFactor) { final double[] outputMsg = tableFactor.getOutPortMsg(_outPortNum); final int outputMsgLength = outputMsg.length; final double damping = tableFactor.getDamping(_outPortNum); final boolean useDamping = damping != 0; final double[] saved = useDamping? DimpleEnvironment.doubleArrayCache.allocateAtLeast(outputMsgLength) : ArrayUtil.EMPTY_DOUBLE_ARRAY; if (useDamping) { System.arraycopy(outputMsg, 0, saved, 0, outputMsgLength); } double sum = 0.0; int f_index = 0; for (final double prob : _f.getValues().get()) { outputMsg[f_index] = prob; sum += prob; f_index += 1; } if (sum == 0) { throw new DimpleException( "Update failed in SumProduct Solver. All probabilities were zero when calculating message for port " + _outPortNum + " on factor " + tableFactor.getFactor().getLabel()); } for (int i = 0; i < outputMsgLength; i++) { outputMsg[i] /= sum; } if (damping != 0) { for (int i = 0; i < outputMsgLength; i++) { outputMsg[i] = (1 - damping) * outputMsg[i] + damping * saved[i]; } DimpleEnvironment.doubleArrayCache.release(saved); } } } static final class SparseOutputStep implements IUpdateStep { private final TableWrapper _f; private final int _outPortNum; private final IFactorTable _factorTable; /** * @param outPortNum * @since 0.06 */ SparseOutputStep(int outPortNum, final TableWrapper f) { _outPortNum = outPortNum; _f = f; _factorTable = _f.getFactorTable(); } @Override public void apply(ISTableFactorSupportingOptimizedUpdate tableFactor) { final double[] outputMsg = tableFactor.getOutPortMsg(_outPortNum); final double damping = tableFactor.getDamping(_outPortNum); final int outputMsgLength = outputMsg.length; final boolean useDamping = damping != 0; final double[] saved = useDamping ? DimpleEnvironment.doubleArrayCache.allocateAtLeast(outputMsgLength) : ArrayUtil.EMPTY_DOUBLE_ARRAY; if (useDamping) { System.arraycopy(outputMsg, 0, saved, 0, outputMsg.length); } double sum = 0.0; Arrays.fill(outputMsg, 0); int sparseIndex = 0; for (final double prob : _f.getValues().get()) { final int f_index = _factorTable.sparseIndexToJointIndex(sparseIndex); outputMsg[f_index] = prob; sum += prob; sparseIndex += 1; } if (sum == 0) { throw new DimpleException( "Update failed in SumProduct Solver. All probabilities were zero when calculating message for port " + _outPortNum + " on factor " + tableFactor.getFactor().getLabel()); } for (int i = 0; i < outputMsgLength; i++) { outputMsg[i] /= sum; } if (useDamping) { for (int i = 0; i < outputMsgLength; i++) { outputMsg[i] = (1 - damping) * outputMsg[i] + damping * saved[i]; } DimpleEnvironment.doubleArrayCache.release(saved); } } } static final class DenseMarginalizationStepEstimator implements IMarginalizationStepEstimator { private final CostEstimationTableWrapper _f; private final CostEstimationTableWrapper _g; DenseMarginalizationStepEstimator(final CostEstimationTableWrapper f, final int inPortNum, final int dimension, final CostEstimationTableWrapper g) { _f = f; _g = g; } @Override public CostEstimationTableWrapper getAuxiliaryTable() { return _g; } @Override public Costs estimateCosts() { Costs result = new Costs(); result.put(CostType.DENSE_MARGINALIZATION_SIZE, _f.getSize()); result.add(_g.estimateCosts()); return result; } } static final class SparseMarginalizationStepEstimator implements IMarginalizationStepEstimator { private final CostEstimationTableWrapper _f; private final CostEstimationTableWrapper _g; private final int _msg_indices_length; SparseMarginalizationStepEstimator(final CostEstimationTableWrapper f, final int inPortNum, final int dimension, final CostEstimationTableWrapper g) { _f = f; _g = g; _msg_indices_length = (int) _f.getSize(); } @Override public CostEstimationTableWrapper getAuxiliaryTable() { return _g; } @Override public Costs estimateCosts() { Costs result = new Costs(); result.put(CostType.SPARSE_MARGINALIZATION_SIZE, _f.getSize()); // allocate 4 bytes each entry (int arrays) for 1. message indices and 2. g indices final double g_size = _g.getSize(); double allocations = (_msg_indices_length + g_size) * 4 / 1024.0 / 1024.0 / 1024.0; result.put(CostType.MEMORY, allocations); // add costs from auxiliary table result.add(_g.estimateCosts()); return result; } } static final class DenseOutputStepEstimator implements IUpdateStepEstimator { private final CostEstimationTableWrapper _f; DenseOutputStepEstimator(final CostEstimationTableWrapper f) { _f = f; } @Override public Costs estimateCosts() { Costs result = new Costs(); result.put(CostType.OUTPUT_SIZE, _f.getSize()); return result; } } static final class SparseOutputStepEstimator implements IUpdateStepEstimator { private final CostEstimationTableWrapper _f; SparseOutputStepEstimator(final CostEstimationTableWrapper f) { _f = f; } @Override public Costs estimateCosts() { Costs result = new Costs(); result.put(CostType.OUTPUT_SIZE, _f.getSize()); return result; } } }