/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.optimizedupdate;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Set;
import java.util.TreeSet;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.Comparators;
import com.analog.lyric.collect.Tuple2;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableRepresentation;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer;
import com.analog.lyric.util.misc.Internal;
/**
* Represents factor tables internally for the optimized update algorithm. At the root of the
* marginalization tree, wraps the factor's factor table. Within the tree, represents auxiliary
* tables with a lighter-weight representation.
* <p>
* The values stored in the table may be weights or energies, depending on the solver.
* Solver-specific code provides the root-level values, and solver-specific code operates on the
* values to compute the auxiliary table values.
* <p>
* Auxiliary tables employ thread-local storage when multithreading is enabled.
*
* @since 0.06
* @author jking
*/
@Internal
public final class TableWrapper
{
private final IFactorTable _factorTable;
private final ThreadLocal<double[]> _values;
private final int _size;
private final boolean _isSparse;
private final ISFactorGraphToOptimizedUpdateAdapter _helper;
private final double _sparseThreshold;
public TableWrapper(IFactorTable factorTable, boolean useThreadLocalValues, ISFactorGraphToOptimizedUpdateAdapter isFactorGraphToCostOptimizerAdapter, double sparseThreshold)
{
_factorTable = factorTable;
_isSparse = factorTable.hasSparseRepresentation();
_helper = isFactorGraphToCostOptimizerAdapter;
_sparseThreshold = sparseThreshold;
if (!useThreadLocalValues)
{
final double[] values;
if (_isSparse)
{
values = _helper.getSparseValues(factorTable);
}
else
{
/*
* Because useThreadLocalValues is only false at the root of the tree, with
* factor tables from the model, and because the sum-product solver sets those
* tables' representation to sparse, it is most likely impossible to hit this
* branch. But just in case...
*/
values = _helper.getDenseValues(factorTable);
}
_size = values.length;
_values = new ThreadLocal<double[]>() {
@Override
protected double[] initialValue()
{
return values;
}
};
}
else
{
if (_isSparse)
{
_size = _helper.getSparseValues(factorTable).length;
}
else
{
_size = _helper.getDenseValues(factorTable).length;
}
_values = new ThreadLocal<double[]>() {
@Override
protected double[] initialValue()
{
return new double[_size];
}
};
}
}
public TableWrapper(final IFactorTable factorTable, ISFactorGraphToOptimizedUpdateAdapter helper, double sparseThreshold)
{
this(factorTable, false, helper, sparseThreshold);
}
public IUpdateStep createOutputStep(final int outPortNum)
{
if (_isSparse)
{
return _helper.createSparseOutputStep(outPortNum, this);
}
else
{
return _helper.createDenseOutputStep(outPortNum, this);
}
}
private Tuple2<int[][], int[]> processIndices(final int dimension, final IFactorTable g_factorTable)
{
final int[][] all_f_indices = getFactorTable().getIndicesSparseUnsafe();
final int[] _msg_indices = new int[_size];
final int[][] g_indices = new int[_size][];
int n = 0;
for (final int[] f_indices : all_f_indices)
{
_msg_indices[n] = f_indices[dimension];
g_indices[n] = ArrayUtil.removeIntArrayEntry(f_indices, dimension);
n += 1;
}
// The next section of this function sets the weight for each of the used g_indices to a
// non-zero value.
Comparator<int[]> comparator = Comparators.reverseLexicalIntArray();
Set<int[]> all_g_indices = new TreeSet<>(comparator);
all_g_indices.addAll(Arrays.asList(g_indices));
int[][] indices = all_g_indices.toArray(new int[all_g_indices.size()][]);
double[] weights = new double[indices.length];
Arrays.fill(weights, 1.0);
g_factorTable.setWeightsSparse(indices, weights);
return new Tuple2<int[][], int[]>(g_indices, _msg_indices);
}
public IMarginalizationStep createMarginalizationStep(final int inPortNum, final int dimension)
{
final JointDomainIndexer f_indexer = _factorTable.getDomainIndexer();
final JointDomainReindexer g_remover = JointDomainReindexer.createRemover(f_indexer, dimension);
final JointDomainIndexer g_indexer = g_remover.getToDomains();
IFactorTable g_factorTable = FactorTable.create(g_indexer);
if (_isSparse)
{
Tuple2<int[][], int[]> g_and_msg_indices = processIndices(dimension, g_factorTable);
if (g_indexer.supportsJointIndexing()
&& (g_factorTable.countNonZeroWeights() >= g_indexer.getCardinality() * _sparseThreshold))
{
g_factorTable.setRepresentation(FactorTableRepresentation.DENSE_WEIGHT);
}
return _helper
.createSparseMarginalizationStep(this, inPortNum, dimension, g_factorTable, g_and_msg_indices);
}
else
{
g_factorTable.setRepresentation(FactorTableRepresentation.DENSE_WEIGHT);
return _helper.createDenseMarginalizationStep(this, inPortNum, dimension, g_factorTable);
}
}
public IFactorTable getFactorTable()
{
return _factorTable;
}
public ThreadLocal<double[]> getValues()
{
return _values;
}
public double getSparseThreshold()
{
return _sparseThreshold;
}
}