/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct.customFactors; import java.util.List; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.solvers.sumproduct.SFiniteFieldFactor; import com.analog.lyric.dimple.solvers.sumproduct.SumProductFiniteFieldVariable; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph; @SuppressWarnings("deprecation") // TODO remove when SFiniteFieldFactor removed public class CustomFiniteFieldProjection extends SFiniteFieldFactor { private SumProductFiniteFieldVariable _ffVar; private int [] _portIndex2bitIndex; private int [] _bit2port; public CustomFiniteFieldProjection(Factor factor, SumProductSolverGraph parent) { super(factor, parent); final int nVars = factor.getSiblingCount(); if (nVars <= 1) throw new DimpleException("need to specify at least one bit for projection"); final int nEdges = factor.getSiblingCount(); //First variable is the FiniteFieldVariable //Other variables should be bits. _ffVar = (SumProductFiniteFieldVariable)getSibling(0); _portIndex2bitIndex = new int[nEdges]; for (int i = 0; i < nEdges; i++) _portIndex2bitIndex[i] = -1; //get constant value and make sure it's in range final List<Value> constants = factor.getConstantValues(); if (constants.size()!= 1) throw new DimpleException("expected one constant to specify the array of bit positions"); double [] domain = constants.get(0).getDoubleArray(); if (nVars != 1+domain.length) throw new DimpleException("expect finite field variable, bit positions, and bits"); _bit2port = new int[nEdges-1]; for (int i = 1; i < nVars; i++) { //TODO: error check int index = (int)domain[i-1]; if (index < 0 || index >= nEdges-1) throw new DimpleException("index out of range"); if (_bit2port[index] != 0) throw new DimpleException("Tried to set index twice"); //get Variable and make sure it's a bit. Discrete bit = (Discrete)factor.getSibling(i); Object [] bitDomain = bit.getDiscreteDomain().getElements(); if (bitDomain.length != 2 || (Double)bitDomain[0] != 0 || (Double)bitDomain[1] != 1) throw new DimpleException("expected bit"); _bit2port[index] = i; _portIndex2bitIndex[i] = index; } } @Override public void doUpdateEdge(int outPortNum) { if (outPortNum == 0) updateFiniteField(); else { if (outPortNum >= 1) updateBit(outPortNum); } } public void updateFiniteField() { //for every value of the finite field double [] outputs = getSiblingEdgeState(0).factorToVarMsg.representation(); int numBits = _ffVar.getNumBits(); final int nEdges = getSiblingCount(); double prod; double [][] inputMsgs = new double[numBits][]; for (int i = 1; i < nEdges; i++) { inputMsgs[i-1] = getSiblingEdgeState(i).varToFactorMsg.representation(); } //Multiply bit probabilities double sum = 0; for (int i = 0, end = ((Discrete)_ffVar.getVariable()).getDiscreteDomain().size(); i < end; i++) { prod = 1; for (int j = 0; j < numBits; j++) { int p = _bit2port[j]; if (p != 0) { if (((i >> j) & 1) == 1) { //is one prod *= inputMsgs[j][1]; } else { prod *= inputMsgs[j][0]; } } } outputs[i] = prod; sum += prod; } //normalize for (int i = 0; i < outputs.length; i++) outputs[i] /= sum; } public void updateBit(int portNum) { //get output msg for bit double [] outputs = getSiblingEdgeState(portNum).factorToVarMsg.representation(); //init to 1 for each outputs[0] = 0; outputs[1] = 0; int bit = _portIndex2bitIndex[portNum]; //Iterate each value of finite field double [] inputs = getSiblingEdgeState(0).varToFactorMsg.representation(); for (int i = 0; i < inputs.length; i++) { //extract value of bit of interest if (((i >> bit) & 1) == 1) { //bit was one outputs[1] += inputs[i]; } else { //bit was zero outputs[0] += inputs[i]; } } //normalize double sum = outputs[0]+outputs[1]; outputs[0] /= sum; outputs[1] /= sum; } }