/*******************************************************************************
* Copyright 2012-2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.minsum;
import java.util.Map;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.Tuple2;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.EdgeState;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.options.BPOptions;
import com.analog.lyric.dimple.solvers.core.BPSolverGraph;
import com.analog.lyric.dimple.solvers.core.NoSolverEdge;
import com.analog.lyric.dimple.solvers.core.multithreading.MultiThreadingManager;
import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable;
import com.analog.lyric.dimple.solvers.optimizedupdate.CostEstimationTableWrapper;
import com.analog.lyric.dimple.solvers.optimizedupdate.CostType;
import com.analog.lyric.dimple.solvers.optimizedupdate.Costs;
import com.analog.lyric.dimple.solvers.optimizedupdate.FactorTableUpdateSettings;
import com.analog.lyric.dimple.solvers.optimizedupdate.FactorUpdatePlan;
import com.analog.lyric.dimple.solvers.optimizedupdate.IMarginalizationStep;
import com.analog.lyric.dimple.solvers.optimizedupdate.IMarginalizationStepEstimator;
import com.analog.lyric.dimple.solvers.optimizedupdate.ISFactorGraphToOptimizedUpdateAdapter;
import com.analog.lyric.dimple.solvers.optimizedupdate.IUpdateStep;
import com.analog.lyric.dimple.solvers.optimizedupdate.IUpdateStepEstimator;
import com.analog.lyric.dimple.solvers.optimizedupdate.TableWrapper;
import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateApproach;
import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateCostOptimizer;
import com.analog.lyric.options.IOptionKey;
/**
* Solver-specific factor graph for min-sum solver.
* <p>
* <em>Previously was com.analog.lyric.dimple.solvers.minsum.SFactorGraph</em>
* <p>
* @since 0.07
*/
public class MinSumSolverGraph extends BPSolverGraph<ISolverFactor,ISolverVariable,ISolverEdgeState>
{
protected double _damping = 0;
public MinSumSolverGraph(FactorGraph factorGraph, @Nullable ISolverFactorGraph parent)
{
super(factorGraph, parent);
setMultithreadingManager(new MultiThreadingManager(this));
}
@Override
public boolean hasEdgeState()
{
return true;
}
@Override
public void initialize()
{
_damping = getOptionOrDefault(BPOptions.damping);
super.initialize();
UpdateCostOptimizer optimizer = new UpdateCostOptimizer(_optimizedUpdateAdapter);
optimizer.optimize(this);
for (ISolverFactor sf : getSolverFactorsRecursive())
{
if (sf instanceof MinSumTableFactor)
{
MinSumTableFactor tf = (MinSumTableFactor)sf;
tf.setupTableFactorEngine();
}
}
}
@SuppressWarnings("deprecation") // TODO remove when SVariable removed
@Override
public ISolverVariable createVariable(Variable var)
{
if (var instanceof Discrete)
{
return new SVariable((Discrete)var, this);
}
throw unsupportedVariableType(var);
}
@Override
public ISolverFactor createFactor(Factor factor)
{
return MinSumOptions.customFactors.createFactor(factor, this);
}
@Override
public ISolverEdgeState createEdgeState(EdgeState edge)
{
final Variable var = edge.getVariable(_model);
if (var instanceof Discrete)
{
return new MinSumDiscreteEdge((Discrete)var);
}
return NoSolverEdge.INSTANCE;
}
@SuppressWarnings("deprecation") // TODO remove when SFactorGraph removed
@Override
public ISolverFactorGraph createSubgraph(FactorGraph subgraph)
{
return new SFactorGraph(subgraph, this);
}
// For backward compatibility only; preferable to use "Xor" factor function, which can
// be evaluated for scoring or other purposes, but still uses the custom factor. This may be removed at some point.
// This should return true only for custom factors that do not have a corresponding FactorFunction of the same name
@Override
public boolean customFactorExists(String funcName)
{
if (funcName.equals("CustomXor") || funcName.equals("customXor"))
return true;
else
return false;
}
/*
* Set the global solver damping parameter. We have to go through all factor graphs
* and update the damping parameter on all existing table functions in that graph.
*/
public void setDamping(double damping)
{
setOption(BPOptions.damping, damping);
_damping = damping;
}
public double getDamping()
{
return _damping;
}
/**
* Indicates if this solver supports the optimized update algorithm.
*
* @since 0.06
*/
public boolean isOptimizedUpdateSupported()
{
return true;
}
/*
*
*/
@Override
protected void doUpdateEdge(int edge)
{
}
@Override
protected String getSolverName()
{
return "min-sum";
}
private final ISFactorGraphToOptimizedUpdateAdapter _optimizedUpdateAdapter = new SFactorGraphToOptimizedUpdateAdapter(this);
private static class SFactorGraphToOptimizedUpdateAdapter implements ISFactorGraphToOptimizedUpdateAdapter
{
final private MinSumSolverGraph _minSumSolverGraph;
SFactorGraphToOptimizedUpdateAdapter(MinSumSolverGraph minSumSolverGraph)
{
_minSumSolverGraph = minSumSolverGraph;
}
@Override
public IUpdateStepEstimator createSparseOutputStepEstimator(CostEstimationTableWrapper tableWrapper)
{
return new TableFactorEngineOptimized.SparseOutputStepEstimator(tableWrapper);
}
@Override
public IUpdateStepEstimator createDenseOutputStepEstimator(CostEstimationTableWrapper tableWrapper)
{
return new TableFactorEngineOptimized.DenseOutputStepEstimator(tableWrapper);
}
@Override
public IMarginalizationStepEstimator
createSparseMarginalizationStepEstimator(CostEstimationTableWrapper tableWrapper,
int inPortNum,
int dimension,
CostEstimationTableWrapper g)
{
return new TableFactorEngineOptimized.SparseMarginalizationStepEstimator(tableWrapper, inPortNum,
dimension, g);
}
@Override
public IMarginalizationStepEstimator
createDenseMarginalizationStepEstimator(CostEstimationTableWrapper tableWrapper,
int inPortNum,
int dimension,
CostEstimationTableWrapper g)
{
return new TableFactorEngineOptimized.DenseMarginalizationStepEstimator(tableWrapper, inPortNum,
dimension, g);
}
@Override
public Costs estimateCostOfNormalUpdate(IFactorTable factorTable)
{
Costs result = new Costs();
final int size = factorTable.countNonZeroWeights();
final int dimensions = factorTable.getDimensions();
// Coefficients determined experimentally
double executionTime = 3.30461648566;
executionTime += 1.51472189501 * (size - 2397282.13878) / 4990159.0;
executionTime += 12.0304854157 * (dimensions * size - 24636832.1724) / 114805021.0;
result.put(CostType.EXECUTION_TIME, executionTime);
return result;
}
@Override
public Costs estimateCostOfOptimizedUpdate(IFactorTable factorTable, final double sparseThreshold)
{
final Costs costs = FactorUpdatePlan.estimateOptimizedUpdateCosts(factorTable, this, sparseThreshold);
double dmf = costs.get(CostType.DENSE_MARGINALIZATION_SIZE);
double smf = costs.get(CostType.SPARSE_MARGINALIZATION_SIZE);
double fo = costs.get(CostType.OUTPUT_SIZE);
final double size = factorTable.countNonZeroWeights();
// Coefficients determined experimentally
double executionTime = 1.29764000525;
executionTime += 6.92791055163 * (dmf - 3705266.58065) / 25293812.93;
executionTime += 4.29121266133 * (smf - 3224351.19011) / 14900000.0;
executionTime += -0.330453110368 * (fo - 12588.026109) / 724853.0;
executionTime += -1.36970402596 * (size - 2397282.13878) / 4990159.0;
final Costs result = new Costs();
result.put(CostType.MEMORY, costs.get(CostType.MEMORY));
result.put(CostType.EXECUTION_TIME, executionTime);
return result;
}
@Override
public ISolverFactorGraph getSolverGraph()
{
return _minSumSolverGraph;
}
@Override
public int getWorkers(ISolverFactorGraph sfactorGraph)
{
MinSumSolverGraph sfg = (MinSumSolverGraph) sfactorGraph;
if (sfg.useMultithreading())
{
return sfg.getMultithreadingManager().getNumWorkers();
}
else
{
return 1;
}
}
@Override
public void
putFactorTableUpdateSettings(Map<IFactorTable, FactorTableUpdateSettings> optionsValueByFactorTable)
{
_minSumSolverGraph._factorTableUpdateSettings = optionsValueByFactorTable;
}
@Override
public double[] getSparseValues(IFactorTable factorTable)
{
return factorTable.getEnergiesSparseUnsafe();
}
@Override
public double[] getDenseValues(IFactorTable factorTable)
{
return factorTable.getEnergiesDenseUnsafe();
}
@Override
public IUpdateStep createSparseOutputStep(int outPortNum, TableWrapper tableWrapper)
{
return new TableFactorEngineOptimized.SparseOutputStep(outPortNum, tableWrapper);
}
@Override
public IUpdateStep createDenseOutputStep(int outPortNum, TableWrapper tableWrapper)
{
return new TableFactorEngineOptimized.DenseOutputStep(outPortNum, tableWrapper);
}
@Override
public IMarginalizationStep createSparseMarginalizationStep(TableWrapper tableWrapper,
int inPortNum,
int dimension,
IFactorTable g_factorTable,
Tuple2<int[][], int[]> g_and_msg_indices)
{
return new TableFactorEngineOptimized.SparseMarginalizationStep(tableWrapper, this, inPortNum,
dimension, g_factorTable, g_and_msg_indices);
}
@Override
public IMarginalizationStep createDenseMarginalizationStep(TableWrapper tableWrapper,
int inPortNum,
int dimension,
IFactorTable g_factorTable)
{
return new TableFactorEngineOptimized.DenseMarginalizationStep(tableWrapper, this, inPortNum,
dimension, g_factorTable);
}
@Override
public IOptionKey<UpdateApproach> getUpdateApproachOptionKey()
{
return BPOptions.updateApproach;
}
@Override
public IOptionKey<Double> getOptimizedUpdateSparseThresholdKey()
{
return BPOptions.optimizedUpdateSparseThreshold;
}
@Override
public IOptionKey<Double> getAutomaticExecutionTimeScalingFactorKey()
{
return BPOptions.automaticExecutionTimeScalingFactor;
}
@Override
public IOptionKey<Double> getAutomaticMemoryAllocationScalingFactorKey()
{
return BPOptions.automaticMemoryAllocationScalingFactor;
}
}
private @Nullable Map<IFactorTable, FactorTableUpdateSettings> _factorTableUpdateSettings;
@Nullable FactorTableUpdateSettings getFactorTableUpdateSettings(Factor factor)
{
final Map<IFactorTable, FactorTableUpdateSettings> map = _factorTableUpdateSettings;
FactorTableUpdateSettings result = null;
if (map != null && factor.hasFactorTable())
{
IFactorTable factorTable = factor.getFactorTable();
result = map.get(factorTable);
}
return result;
}
}