/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct.customFactors; import java.util.List; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.solvers.sumproduct.SFiniteFieldFactor; import com.analog.lyric.dimple.solvers.sumproduct.SumProductFiniteFieldVariable; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph; @SuppressWarnings("deprecation") // TODO remove when SFiniteFieldFactor removed public class CustomFiniteFieldConstantMult extends SFiniteFieldFactor { private int _constant; private int _dlogConstant; private SumProductFiniteFieldVariable _varInput; private SumProductFiniteFieldVariable _varOutput; //private Port _varInputPort; //private Port _varOutputPort; public CustomFiniteFieldConstantMult(Factor factor, SumProductSolverGraph parent) { super(factor, parent); assertHasConstants(factor); final List<Value> constants = _model.getConstantValues(); if (constants.size() != 1) throw new DimpleException("expected one constant"); if (factor.getSiblingCount() != 2) throw new DimpleException("finiteFieldMult expects two variable arguments"); //ArrayList<Port> ports = _factor.getPorts(); _varInput = (SumProductFiniteFieldVariable)getSibling(0); //_varInputPort = ports.get(0); _varOutput = (SumProductFiniteFieldVariable)getSibling(1); //_varOutputPort = ports.get(1); assignConstant(constants.get(0).getInt()); if (_varInput.getTables().getPoly() != _varOutput.getTables().getPoly()) { throw new DimpleException("Variables have different primitive polynomials. This is not currently supported"); } //Make sure primitive polynomials match //TODO: make sure this is possible before casting } private void assignConstant(int val) { //TODO: error check before casting int poly = val; _constant = poly; //FiniteFieldVariable.convertDoubleArray2poly(poly); if (_constant == 0) throw new DimpleException("Multiplication by zero not supported"); _dlogConstant = _varInput.getTables().getDlogTable()[_constant]; } public int getConstant() { return -1; } @Override public void doUpdateEdge(int outPortNum) { switch (outPortNum) { case 0: //we want to update the multiply input updateMultInputEdge(); break; case 1: //we want to update the multiply output updateMultOutputEdge(); break; default: throw new DimpleException("unexpected port num"); } } private void updateMultInputEdge() { double [] inputMsg = getSiblingEdgeState(0).factorToVarMsg.representation(); double [] outputMsg = getSiblingEdgeState(1).varToFactorMsg.representation(); int [] outputDlogTable = _varOutput.getTables().getDlogTable(); int [] inputPowerTable = _varInput.getTables().getPowerTable(); int dlogSum = 0; int length = inputMsg.length-1; int index = 0; //TODO: special case multiplication by constant of zero. //TODO: check same size? inputMsg[0] = outputMsg[0]; for (int i = 1; i < inputMsg.length; i++) { dlogSum = (outputDlogTable[i]-_dlogConstant+length)%length; index = inputPowerTable[dlogSum]; inputMsg[index] = outputMsg[i]; } } private void updateMultOutputEdge() { double [] inputMsg = getSiblingEdgeState(0).varToFactorMsg.representation(); double [] outputMsg = getSiblingEdgeState(1).factorToVarMsg.representation(); int [] inputDlogTable = _varInput.getTables().getDlogTable(); int [] outputPowerTable = _varOutput.getTables().getPowerTable(); int dlogSum = 0; int length = inputMsg.length-1; int index = 0; //TODO: special case multiplication by constant of zero. //TODO: check same size? outputMsg[0] = inputMsg[0]; for (int i = 1; i < inputMsg.length; i++) { dlogSum = (inputDlogTable[i]+_dlogConstant)%length; index = outputPowerTable[dlogSum]; outputMsg[index] = inputMsg[i]; } } @Override public void initialize() { } }