/*******************************************************************************
* Copyright 2012-2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.sumproduct.customFactors;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.solvers.core.SolverFactorCreationException;
import com.analog.lyric.dimple.solvers.sumproduct.SFiniteFieldFactor;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductDiscreteEdge;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
@SuppressWarnings("deprecation") // TODO remove when SFiniteFieldFactor removed
public class CustomFiniteFieldAdd extends SFiniteFieldFactor
{
public CustomFiniteFieldAdd(Factor factor, SumProductSolverGraph parent)
{
super(factor, parent);
if (factor.hasConstants())
throw new SolverFactorCreationException("%s does not support constant arguments",
getClass().getSimpleName());
if (factor.getArgumentCount() != 3)
throw new SolverFactorCreationException("%s expects 3 arguments",
getClass().getSimpleName());
}
@Override
@SuppressWarnings("null")
public void doUpdateEdge(int outPortNum)
{
double [] inputs1 = null;
double [] inputs2 = null;
double [] outputs = null;
for (int i = 0; i < 3; i++)
{
final SumProductDiscreteEdge edge = getSiblingEdgeState(i);
if (outPortNum == i)
{
outputs = edge.factorToVarMsg.representation();
}
else
{
if (inputs1 == null)
inputs1 = edge.varToFactorMsg.representation();
else
inputs2 = edge.varToFactorMsg.representation();
}
}
inputs1 = inputs1.clone();
inputs2 = inputs2.clone();
//TODO: fix this.
int n = (int) (Math.log(inputs1.length)/Math.log(2));
double [] tmp1 = new double[inputs1.length];
double [] tmp2 = new double[inputs2.length];
double [] tmp3 = new double[outputs.length];
//Fast hadamard input 1 probs
fast_hadamard(n,inputs1, tmp1);
fast_hadamard(n,inputs2, tmp2);
//Point-wise multiply values
for (int i = 0; i < outputs.length; i++)
{
tmp3[i] = tmp1[i]*tmp2[i];
}
//Fast hadamard result
fast_hadamard(n,tmp3,outputs);
double sum = 0;
for (int i = 0; i < outputs.length; i++)
sum += outputs[i];
for (int i = 0; i < outputs.length; i++)
{
if (outputs[i] < 0)
outputs[i] = 0;
else
outputs[i] /= sum;
}
}
//TODO: can we do this in place?
public static void fast_hadamard(int n, double [] in, double [] out)
{
int i, bit, flip_bit;
int leftmask, rightmask, leftshifted, ind0, ind1;
double [] tmp;
for (bit=0;bit<n;bit++)
{
flip_bit=1<<bit;
for (leftmask=0;leftmask< (1<<(n-bit-1));leftmask++)
{
leftshifted=leftmask<<(bit+1);
for (rightmask=0;rightmask< (1<<bit); rightmask++)
{
ind0=leftshifted | rightmask;
ind1=leftshifted | rightmask | flip_bit;
out[ind0]=in[ind0]+in[ind1];
out[ind1]=in[ind0]-in[ind1];
}
}
tmp=in;
in=out;
out=tmp;
}
/* If "n" is even, then we need to copy "out" to "in" to make the
output appear in the correct array. */
if ((n&1) == 0){
for (i=0;i< (1<<n);i++){
out[i]=in[i];
}
}
}
}