/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.sumproduct.customFactors;
import java.util.ArrayList;
import org.eclipse.jdt.annotation.NonNull;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.Multiplexer;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.solvers.core.STableFactorDoubleArray;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductDiscreteEdge;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
/**
* Custom factor for {@link Multiplexer}
* <p>
* The Multiplexer factor is a directed factor
* <blockquote><pre>
* a z(1) z(2) ...
* \ | / /
* ++--+----+
* |
* y
* </pre></blockquote>
* such that P(Y=y|a,z(1),z(2),...) = Identity(y == z(a))
*/
public class CustomMultiplexer extends STableFactorDoubleArray
{
/*-------
* State
*/
private int _yDomainSize;
private int _aDomainSize;
//Create a mapping between a yIndex and all the possible zs that could
//have been selected to achieve that value of y
private ArrayList<int []> [] _yIndex2zIndices;
//Create a mapping between a z index and the y
private int [][] _zIndices2yIndex;
/*--------------
* Construction
*/
@SuppressWarnings("unchecked")
public CustomMultiplexer(Factor factor, SumProductSolverGraph parent)
{
super(factor, parent);
assertDiscrete(factor);
final int nVars = factor.getSiblingCount();
if (nVars < 2)
throw new DimpleException("Must specify at least Y and A");
final Variable y = factor.getSibling(0);
final Variable a = factor.getSibling(1);
final DiscreteDomain yDomain = y.asDiscreteVariable().getDiscreteDomain();
_yDomainSize = yDomain.size();
_aDomainSize = a.asDiscreteVariable().getDiscreteDomain().size();
if (_aDomainSize+2 != nVars)
throw new DimpleException("Must specify " + _aDomainSize + " Zs");
//calculate the list of z index pairs for each y
_yIndex2zIndices = new ArrayList [_yDomainSize];
//Generate the mapping from Ys to Zs
for (int i = 0; i < _yDomainSize; i++)
{
_yIndex2zIndices[i] = new ArrayList<int[]>();
for (int j = 0; j < _aDomainSize; j++)
{
DiscreteDomain zDomain = factor.getSibling(2+j).asDiscreteVariable().getDiscreteDomain();
for (int k = 0, end = zDomain.size(); k < end; k++)
{
if (yDomain.getElement(i).equals(zDomain.getElement(k)))
{
_yIndex2zIndices[i].add(new int [] {j,k});
break;
}
}
}
}
_zIndices2yIndex = new int[_aDomainSize][];
//Generate the mappings from zs to Y
for (int i = 0; i < _aDomainSize; i++)
{
DiscreteDomain zDomain = factor.getSibling(2+i).asDiscreteVariable().getDiscreteDomain();
_zIndices2yIndex[i] = new int [zDomain.size()];
for (int j = 0; j < _zIndices2yIndex[i].length; j++)
{
_zIndices2yIndex[i][j] = -1;
for (int k = 0; k < _yDomainSize; k++)
{
if (yDomain.getElement(k).equals(zDomain.getElement(j)))
{
_zIndices2yIndex[i][j] = k;
break;
}
}
}
}
}
@Override
public void doUpdateEdge(int outPortNum)
{
if (outPortNum == 0)
updateToY();
else if (outPortNum == 1)
updateToA();
else
updateToZ(outPortNum-2);
}
@Override
protected boolean createFactorTableOnInit()
{
return false;
}
@Override
protected void setTableRepresentation(@NonNull IFactorTable table)
{
}
public void updateToA()
{
//p(a=x) = sum_{z in za} p(y=z)p(za=z)
double total = 0;
final double[] yWeights = getSiblingEdgeState(0).varToFactorMsg.representation();
final double[] aWeights = getSiblingEdgeState(1).factorToVarMsg.representation();
for (int i = 0; i < _aDomainSize; i++)
{
final double[] zWeights = getSiblingEdgeState(i+2).varToFactorMsg.representation();
double sm = 0;
for (int j = 0; j < _zIndices2yIndex[i].length; j++)
{
int yIndex = _zIndices2yIndex[i][j];
sm += yWeights[yIndex] * zWeights[j];
}
aWeights[i] = sm;
total += sm;
}
//normalize
for (int i = 0; i < _aDomainSize; i++)
aWeights[i] /= total;
}
public void updateToY()
{
//P(Y=y) = sum_{a} p(a)p(za=y)
double [] outMsg = getSiblingEdgeState(0).factorToVarMsg.representation();
double [] aInputMsg = getSiblingEdgeState(1).varToFactorMsg.representation();
double total = 0;
for (int i = 0; i < _yDomainSize; i++)
{
ArrayList<int []> zIndices = _yIndex2zIndices[i];
double sm = 0;
for (int [] tmp : zIndices)
{
int a = tmp[0];
int z = tmp[1];
sm += aInputMsg[a] * getSiblingEdgeState(a+2).varToFactorMsg.getWeight(z);
}
outMsg[i] = sm;
total += sm;
}
//normalize
for (int i = 0; i < _yDomainSize; i++)
outMsg[i] /= total;
}
public void updateToZ(int index)
{
//TODO: Can we optimize update all edges to calculate this once
// and then subtract off parts for each Z?
//P(Zi=x) = p(a=i)p(y=x) + sum_{j not i} sum_{z in za} p(aj)p(y=z)p(za=z)
double [] zBelief = getSiblingEdgeState(index+2).factorToVarMsg.representation();
double [] yWeights = getSiblingEdgeState(0).varToFactorMsg.representation();
double [] aWeights = getSiblingEdgeState(1).varToFactorMsg.representation();
double offset = 0;
for (int j = 0; j < _aDomainSize; j++)
{
if (j != index)
{
final double a = aWeights[j];
final int[] zIndices2yIndex = _zIndices2yIndex[j];
final double[] zWeights = getSiblingEdgeState(j+2).varToFactorMsg.representation();
for (int k = 0, nk = zIndices2yIndex.length; k < nk; k++)
{
int yIndex = zIndices2yIndex[k];
offset += a * yWeights[yIndex] * zWeights[k];
}
}
}
double total = 0;
for (int i = 0; i < _zIndices2yIndex[index].length; i++)
{
int yIndex = _zIndices2yIndex[index][i];
zBelief[i] = aWeights[index] * yWeights[yIndex] + offset;
total += zBelief[i];
}
//normalize
for (int i = 0; i < zBelief.length; i++)
zBelief[i] /= total;
}
@SuppressWarnings("null")
@Override
public SumProductDiscreteEdge getSiblingEdgeState(int siblingIndex)
{
return (SumProductDiscreteEdge)getSiblingEdgeState_(siblingIndex);
}
}