/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.FactorFunctions.core; import static java.util.Objects.*; import java.util.Arrays; import java.util.Random; import java.util.concurrent.TimeUnit; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.factorfunctions.core.FactorTable; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.google.common.base.Stopwatch; public class FactorTablePerformanceTester { private final IFactorTable _table; private final Stopwatch _timer; private final int _iterations; private final Random _random; public boolean showLog = true; /*--------------- * Construction */ FactorTablePerformanceTester(IFactorTable table, int iterations) { _table = table; _iterations = iterations; _random = new Random(42); _timer = Stopwatch.createUnstarted(); } /*------------ * Test cases */ public void testEvalAsFactorFunction() { _random.setSeed(23); final Object [][] argrows = new Object[_iterations][]; for (int i = 0; i < _iterations; ++i) { argrows[i] = randomArguments(); } Runnable test = new Runnable() { @Override public void run() { for (Object[] args : argrows) { _table.getWeightForElements(args); } } }; runTest("evalAsFactorFunction", test); } public void testGetWeightIndexFromTableIndices() { _random.setSeed(42); final int [][] rows = new int[_iterations][]; for (int i = 0; i < _iterations; ++i) { rows[i] = randomIndices(); } Runnable test = new Runnable() { @Override public void run() { for (int[] indices : rows) { _table.sparseIndexFromIndices(indices); } } }; runTest("getWeightIndexFromTableIndices", test); } public void testGetWeightForIndices() { _random.setSeed(43); final int [][] rows = new int[500][]; for (int i = 0; i < rows.length; ++i) { rows[i] = randomIndices(); } Runnable test = new Runnable() { @SuppressWarnings("unused") double total = 0.0; @Override public void run() { for (int i = _iterations; --i>=0;) { for (int[] indices : rows) { total += _table.getWeightForIndices(indices); } } } }; runTest("getWeightForIndices", rows.length, "call", 42, test); } public void testSumProductUpdate() { final JointDomainIndexer domains = _table.getDomainIndexer(); final int numPorts_ = domains.size(); final double[][] outMsgs_ = new double[numPorts_][]; final double[][] inMsgs_ = new double[numPorts_][]; for (int i = 0; i < numPorts_; ++i) { int domainSize = domains.getDomainSize(i); outMsgs_[i] = new double[domainSize]; inMsgs_[i] = new double[domainSize]; for (int j = 0; j < domainSize; ++j) { double d = _random.nextDouble(); inMsgs_[i][j] = d; } } if (getNewTable() == null) { Runnable original = new Runnable() { @Override public void run() { final int numPorts = numPorts_; final double[][] outMsgs = outMsgs_; final double[][] inMsgs = inMsgs_; for (int iteration = _iterations; --iteration>=0;) { int[][] table = _table.getIndicesSparseUnsafe(); double[] values = _table.getWeightsSparseUnsafe(); int tableLength = table.length; for (int outPortNum = 0; outPortNum < numPorts; outPortNum++) { double[] outputMsgs = outMsgs[outPortNum]; int outputMsgLength = outputMsgs.length; for (int i = 0; i < outputMsgLength; i++) outputMsgs[i] = 0; for (int tableIndex = 0; tableIndex < tableLength; tableIndex++) { double prob = values[tableIndex]; int[] tableRow = table[tableIndex]; int outputIndex = tableRow[outPortNum]; for (int inPortNum = 0; inPortNum < numPorts; inPortNum++) if (inPortNum != outPortNum) { prob *= inMsgs[inPortNum][tableRow[inPortNum]]; } outputMsgs[outputIndex] += prob; } double sum = 0; for (int i = 0; i < outputMsgLength; i++) sum += outputMsgs[i]; for (int i = 0; i < outputMsgLength; i++) outputMsgs[i] /= sum; } } } }; runTest("sumProductUpdateOriginal", "call", 42, original); Runnable faster = new Runnable() { @Override public void run() { final int numPorts = numPorts_; final double[][] outMsgs = outMsgs_; final double[][] inMsgs = inMsgs_; for (int iteration = _iterations; --iteration>=0;) { final int[][] table = _table.getIndicesSparseUnsafe(); final double[] values = _table.getWeightsSparseUnsafe(); final int tableLength = table.length; for (int outPortNum = 0; outPortNum < numPorts; ++outPortNum) { double[] outputMsgs = outMsgs[outPortNum]; int outputMsgLength = outputMsgs.length; Arrays.fill(outputMsgs, 0); for (int tableIndex = 0; tableIndex < tableLength; tableIndex++) { double prob = values[tableIndex]; int[] tableRow = table[tableIndex]; int outputIndex = tableRow[outPortNum]; for (int inPortNum = 0; inPortNum < outPortNum; ++inPortNum) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; for (int inPortNum = outPortNum + 1; inPortNum < numPorts; inPortNum++) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; outputMsgs[outputIndex] += prob; } double sum = 0; for (double w : outputMsgs) sum += w; for (int i = 0; i < outputMsgLength; ++i) outputMsgs[i] /= sum; } } } }; runTest("sumProductUpdateFaster", "call", 42, faster); } if (getNewTable() != null) { Runnable unsafeIndices = new Runnable() { @Override public void run() { final FactorTable ftable = requireNonNull(getNewTable()); final int numPorts = numPorts_; final double[][] outMsgs = outMsgs_; final double[][] inMsgs = inMsgs_; for (int iteration = _iterations; --iteration>=0;) { final int[][] table = ftable.getIndicesSparseUnsafe(); final int tableLength = table.length; final double[] values = ftable.getWeightsSparseUnsafe(); for (int outPortNum = 0; outPortNum < numPorts; ++outPortNum) { double[] outputMsgs = outMsgs[outPortNum]; int outputMsgLength = outputMsgs.length; Arrays.fill(outputMsgs, 0); for (int tableIndex = 0; tableIndex < tableLength; tableIndex++) { double prob = values[tableIndex]; int[] tableRow = table[tableIndex]; int outputIndex = tableRow[outPortNum]; for (int inPortNum = 0; inPortNum < outPortNum; ++inPortNum) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; for (int inPortNum = outPortNum + 1; inPortNum < numPorts; inPortNum++) prob *= inMsgs[inPortNum][tableRow[inPortNum]]; outputMsgs[outputIndex] += prob; } double sum = 0; for (double w : outputMsgs) sum += w; for (int i = 0; i < outputMsgLength; ++i) outputMsgs[i] /= sum; } } } }; runTest("sumProductUpdateNew", "call", 42, unsafeIndices); } } public void testGibbsUpdateMessage() { final JointDomainIndexer domains = _table.getDomainIndexer(); final int numPorts_ = domains.size(); final double[][] outMsgs_ = new double[numPorts_][]; final int[] inMsgs_ = new int[numPorts_]; for (int i = 0; i < numPorts_; ++i) { int domainSize = domains.getDomainSize(i); outMsgs_[i] = new double[domainSize]; inMsgs_[i] = _random.nextInt(domainSize); } // Modified version of gibbs.STableFactor.updateEdgeMessage(int) Runnable original = new Runnable() { @Override public void run() { final int numPorts = numPorts_; final double[][] outMsgs = outMsgs_; final int[] inMsgs = inMsgs_; final IFactorTable factorTable = _table; for (int iteration = _iterations; --iteration>=0;) { for (int outPortNum = 0; outPortNum < numPorts; ++outPortNum) { double[] outMessage = outMsgs[outPortNum]; int outputMsgLength = outMessage.length; double[] factorTableWeights = factorTable.getEnergiesSparseUnsafe(); int[] inPortMsgs = new int[numPorts]; for (int port = 0; port < numPorts; port++) inPortMsgs[port] = inMsgs[port]; for (int outIndex = 0; outIndex < outputMsgLength; outIndex++) { inPortMsgs[outPortNum] = outIndex; int weightIndex = factorTable.sparseIndexFromIndices(inPortMsgs); if (weightIndex >= 0) outMessage[outIndex] = factorTableWeights[weightIndex]; else outMessage[outIndex] = Double.POSITIVE_INFINITY; } } } } }; runTest("gibbsUpdateMessageOriginal", numPorts_, "call", 42, original); Runnable modified = new Runnable() { @Override public void run() { final int numPorts = numPorts_; final double[][] outMsgs = outMsgs_; final int[] inMsgs = inMsgs_; final IFactorTable factorTable = _table; for (int iteration = _iterations; --iteration>=0;) { for (int outPortNum = 0; outPortNum < numPorts; ++outPortNum) { double[] outMessage = outMsgs[outPortNum]; int outputMsgLength = outMessage.length; int[] inPortMsgs = new int[numPorts]; for (int port = 0; port < numPorts; port++) inPortMsgs[port] = inMsgs[port]; for (int outIndex = 0; outIndex < outputMsgLength; outIndex++) { inPortMsgs[outPortNum] = outIndex; outMessage[outIndex] = factorTable.getEnergyForIndices(inPortMsgs); } } } } }; runTest("gibbsUpdateMessageModified", numPorts_, "call", 42, modified); } /*----------------- * Private methods */ private @Nullable FactorTable getNewTable() { if (_table instanceof FactorTable) { return (FactorTable)_table; } return null; } private Object[] randomArguments() { JointDomainIndexer domains = _table.getDomainIndexer(); Object[] arguments = new Object[domains.size()]; for (int i = 0; i < arguments.length; ++i) { arguments[i] = domains.get(i).getElement(_random.nextInt(domains.getDomainSize(i))); } return arguments; } private int[] randomIndices() { JointDomainIndexer domains = _table.getDomainIndexer(); int[] indices = new int[domains.size()]; for (int i = 0; i < indices.length; ++i) { indices[i] = _random.nextInt(domains.getDomainSize(i)); } return indices; } private double runTest(String name, int unitMultiplier, String unit, int seed, Runnable test) { _random.setSeed(seed); // Warmup test.run(); _random.setSeed(seed); _timer.reset(); _timer.start(); test.run(); _timer.stop(); long ns = _timer.elapsed(TimeUnit.NANOSECONDS); return logTime(name, ns / (_iterations * unitMultiplier), unit); } private double runTest(String name, String unit, int seed, Runnable test) { return runTest(name, 1, unit, seed, test); } private double runTest(String name, Runnable test) { return runTest(name, "call", 42, test); } private double logTime(String name, double time, String unit) { if (showLog) { System.out.format("%s.%s: %f/%s\n", _table.getClass().getSimpleName(), name, time, unit); } return time; } }