/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.core.kbest; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.util.misc.IndexCounter; /* * This class provides an implementation for update and updateEdge that can * be used by solver factor classes. It implements a k best algorithm. * * pseudocode of algorithm * * updateEdge(outPort) * For each input msg * sort by probability and pick the k most likely * * initialize outputMsg to zero (or equivalent) for all values * * For every single value of the output message (not just the kbset) * * For the n^k combinations of inputs (where n is number of input edges) * prod = calculate factor function (or equivalent for minsum) * prod *= all of the input probabilities for those values * * sum the prod with the current value for the output message at this value (or equivalent for minsum) * * Normalize outputmsg (subtract smallest value) * * There is no optimization for update(all) */ public class KBestFactorEngine { private int _k; private IKBestFactor _kbestFactor; private double [][] _outPortMsgs = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY; private double [][] _inPortMsgs = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY; private void updateCache() { _inPortMsgs = _kbestFactor.getInPortMsgs(); _outPortMsgs = _kbestFactor.getOutPortMsgs(); } public KBestFactorEngine(IKBestFactor f) { _kbestFactor = f; } public void update() { updateCache(); for (int i = 0; i < _outPortMsgs.length; i++) updateEdgeInternal(i); } public void setK(int k) { _k = k; } /* * Code for updating given no factor table but java factor function */ public void updateEdge(int outPortNum) { updateCache(); updateEdgeInternal(outPortNum); } protected void updateEdgeInternal(int outPortNum) { //Initialize the outputMsg to Infinite potentials. double [] outputMsg = _outPortMsgs[outPortNum]; _kbestFactor.initMsg(outputMsg); //Cache the input messages. //Cache the domains Object [][] domains = new Object[_inPortMsgs.length][]; for (int i = 0; i < _inPortMsgs.length; i++) domains[i] = ((Discrete)_kbestFactor.getFactor().getConnectedNodeFlat(i)).getDiscreteDomain().getElements(); //We will store the kbest indices in this array int [][] domainIndices = new int[_inPortMsgs.length][]; //We will store the truncated domainlengths here. int [] domainLengths = new int[_inPortMsgs.length]; //For each port for (int i = 0; i < _inPortMsgs.length; i++) { double [] inPortMsg = _inPortMsgs[i]; //If this is the output port, we only store one value at a time. if (i == outPortNum) domainIndices[i] = new int[]{0}; else { //Here we check to see that k is actually less than the domain length if (_k < inPortMsg.length) domainIndices[i] = _kbestFactor.findKBestForMsg(inPortMsg,_k); else { //If it's not, we just map indices one to one. domainIndices[i] = new int[inPortMsg.length]; for (int j = 0; j < domainIndices[i].length; j++) domainIndices[i][j] = j; } } domainLengths[i] = domainIndices[i].length; } //cache the factor function. //FactorFunction ff = _kbestFactor.getFactorFunction(); //Used to iterate all combinations of truncated domains. IndexCounter ic = new IndexCounter(domainLengths); int [] inputIndices = new int[_inPortMsgs.length]; //We fill out a value for every value for the output message (no truncating to k) for (int outputIndex = 0; outputIndex < outputMsg.length; outputIndex++) { //Here we set the output port's index appropriately domainIndices[outPortNum][0] = outputIndex; //For all elements of cartesian product for (int [] indices : ic) { //initialize the sum double sum = _kbestFactor.initAccumulator(); for (int i = 0; i < indices.length; i++) { //Don't count the output port if (i != outPortNum) //i == port index, indices[i] == which of the truncated domain indices to retrieve //domainIndices[i][indices[i]] == the actual index of the input msg. sum = _kbestFactor.accumulate(sum, _inPortMsgs[i][domainIndices[i][indices[i]]]); //Here we set the input value for this port inputIndices[i] = domainIndices[i][indices[i]]; //ffInput[i] = domains[i][domainIndices[i][indices[i]]]; } //Evaluate the factor function and add that potential to the sum. double result = getFactorFunctionValueForIndices(inputIndices,domains); sum = _kbestFactor.accumulate(sum, result); outputMsg[outputIndex] = _kbestFactor.combine(outputMsg[outputIndex] , sum); } } _kbestFactor.normalize(outputMsg); } protected double getFactorFunctionValueForIndices(int [] inputIndices, Object [][] domains) { Object [] ffInput = new Object[inputIndices.length]; for (int i = 0; i < ffInput.length; i++) ffInput[i] = domains[i][inputIndices[i]]; return _kbestFactor.evalFactorFunction(ffInput); } protected IKBestFactor getIKBestFactor() { return _kbestFactor; } }