/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct; import java.util.Arrays; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.exceptions.NormalizationException; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage; /* * Provides the update and updateEdge logic for sumproduct */ public class TableFactorEngine { final SumProductTableFactor _tableFactor; final Factor _factor; public TableFactorEngine(SumProductTableFactor tableFactor) { _tableFactor = tableFactor; _factor = _tableFactor.getFactor(); } public void updateEdge(int outPortNum) { final SumProductTableFactor tableFactor = _tableFactor; final int[][] table = tableFactor.getFactorTable().getIndicesSparseUnsafe(); final double[] values = tableFactor.getFactorTable().getWeightsSparseUnsafe(); final int tableLength = table.length; final int numPorts = _factor.getSiblingCount(); final double[] outputMsgs = tableFactor.getOutPortMsg(outPortNum); final double [][] inputMsgs = tableFactor.getInPortMsgs(); final int outputMsgLength = outputMsgs.length; final double damping = tableFactor._dampingInUse ? tableFactor._dampingParams[outPortNum] : 0.0; if (damping != 0.0) { final double[] saved = DimpleEnvironment.doubleArrayCache.allocateAtLeast(outputMsgLength); System.arraycopy(outputMsgs, 0, saved, 0, outputMsgLength); double sum = 0.0; Arrays.fill(outputMsgs, 0); for (int tableIndex = tableLength; --tableIndex>=0;) { double prob = values[tableIndex]; final int[] tableRow = table[tableIndex]; final int outputIndex = tableRow[outPortNum]; int inPortNum = numPorts; while (--inPortNum > outPortNum) prob *= inputMsgs[inPortNum][tableRow[inPortNum]]; while (--inPortNum >= 0) prob *= inputMsgs[inPortNum][tableRow[inPortNum]]; outputMsgs[outputIndex] += prob; sum += prob; } if (sum == 0) { throw new DimpleException("UpdateEdge failed in SumProduct Solver. All probabilities were zero when calculating message for port " + outPortNum + " on factor " + _factor.getLabel()); } for (int i = outputMsgLength; --i>=0;) { outputMsgs[i] /= sum; } final double inverseDamping = 1 - damping; for (int i = outputMsgLength; --i>=0;) { outputMsgs[i] = inverseDamping*outputMsgs[i] + damping*saved[i]; } DimpleEnvironment.doubleArrayCache.release(saved); } else { // Only update normalization energy when damping is disabled because it probably // won't be useful in that case. final DiscreteMessage outMsg = _tableFactor.getSiblingEdgeState(outPortNum).factorToVarMsg; if (true) // make this optional? { double normalizationEnergy = 0.0; for (int i = numPorts; --i > outPortNum;) { normalizationEnergy += _tableFactor.getSiblingEdgeState(i).varToFactorMsg.getNormalizationEnergy(); } for (int i = outPortNum; --i >= 0;) { normalizationEnergy += _tableFactor.getSiblingEdgeState(i).varToFactorMsg.getNormalizationEnergy(); } outMsg.setNormalizationEnergy(normalizationEnergy); } Arrays.fill(outputMsgs, 0); for (int tableIndex = tableLength; --tableIndex>=0;) { double prob = values[tableIndex]; final int[] tableRow = table[tableIndex]; final int outputIndex = tableRow[outPortNum]; int inPortNum = numPorts; while (--inPortNum > outPortNum) prob *= inputMsgs[inPortNum][tableRow[inPortNum]]; while (--inPortNum >= 0) prob *= inputMsgs[inPortNum][tableRow[inPortNum]]; outputMsgs[outputIndex] += prob; } try { outMsg.normalize(); } catch (NormalizationException ex) { throw new DimpleException("UpdateEdge failed in SumProduct Solver. All probabilities were zero when calculating message for port " + outPortNum + " on factor " + _factor.getLabel()); } } } public void update() { final SumProductTableFactor tableFactor = _tableFactor; final IFactorTable table = tableFactor.getFactorTable(); final int[][] tableIndices = table.getIndicesSparseUnsafe(); final double[] values = table.getWeightsSparseUnsafe(); final int tableLength = tableIndices.length; final int numPorts = _factor.getSiblingCount(); final double [][] inMsgs = tableFactor.getInPortMsgs(); if (tableFactor._dampingInUse) { final double[] saved = DimpleEnvironment.doubleArrayCache.allocateAtLeast(table.getDomainIndexer().getSumOfDomainSizes()); for (int outPortNum = 0, savedOffset = 0; outPortNum < numPorts; outPortNum++) { final double[] outputMsgs = tableFactor.getOutPortMsg(outPortNum); final int outputMsgLength = outputMsgs.length; final double damping = tableFactor._dampingParams[outPortNum]; if (damping != 0) { System.arraycopy(outputMsgs, 0, saved, savedOffset, outputMsgLength); } Arrays.fill(outputMsgs, 0); for (int tableIndex = tableLength; --tableIndex>=0;) { double prob = values[tableIndex]; final int[] tableRow = tableIndices[tableIndex]; int outputIndex = tableRow[outPortNum]; int inPortNum = numPorts; while (--inPortNum > outPortNum) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; while (--inPortNum >= 0) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; outputMsgs[outputIndex] += prob; } double sum = 0; for (int i = outputMsgLength; --i>=0;) { sum += outputMsgs[i]; } if (sum == 0) { throw new DimpleException("Update failed in SumProduct Solver. All probabilities were zero when calculating message for port " + outPortNum + " on factor " +_factor.getLabel()); } for (int i = outputMsgLength; --i>=0;) { outputMsgs[i] /= sum; } if (damping != 0) { final double inverseDamping = 1.0 - damping; for (int i = outputMsgLength; --i>=0;) { outputMsgs[i] = inverseDamping*outputMsgs[i] + damping*saved[i+savedOffset]; } } savedOffset += outputMsgLength; } DimpleEnvironment.doubleArrayCache.release(saved); } else // no damping { for (int outPortNum = numPorts; --outPortNum>=0;) { final double[] outputMsgs = tableFactor.getOutPortMsg(outPortNum); final int outputMsgLength = outputMsgs.length; Arrays.fill(outputMsgs, 0); for (int tableIndex = tableLength; --tableIndex>=0;) { double prob = values[tableIndex]; final int[] tableRow = tableIndices[tableIndex]; int outputIndex = tableRow[outPortNum]; int inPortNum = numPorts; while (--inPortNum > outPortNum) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; while (--inPortNum >= 0) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; outputMsgs[outputIndex] += prob; } double sum = 0; for (int i = outputMsgLength; --i>=0;) { sum += outputMsgs[i]; } if (sum == 0) { throw new DimpleException("Update failed in SumProduct Solver. All probabilities were zero when calculating message for port " + outPortNum + " on factor " +_factor.getLabel()); } // normalize for (int i = outputMsgLength; --i>=0;) { outputMsgs[i] /= sum; } } } } }