/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.minsum;
import java.util.Arrays;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.dimple.environment.DimpleEnvironment;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.factors.Factor;
/*
* Provides the update and updateEdge logic for minsum
*/
public class TableFactorEngine
{
final MinSumTableFactor _tableFactor;
final Factor _factor;
public TableFactorEngine(MinSumTableFactor tableFactor)
{
_tableFactor = tableFactor;
_factor = _tableFactor.getFactor();
}
public void updateEdge(int outPortNum)
{
final int[][] table = _tableFactor.getFactorTable().getIndicesSparseUnsafe();
final double[] values = _tableFactor.getFactorTable().getEnergiesSparseUnsafe();
final int tableLength = table.length;
final int numPorts = _factor.getSiblingCount();
final double[] outputMsgs = _tableFactor.getOutPortMsg(outPortNum);
final int outputMsgLength = outputMsgs.length;
double[] saved = ArrayUtil.EMPTY_DOUBLE_ARRAY;
if (_tableFactor._dampingInUse)
{
double damping = _tableFactor._dampingParams[outPortNum];
if (damping != 0)
{
saved = DimpleEnvironment.doubleArrayCache.allocateAtLeast(outputMsgLength);
System.arraycopy(outputMsgs, 0, saved, 0, outputMsgLength);
}
}
Arrays.fill(outputMsgs, Double.POSITIVE_INFINITY);
final double [][] inPortMsgs = _tableFactor.getInPortMsgs();
// Run through each row of the function table
for (int tableIndex = tableLength; --tableIndex>=0;)
{
double L = values[tableIndex];
final int[] tableRow = table[tableIndex];
final int outputIndex = tableRow[outPortNum];
int inPortNum = numPorts;
while (--inPortNum > outPortNum)
L += inPortMsgs[inPortNum][tableRow[inPortNum]];
while (--inPortNum >= 0)
L += inPortMsgs[inPortNum][tableRow[inPortNum]];
if (L < outputMsgs[outputIndex])
outputMsgs[outputIndex] = L; // Use the minimum value
}
// Damping
if (_tableFactor._dampingInUse)
{
double damping = _tableFactor._dampingParams[outPortNum];
if (damping != 0)
{
final double inverseDamping = 1.0 - damping;
for (int i = outputMsgLength; --i>=0;)
{
outputMsgs[i] = inverseDamping*outputMsgs[i] + damping*saved[i];
}
}
}
if (saved.length > 0)
{
DimpleEnvironment.doubleArrayCache.release(saved);
}
// Normalize the outputs
double minPotential = outputMsgs[0];
for (int i = outputMsgLength; --i>=0;)
{
minPotential = Math.min(minPotential, outputMsgs[i]);
}
// Normalize min value
if (minPotential != 0.0)
{
for (int i = outputMsgLength; --i>=0;)
{
outputMsgs[i] -= minPotential;
}
}
}
public void update()
{
final IFactorTable table = _tableFactor.getFactorTable();
final JointDomainIndexer indexer = table.getDomainIndexer();
final int[][] tableIndices = table.getIndicesSparseUnsafe();
final double[] values = table.getEnergiesSparseUnsafe();
final int tableLength = tableIndices.length;
final int numPorts = _factor.getSiblingCount();
double [][] outPortMsgs = _tableFactor.getOutPortMsgs();
final boolean useDamping = _tableFactor._dampingInUse;
double[] saved = ArrayUtil.EMPTY_DOUBLE_ARRAY;
if (useDamping)
{
saved = DimpleEnvironment.doubleArrayCache.allocateAtLeast(indexer.getSumOfDomainSizes());
for (int port = 0, savedOffset = 0; port < numPorts; port++)
{
final double[] outputMsgs = outPortMsgs[port];
final int outputMsgLength = outputMsgs.length;
if (useDamping)
{
double damping = _tableFactor._dampingParams[port];
if (damping != 0)
{
System.arraycopy(outputMsgs, 0, saved, savedOffset, outputMsgLength);
}
}
Arrays.fill(outputMsgs, Double.POSITIVE_INFINITY);
savedOffset += outputMsgLength;
}
}
else
{
for (double[] outMsg : outPortMsgs)
{
Arrays.fill(outMsg, Double.POSITIVE_INFINITY);
}
}
final double [][] inPortMsgs = _tableFactor.getInPortMsgs();
// Run through each row of the function table
for (int tableIndex = tableLength; --tableIndex>=0;)
{
final int[] tableRow = tableIndices[tableIndex];
// Sum up the function value plus the messages on all ports
double L = values[tableIndex];
for (int port = numPorts; --port>=0;)
L += inPortMsgs[port][tableRow[port]];
// Run through each output port
for (int outPortNum = numPorts; --outPortNum>=0;)
{
final double[] outputMsgs = outPortMsgs[outPortNum];
final int outputIndex = tableRow[outPortNum]; // Index for the output value
final double LThisPort = L - inPortMsgs[outPortNum][outputIndex]; // Subtract out the message from this output port
outputMsgs[outputIndex] = Math.min(outputMsgs[outputIndex], LThisPort);
}
}
// Damping
if (useDamping)
{
for (int port = 0, savedOffset = 0; port < numPorts; port++)
{
final double[] outputMsgs = outPortMsgs[port];
final int outputMsgLength = outputMsgs.length;
final double damping = _tableFactor._dampingParams[port];
if (damping != 0)
{
final double inverseDamping = 1.0 - damping;
for (int i = outputMsgLength; --i>=0;)
{
outputMsgs[i] = inverseDamping*outputMsgs[i] + damping*saved[i+savedOffset];
}
}
savedOffset += outputMsgLength;
}
DimpleEnvironment.doubleArrayCache.release(saved);
}
// Normalize the outputs
for (int port = numPorts; --port>=0;)
{
double[] outputMsgs = outPortMsgs[port];
int outputMsgLength = outputMsgs.length;
double minPotential = Double.POSITIVE_INFINITY;
for (int i = outputMsgLength; --i>=0;)
{
minPotential = Math.min(minPotential, outputMsgs[i]);
}
if (minPotential != 0.0)
{
for (int i = outputMsgLength; --i>=0;)
outputMsgs[i] -= minPotential; // Normalize min value
}
}
}
}