/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct; import static com.analog.lyric.dimple.environment.DimpleEnvironment.*; import java.util.Map; import java.util.Random; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.Tuple2; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.FiniteFieldVariable; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.model.variables.RealJoint; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.options.BPOptions; import com.analog.lyric.dimple.options.DimpleOptions; import com.analog.lyric.dimple.schedulers.DefaultScheduler; import com.analog.lyric.dimple.schedulers.TreeSchedulerAbstract; import com.analog.lyric.dimple.solvers.core.BPSolverGraph; import com.analog.lyric.dimple.solvers.core.NoSolverEdge; import com.analog.lyric.dimple.solvers.core.ParameterEstimator; import com.analog.lyric.dimple.solvers.core.multithreading.MultiThreadingManager; import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions; import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; import com.analog.lyric.dimple.solvers.optimizedupdate.CostEstimationTableWrapper; import com.analog.lyric.dimple.solvers.optimizedupdate.CostType; import com.analog.lyric.dimple.solvers.optimizedupdate.Costs; import com.analog.lyric.dimple.solvers.optimizedupdate.FactorTableUpdateSettings; import com.analog.lyric.dimple.solvers.optimizedupdate.FactorUpdatePlan; import com.analog.lyric.dimple.solvers.optimizedupdate.IMarginalizationStep; import com.analog.lyric.dimple.solvers.optimizedupdate.IMarginalizationStepEstimator; import com.analog.lyric.dimple.solvers.optimizedupdate.ISFactorGraphToOptimizedUpdateAdapter; import com.analog.lyric.dimple.solvers.optimizedupdate.IUpdateStep; import com.analog.lyric.dimple.solvers.optimizedupdate.IUpdateStepEstimator; import com.analog.lyric.dimple.solvers.optimizedupdate.TableWrapper; import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateApproach; import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateCostOptimizer; import com.analog.lyric.dimple.solvers.sumproduct.sampledfactor.SampledFactor; import com.analog.lyric.options.IOptionKey; import com.analog.lyric.options.TemporaryOptionSettings; /** * Solver representation of factor graph under Sum-Product solver. * <p> * @since 0.07 */ public class SumProductSolverGraph extends BPSolverGraph<ISolverFactor,ISolverVariable,ISolverEdgeState> { /*------- * State */ private double _damping = 0; private @Nullable IFactorTable _currentFactorTable = null; private static Random _rand = new Random(); /*-------------- * Construction */ public SumProductSolverGraph(FactorGraph factorGraph, @Nullable ISolverFactorGraph parent) { super(factorGraph, parent); setMultithreadingManager(new MultiThreadingManager(this)); // Set default Gibbs options for sampled factors. setOption(GibbsOptions.numSamples, SampledFactor.DEFAULT_SAMPLES_PER_UPDATE); setOption(GibbsOptions.burnInScans, SampledFactor.DEFAULT_BURN_IN_SCANS_PER_UPDATE); setOption(GibbsOptions.scansPerSample, SampledFactor.DEFAULT_SCANS_PER_SAMPLE); } /*---------------------- * ISolverGraph methods */ // TODO - rearrange methods @Override public boolean hasEdgeState() { return true; } @Override public ISolverEdgeState createEdgeState(EdgeState edge) { final ISolverFactor sfactor = getSolverFactor(edge.getFactor(_model)); ISolverEdgeState sedge = sfactor.createEdge(edge); if (sedge != null) { return sedge; } final Variable var = edge.getVariable(_model); if (var instanceof Discrete) { return new SumProductDiscreteEdge((Discrete)var); } else if (var instanceof Real) { return new SumProductNormalEdge(); } else if (var instanceof RealJoint) { return new SumProductMultivariateNormalEdge((RealJoint)var); } return NoSolverEdge.INSTANCE; } @SuppressWarnings("deprecation") // TODO remove when S* classes removed @Override public ISolverVariable createVariable(Variable var) { if (var instanceof FiniteFieldVariable) return new SFiniteFieldVariable((FiniteFieldVariable)var, this); else if (var instanceof RealJoint) return new SRealJointVariable((RealJoint)var, this); else if (var instanceof Real) return new SRealVariable((Real)var, this); else if (var instanceof Discrete) { return new SDiscreteVariable((Discrete)var, this); } throw unsupportedVariableType(var); } @Override public ISolverFactor createFactor(Factor factor) { return SumProductOptions.customFactors.createFactor(factor, this); } @SuppressWarnings("deprecation") // TODO remove when SFactorGraph removed @Override public ISolverFactorGraph createSubgraph(FactorGraph subgraph) { return new SFactorGraph(subgraph, this); } // This should return true only for custom factors that do not have a corresponding FactorFunction of the same name @Deprecated @Override public boolean customFactorExists(String funcName) { if (funcName.equals("finiteFieldAdd")) // For backward compatibility return true; else if (funcName.equals("finiteFieldMult")) // For backward compatibility return true; else if (funcName.equals("finiteFieldProjection")) // For backward compatibility return true; else if (funcName.equals("multiplexerCPD")) // For backward compatibility; should use "Multiplexer" instead return true; else if (funcName.equals("add")) // For backward compatibility return true; else if (funcName.equals("constmult")) // For backward compatibility return true; else if (funcName.equals("linear")) // For backward compatibility return true; else if (funcName.equals("polynomial")) // For backward compatibility return true; else return false; } public static Random getRandom() { return _rand; } public void setSeed(long seed) { _rand = new Random(seed); // Used for parameter estimation activeRandom().setSeed(seed); // Used for sampled factors } /* * Set the global solver damping parameter. We have to go through all factor graphs * and update the damping parameter on all existing table functions in that graph. */ public void setDamping(double damping) { _damping = damping; setOption(BPOptions.damping, damping); } public double getDamping() { return _damping; } /** * Indicates if this solver supports the optimized update algorithm. * * @since 0.06 */ public boolean isOptimizedUpdateSupported() { return true; } private final ISFactorGraphToOptimizedUpdateAdapter _optimizedUpdateAdapter = new SFactorGraphToOptimizedUpdateAdapter(this); private static class SFactorGraphToOptimizedUpdateAdapter implements ISFactorGraphToOptimizedUpdateAdapter { final private SumProductSolverGraph _sumProductSolverGraph; SFactorGraphToOptimizedUpdateAdapter(SumProductSolverGraph sumProductSolverGraph) { _sumProductSolverGraph = sumProductSolverGraph; } @Override public IUpdateStepEstimator createSparseOutputStepEstimator(CostEstimationTableWrapper tableWrapper) { return new TableFactorEngineOptimized.SparseOutputStepEstimator(tableWrapper); } @Override public IUpdateStepEstimator createDenseOutputStepEstimator(CostEstimationTableWrapper tableWrapper) { return new TableFactorEngineOptimized.DenseOutputStepEstimator(tableWrapper); } @Override public IMarginalizationStepEstimator createSparseMarginalizationStepEstimator(CostEstimationTableWrapper tableWrapper, int inPortNum, int dimension, CostEstimationTableWrapper g) { return new TableFactorEngineOptimized.SparseMarginalizationStepEstimator(tableWrapper, inPortNum, dimension, g); } @Override public IMarginalizationStepEstimator createDenseMarginalizationStepEstimator(CostEstimationTableWrapper tableWrapper, int inPortNum, int dimension, CostEstimationTableWrapper g) { return new TableFactorEngineOptimized.DenseMarginalizationStepEstimator(tableWrapper, inPortNum, dimension, g); } @Override public Costs estimateCostOfNormalUpdate(IFactorTable factorTable) { Costs result = new Costs(); final int size = factorTable.countNonZeroWeights(); final int dimensions = factorTable.getDimensions(); // Coefficients determined experimentally double executionTime = 1.73280131035; executionTime += -46.4751637511 * (size - 254722.59319) / 9956266.39996; executionTime += 342.15344018 * (dimensions * size - 1877809.77842) / 219896210.353; result.put(CostType.EXECUTION_TIME, executionTime); return result; } @Override public Costs estimateCostOfOptimizedUpdate(IFactorTable factorTable, final double sparseThreshold) { final Costs costs = FactorUpdatePlan.estimateOptimizedUpdateCosts(factorTable, this, sparseThreshold); double dmf = costs.get(CostType.DENSE_MARGINALIZATION_SIZE); double smf = costs.get(CostType.SPARSE_MARGINALIZATION_SIZE); double fo = costs.get(CostType.OUTPUT_SIZE); double mem_cost = costs.get(CostType.MEMORY) * 1024.0 * 1024.0 * 1024.0; final double size = factorTable.countNonZeroWeights(); // Coefficients determined experimentally double executionTime = 0.08; executionTime += 1.24138327837 * (size - 254722.59319) / 9956266.39996; executionTime += 2.18296909944 * (dmf - 316560.301654) / 39676196.0; executionTime += 0.883232752009 * (smf - 421208.789658) / 19914546.0; executionTime += 1.60951456134 * (fo - 4453.26298626) / 1836974.0; executionTime += 1.08967345943 * (mem_cost - 6742787.05688) / 416754545.21; executionTime += -0.447862077999 * Math.pow((size - 254722.59319) / 9956266.39996, 2.0); executionTime += -0.585003946613 * Math.pow((mem_cost - 6742787.05688) / 416754545.21, 2.0); final Costs result = new Costs(); result.put(CostType.MEMORY, costs.get(CostType.MEMORY)); result.put(CostType.EXECUTION_TIME, executionTime); return result; } @Override public ISolverFactorGraph getSolverGraph() { return _sumProductSolverGraph; } @Override public int getWorkers(ISolverFactorGraph sfactorGraph) { SumProductSolverGraph sfg = (SumProductSolverGraph) sfactorGraph; if (sfg.useMultithreading()) { return sfg.getMultithreadingManager().getNumWorkers(); } else { return 1; } } @Override public void putFactorTableUpdateSettings(Map<IFactorTable, FactorTableUpdateSettings> optionsValueByFactorTable) { _sumProductSolverGraph._factorTableUpdateSettings = optionsValueByFactorTable; } @Override public double[] getSparseValues(IFactorTable factorTable) { return factorTable.getWeightsSparseUnsafe(); } @Override public double[] getDenseValues(IFactorTable factorTable) { return factorTable.getWeightsDenseUnsafe(); } @Override public IUpdateStep createSparseOutputStep(int outPortNum, TableWrapper tableWrapper) { return new TableFactorEngineOptimized.SparseOutputStep(outPortNum, tableWrapper); } @Override public IUpdateStep createDenseOutputStep(int outPortNum, TableWrapper tableWrapper) { return new TableFactorEngineOptimized.DenseOutputStep(outPortNum, tableWrapper); } @Override public IMarginalizationStep createSparseMarginalizationStep(TableWrapper tableWrapper, int inPortNum, int dimension, IFactorTable g_factorTable, Tuple2<int[][], int[]> g_and_msg_indices) { return new TableFactorEngineOptimized.SparseMarginalizationStep(tableWrapper, this, inPortNum, dimension, g_factorTable, g_and_msg_indices); } @Override public IMarginalizationStep createDenseMarginalizationStep(TableWrapper tableWrapper, int inPortNum, int dimension, IFactorTable g_factorTable) { return new TableFactorEngineOptimized.DenseMarginalizationStep(tableWrapper, this, inPortNum, dimension, g_factorTable); } @Override public IOptionKey<UpdateApproach> getUpdateApproachOptionKey() { return BPOptions.updateApproach; } @Override public IOptionKey<Double> getOptimizedUpdateSparseThresholdKey() { return BPOptions.optimizedUpdateSparseThreshold; } @Override public IOptionKey<Double> getAutomaticExecutionTimeScalingFactorKey() { return BPOptions.automaticExecutionTimeScalingFactor; } @Override public IOptionKey<Double> getAutomaticMemoryAllocationScalingFactorKey() { return BPOptions.automaticMemoryAllocationScalingFactor; } } private @Nullable Map<IFactorTable, FactorTableUpdateSettings> _factorTableUpdateSettings; @Nullable FactorTableUpdateSettings getFactorTableUpdateSettings(Factor factor) { final Map<IFactorTable, FactorTableUpdateSettings> map = _factorTableUpdateSettings; FactorTableUpdateSettings result = null; if (map != null && factor.hasFactorTable()) { IFactorTable factorTable = factor.getFactorTable(); result = map.get(factorTable); } return result; } /** * @deprecated Will be removed in a future release. Instead set {@link GibbsOptions#numSamples} on * this object using {@link #setOption}. */ @Deprecated public void setSampledFactorSamplesPerUpdate(int samplesPerUpdate) { setOption(GibbsOptions.numSamples, samplesPerUpdate); } /** * @deprecated Will be removed in a future release. Instead get {@link GibbsOptions#numSamples} * on from this object using {@link #getOption}. */ @Deprecated public int getSampledFactorSamplesPerUpdate() { return getOptionOrDefault(GibbsOptions.numSamples); } /** * @deprecated Will be removed in a future release. Instead set {@link GibbsOptions#burnInScans} on * this object using {@link #setOption}. */ @Deprecated public void setSampledFactorBurnInScansPerUpdate(int burnInScans) { setOption(GibbsOptions.burnInScans, burnInScans); } /** * @deprecated Will be removed in a future release. Instead set {@link GibbsOptions#burnInScans} on * this object using {@link #setOption}. */ @Deprecated public int getSampledFactorBurnInScansPerUpdate() { return getOptionOrDefault(GibbsOptions.burnInScans); } /** * @deprecated Will be removed in a future release. Instead set {@link GibbsOptions#scansPerSample} on * this object using {@link #setOption}. */ @Deprecated public void setSampledFactorScansPerSample(int scansPerSample) { setOption(GibbsOptions.scansPerSample, scansPerSample); } /** * @deprecated Will be removed in a future release. Instead set {@link GibbsOptions#scansPerSample} on * this object using {@link #setOption}. */ @Deprecated public int getSampledFactorScansPerSample() { return getOptionOrDefault(GibbsOptions.scansPerSample); } @Override public void baumWelch(IFactorTable [] fts, int numRestarts, int numSteps) { ParameterEstimator pe = new ParameterEstimator.BaumWelch(_model, fts, SumProductSolverGraph.getRandom()); pe.run(numRestarts, numSteps); } class GradientDescent extends ParameterEstimator { private double _scaleFactor; public GradientDescent(FactorGraph fg, IFactorTable[] tables, Random r, double scaleFactor) { super(fg, tables, r); _scaleFactor = scaleFactor; } @Override public void runStep(FactorGraph fg) { //_factorGraph.solve(); for (IFactorTable ft : getTables()) { double [] weights = ft.getWeightsSparseUnsafe(); //for each weight for (int i = 0; i < weights.length; i++) { //calculate the derivative double derivative = calculateDerivativeOfBetheFreeEnergyWithRespectToWeight(ft, i); //move the weight in that direction scaled by epsilon ft.setWeightForSparseIndex(weights[i] - weights[i]*derivative*_scaleFactor,i); } } } } public void pseudoLikelihood(IFactorTable [] fts, Variable [] vars, Object [][] data, int numSteps, double stepScaleFactor) { } public static @Nullable int [][] convertObjects2Indices(Variable [] vars, Object [][] data) { return null; } @Override public void estimateParameters(IFactorTable [] fts, int numRestarts, int numSteps, double stepScaleFactor) { new GradientDescent(_model, fts, getRandom(), stepScaleFactor).run(numRestarts, numSteps); } @SuppressWarnings("null") public double calculateDerivativeOfBetheFreeEnergyWithRespectToWeight(IFactorTable ft, int weightIndex) { //BFE = InternalEnergy - BetheEntropy //InternalEnergy = Sum over all factors (Internal Energy of Factor) // + Sum over all variables (Internal Energy of Variable) //BetheEntropy = Sum over all factors (BetheEntropy(factor)) // + sum over all variables (BetheEntropy(variable) //So derivative of BFE = Sum over all factors that contain the weight // (derivative of Internal Energy of Factor // - derivative of BetheEntropy of Factor) // _currentFactorTable = ft; final SolverNodeMapping solvers = getSolverMapping(); for (Factor f : _model.getFactors()) { ((SumProductTableFactor)solvers.getSolverFactor(f)).initializeDerivativeMessages(ft.sparseSize()); } for (Variable vb : _model.getVariablesFlat()) { ((SumProductDiscrete)solvers.getSolverVariable(vb)).initializeDerivativeMessages(ft.sparseSize()); } setCalculateDerivative(true); double result = 0; try { _model.solve(); for (Factor f : _model.getFactors()) { SumProductTableFactor stf = (SumProductTableFactor)solvers.getSolverFactor(f); result += stf.calculateDerivativeOfInternalEnergyWithRespectToWeight(weightIndex); result -= stf.calculateDerivativeOfBetheEntropyWithRespectToWeight(weightIndex); } for (Variable v : _model.getVariablesFlat()) { SumProductDiscrete sv = (SumProductDiscrete)solvers.getSolverVariable(v); result += sv.calculateDerivativeOfInternalEnergyWithRespectToWeight(weightIndex); result += sv.calculateDerivativeOfBetheEntropyWithRespectToWeight(weightIndex); } } finally { setCalculateDerivative(false); } return result; } @SuppressWarnings("null") public void setCalculateDerivative(boolean val) { for (ISolverFactor sfactor : getSolverFactorsRecursive()) { SumProductTableFactor stf = (SumProductTableFactor)sfactor; stf.setUpdateDerivative(val); } for (ISolverVariable svar : getSolverVariablesRecursive()) { SumProductDiscrete sv = (SumProductDiscrete)svar; sv.setCalculateDerivative(val); } } // REFACTOR: make this package-protected? public @Nullable IFactorTable getCurrentFactorTable() { return _currentFactorTable; } @Override public void initialize() { super.initialize(); UpdateCostOptimizer optimizer = new UpdateCostOptimizer(_optimizedUpdateAdapter); optimizer.optimize(this); final SolverNodeMapping solvers = getSolverMapping(); for (Factor f : getModelObject().getFactors()) { ISolverFactor sf = solvers.getSolverFactor(f); if (sf instanceof SumProductTableFactor) { SumProductTableFactor tf = (SumProductTableFactor)sf; IFactorTable table = tf.getFactorTableIfComputed(); if (table != null) { tf.getFactorTable().getIndicesSparseUnsafe(); tf.getFactorTable().getWeightsSparseUnsafe(); } tf.setupTableFactorEngine(); } } // // Update options // Long seed = getOption(DimpleOptions.randomSeed); if (seed != null) { setSeed(seed); } _damping = getOptionOrDefault(BPOptions.damping); } /*------------------------------- * SumProductSolverGraph methods */ public double computeLogPartitionFunction() { if (!_model.isForest()) { throw new UnsupportedOperationException(String.format( "%s is not a tree or forest. Sum-product cannot compute partition function on loopy graph.", _model)); } SumProductDiscrete minVar = null; int minCost = Integer.MAX_VALUE; for (ISolverVariable svar : getSolverVariablesRecursive()) { if (svar instanceof SumProductDiscrete) { // Pick variable with smallest marginalization cost. SumProductDiscrete sdiscrete = (SumProductDiscrete)svar; int computationCost = sdiscrete.getDomain().size() * (svar.getSiblingCount() + 1); if (computationCost < minCost) { minVar = sdiscrete; minCost = computationCost; } } else { throw new UnsupportedOperationException(String.format( "Variable %s is not discrete sum-product variable.", svar)); } } if (minVar == null) { // Graph has no variables! return 0.0; } try (TemporaryOptionSettings tmp = new TemporaryOptionSettings()) { // We don't implement passing normalization energy when using node update // (which are less useful in tree schedules in any case), so we make sure // we are using an edge only schedule. TreeSchedulerAbstract scheduler = new DefaultScheduler(); scheduler.useOnlyEdgeUpdates(); tmp.set(this, BPOptions.scheduler, scheduler); tmp.set(this, BPOptions.updateApproach, UpdateApproach.NORMAL); tmp.set(this, BPOptions.damping, 0.0); tmp.set(this, BPOptions.maxMessageSize, Integer.MAX_VALUE); for (ISolverVariable svar : getSolverVariablesRecursive()) { tmp.setIfDifferent(svar, BPOptions.damping, 0.0); } for (ISolverFactor sfactor : getSolverFactorsRecursive()) { tmp.setIfDifferent(sfactor, BPOptions.damping, 0.0); tmp.setIfDifferent(sfactor, BPOptions.updateApproach, UpdateApproach.NORMAL); tmp.setIfDifferent(sfactor, BPOptions.maxMessageSize, Integer.MAX_VALUE); } // TODO: // - only use messages with normalization energy if this method // is called. initialize(); iterate(); } return minVar.computeLogPartitionFunction(); } /*------------------- * Protected methods */ @Override protected void doUpdateEdge(int edge) { } @Override protected String getSolverName() { return "sum-product"; } }