/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct.customFactors; import static java.util.Objects.*; import org.jtransforms.fft.DoubleFFT_1D; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.solvers.sumproduct.SFiniteFieldFactor; import com.analog.lyric.dimple.solvers.sumproduct.SFiniteFieldVariable; import com.analog.lyric.dimple.solvers.sumproduct.SumProductFiniteFieldVariable; import com.analog.lyric.dimple.solvers.sumproduct.SumProductFiniteFieldVariable.LookupTables; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph; @SuppressWarnings("deprecation") // TODO remove when SFiniteFieldFactor removed public class CustomFiniteFieldMult extends SFiniteFieldFactor { //TODO: store this here or cache somewhere else? // Maybe create a multition for this? //TODO: Do we want double or float multiply? //TODO: Do we want complexForward or realForward? private DoubleFFT_1D _fft; @SuppressWarnings("null") public CustomFiniteFieldMult(Factor factor, SumProductSolverGraph parent) { super(factor, parent); if (factor.getSiblingCount() != 3) throw new DimpleException("Only supports 3 arguments"); //TODO: check all have same prime poly final Discrete v0 = requireNonNull((Discrete)factor.getSibling(0)); //TODO: error check int poly = ((SumProductFiniteFieldVariable)v0.getSolver()).getTables().getPoly(); for (int i = 1; i < 3; i++) { if (((SumProductFiniteFieldVariable)getSibling(i)).getTables().getPoly() != poly) { //TODO: better error message throw new DimpleException("polys don't match"); } } //((Discrete)variables.getByIndex(0)); _fft = new DoubleFFT_1D(v0.getDiscreteDomain().size()-1); } @Override protected void doUpdate() { updateToZ(); updateToX(); updateToY(); } public void updateToX() { double [] xOutput = getSiblingEdgeState(0).factorToVarMsg.representation(); double [] yInput = getSiblingEdgeState(1).varToFactorMsg.representation(); double [] zInput = getSiblingEdgeState(2).varToFactorMsg.representation(); updateBackward(yInput,zInput,xOutput); } public void updateToY() { double [] yOutput = getSiblingEdgeState(1).factorToVarMsg.representation(); double [] xInput = getSiblingEdgeState(0).varToFactorMsg.representation(); double [] zInput = getSiblingEdgeState(2).varToFactorMsg.representation(); updateBackward(xInput,zInput,yOutput); } public void updateBackward(double [] yInput,double [] zInput, double [] xOutput) { @SuppressWarnings("null") final LookupTables tables = ((SFiniteFieldVariable)getSibling(0)).getTables(); int [] dlogTable = tables.getDlogTable(); int [] powTable = tables.getPowerTable(); //Sort x, y, and z so that probs are stored in logs double [] dlogx = new double[(xOutput.length-1)*2]; double [] dlogy = new double[dlogx.length]; double [] dlogz = new double[dlogx.length]; for (int i = 1; i < dlogTable.length; i++) { int dlog = dlogTable[i]; int tmp = (dlogTable.length - 1 - dlog)%(dlogTable.length-1); dlogz[dlog*2] = zInput[i]; dlogy[tmp*2] = yInput[i]; } xOutput[0] = zInput[0]; //perform fft on two inputs _fft.complexForward(dlogy); _fft.complexForward(dlogz); //pointwise multiply for (int i = 0; i < dlogz.length; i+=2) { dlogx[i] = dlogy[i]*dlogz[i] - dlogy[i+1]*dlogz[i+1]; dlogx[i+1] = dlogy[i]*dlogz[i+1] + dlogy[i+1]*dlogz[i]; } //compute inverse FFT //TODO: the scaling could be slow? Can I avoid scaling? _fft.complexInverse(dlogx,true); //sort back double sum = xOutput[0]; for (int i = 0; i < dlogx.length; i += 2) { double val = 0; //threshold negative to zero if (dlogx[i] > 0) val = dlogx[i]; xOutput[powTable[i/2]] = val + yInput[0]*zInput[0]; sum += val; } //normalize, considering 0 for (int i = 0; i < xOutput.length; i++) xOutput[i] /= sum; } public void updateToZ() { //p(Z=0) = p(X=0) + p(Y=0) - p(X=0)*p(Y=0) //p(Z=a) = p(X*Y=a) = p(dlog(x) + dlog(y) = dlog(a)) // = SUM (over i) p(dlog(x) == i) * p(dlog(y) == dlog(a) - i) double [] xInput = getSiblingEdgeState(0).varToFactorMsg.representation(); double [] yInput = getSiblingEdgeState(1).varToFactorMsg.representation(); double [] zOutput = getSiblingEdgeState(2).factorToVarMsg.representation(); @SuppressWarnings("null") final LookupTables tables = ((SumProductFiniteFieldVariable)getSibling(0)).getTables(); int [] dlogTable = tables.getDlogTable(); int [] powTable = tables.getPowerTable(); //Sort x, y, and z so that probs are stored in logs double [] dlogx = new double[(xInput.length-1)*2]; double [] dlogy = new double[dlogx.length]; double [] dlogz = new double[dlogx.length]; for (int i = 1; i < dlogTable.length; i++) { int dlog = dlogTable[i]; dlogx[dlog*2] = xInput[i]; dlogy[dlog*2] = yInput[i]; } //calculate p(0) zOutput[0] = xInput[0] + yInput[0] - xInput[0]*yInput[0]; //perform fft on two inputs _fft.complexForward(dlogx); _fft.complexForward(dlogy); //pointwise multiply for (int i = 0; i < dlogz.length; i+=2) { dlogz[i] = dlogx[i]*dlogy[i] - dlogx[i+1]*dlogy[i+1]; dlogz[i+1] = dlogx[i]*dlogy[i+1] + dlogx[i+1]*dlogy[i]; } //compute inverse FFT //TODO: is scaling slower or faster? _fft.complexInverse(dlogz,false); //unsort back //double sum = zOutput[0]; double sum = 0; for (int i = 0; i < dlogz.length; i += 2) { double val = 0; //threshold negative to zero if (dlogz[i] > 0) val = dlogz[i]; zOutput[powTable[i/2]] = val; sum += val; } //TODO: Is this enough to avoid NaNs? if (sum > 0) for (int i = 1; i < zOutput.length; i++) zOutput[i] = (1-zOutput[0])*zOutput[i]/sum; } @Override public void doUpdateEdge(int outPortNum) { if (outPortNum == 0) updateToX(); else if (outPortNum == 1) updateToY(); else if (outPortNum == 2) updateToZ(); else throw new DimpleException("unexpected port num"); } }