/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions.core;
import java.util.BitSet;
import java.util.Objects;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer;
import com.analog.lyric.dimple.model.values.Value;
/**
* @since 0.05
* @author Christopher Barber
*/
public abstract class SparseFactorTableBase extends FactorTableBase implements IFactorTable
{
private static final long serialVersionUID = 1L;
/**
* Set if table is known to be in normalized form (all weights add up to one).
*/
static final int NORMALIZED = 0x01;
/**
* Set if value of {@link #NORMALIZED} bit has been computed.
*/
static final int NORMALIZED_COMPUTED = 0x02;
/**
* Set if table is directed and has been conditionally normalized (so that the total weight for any
* two inputs is the same).
*/
static final int CONDITIONAL = 0x04;
/**
* Set if value of {@link #CONDITIONAL} bit has been computed.
*/
static final int CONDITIONAL_COMPUTED = 0x08;
/**
* Bit mask indicating how the contents of the table are represented. Exposed
* by {@link #getRepresentation()} and {@link #setRepresentation(FactorTableRepresentation)}.
* <p>
* This is a combination of the bits:
* <ul>
* <li>{@link FactorTable#DENSE_ENERGY},
* <li>{@link FactorTable#DENSE_WEIGHT},
* <li>{@link FactorTable#SPARSE_ENERGY},
* <li>{@link FactorTable#SPARSE_WEIGHT}
* <li>{@link FactorTable#SPARSE_INDICES}.
* </ul>
*/
int _representation;
double[] _sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
double[] _sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
/**
* Count of table entries with non-zero weight/non-infinite energy.
* <p>
* When table has no sparse representation, this is returned as the {@link #sparseSize()}.
*/
int _nonZeroWeights;
/**
* Same information as FactorTable's _sparseIndexToJointIndex but instead of storing joint indices stores
* arrays of element indices.
*/
int[][] _sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
/**
* Information computed about the table based on its values. This field is zeroed out whenever
* table weights or energies are changed.
* <p>
* Consists of the bitsL
* <ul>
* <li>{@link FactorTable#DETERMINISTIC_COMPUTED}
* <li>{@link #NORMALIZED}
* <li>{@link #NORMALIZED_COMPUTED}
* <li>{@link #CONDITIONAL}
* <li>{@link #CONDITIONAL_COMPUTED}.
* </ul>
*/
int _computedMask = 0;
/*--------------
* Construction
*/
SparseFactorTableBase(JointDomainIndexer domains)
{
super(domains);
_nonZeroWeights = 0;
_representation = FactorTable.SPARSE_ENERGY;
}
SparseFactorTableBase(SparseFactorTableBase that)
{
super(that);
_nonZeroWeights = that._nonZeroWeights;
_representation = that._representation;
_computedMask = that._computedMask;
_sparseEnergies = Objects.requireNonNull(ArrayUtil.cloneArray(that._sparseEnergies));
_sparseWeights = Objects.requireNonNull(ArrayUtil.cloneArray(that._sparseWeights));
_sparseIndices = Objects.requireNonNull(ArrayUtil.cloneArray(that._sparseIndices));
}
/*--------------------------
* IFactorTableBase methods
*/
@Override
public final IFactorTable convert(JointDomainReindexer converter)
{
return FactorTable.convert(this, converter);
}
@Override
public final int countNonZeroWeights()
{
return _nonZeroWeights;
}
@Override
public final boolean hasSparseEnergies()
{
return (_representation & FactorTable.SPARSE_ENERGY) != 0;
}
@Override
public final boolean hasSparseIndices()
{
return (_representation & FactorTable.SPARSE_INDICES) != 0;
}
@Override
public final boolean hasSparseWeights()
{
return (_representation & FactorTable.SPARSE_WEIGHT) != 0;
}
@Override
public final boolean isNormalized()
{
if ((_computedMask & NORMALIZED_COMPUTED) == 0)
{
if (!isDirected())
{
normalizeUndirected(true);
}
_computedMask |= NORMALIZED_COMPUTED;
}
return (_computedMask & NORMALIZED) != 0;
}
@Override
public final void normalize()
{
if (isDirected())
{
throw new UnsupportedOperationException(
"normalize() not supported for directed factor table. Use normalizeConditional() instead");
}
normalizeUndirected(false);
}
@Override
public final void normalizeConditional()
{
normalizeConditional(false);
}
@Override
public final int normalizeConditional(boolean ignoreZeroWeightInputs)
{
if (!isDirected())
{
throw new UnsupportedOperationException(
"normalizeConditional() not supported for undirected factor table. Use normalize() instead");
}
return normalizeDirected(false, ignoreZeroWeightInputs);
}
@Override
public final void setDirected(@Nullable BitSet outputSet)
{
setDirected(outputSet, false);
}
/*----------------------
* IFactorTable methods
*/
@Override
public final IFactorTable createTableWithNewVariables(DiscreteDomain[] additionalDomains)
{
JointDomainIndexer domains = getDomainIndexer();
JointDomainReindexer converter =
JointDomainReindexer.createAdder(domains, domains.size(), additionalDomains);
return FactorTable.convert(this, converter);
}
@Override
public final double[] getEnergySlice(int sliceDimension, int... indices)
{
return getEnergySlice(null, sliceDimension, indices);
}
@Override
public final double[] getEnergySlice(int sliceDimension, Value ... values)
{
return getEnergySlice(null, sliceDimension, values);
}
@Override
public final FactorTableRepresentation getRepresentation()
{
return FactorTableRepresentation.forMask(_representation);
}
@Override
public final double[] getWeightSlice(int sliceDimension, int... indices)
{
return getWeightSlice(null, sliceDimension, indices);
}
@Override
public final double[] getWeightSlice(int sliceDimension, Value ... values)
{
return getWeightSlice(null, sliceDimension, values);
}
@Override
public final void replaceEnergiesSparse(double[] energies)
{
final int size = energies.length;
if (size != sparseSize())
{
throw new IllegalArgumentException(
String.format("Array size (%d) does not match sparse size (%d).", size, sparseSize()));
}
for (int si = 0; si < size; ++si)
{
setEnergyForSparseIndex(energies[si], si);
}
}
@Override
public final void replaceWeightsSparse(double[] weights)
{
final int size = weights.length;
if (size != sparseSize())
{
throw new IllegalArgumentException(
String.format("Array size (%d) does not match sparse size (%d).", size, sparseSize()));
}
for (int si = 0; si < size; ++si)
{
setWeightForSparseIndex(weights[si], si);
}
}
@Override
public final void setConditional(BitSet outputSet)
{
Objects.requireNonNull(outputSet);
setDirected(outputSet, true);
}
@Override
public final void setEnergiesSparse(int[][] indices, double[] energies)
{
setSparseValues(indices, energies, FactorTable.SPARSE_ENERGY);
}
@Override
public final void setRepresentation(FactorTableRepresentation representation)
{
setRepresentation(representation.mask());
}
@Override
public final void setWeightsSparse(int[][] indices, double[] weights)
{
setSparseValues(indices, weights, FactorTable.SPARSE_WEIGHT);
}
@Override
public final void makeConditional(BitSet outputSet)
{
Objects.requireNonNull(outputSet);
setDirected(outputSet, false);
normalizeConditional();
}
/*-----------------
* Package methods
*/
final void assertIsConditional()
{
if (!isConditional())
{
throw new DimpleException("weights must be normalized correctly for directed factors");
}
}
abstract int normalizeDirected(boolean justCheck, boolean ignoreZeroWeightInputs);
/**
* Throws exception with message indicating an attempt to normalize a directed table whose weights
* for some input adds up to zero.
*
* @return 1 if {@code justCheck} is true, otherwise throws an exception.
* @throws DimpleException if {@code justCheck} is false.
*/
final int normalizeDirectedHandleZeroForInput(boolean justCheck)
{
if (!justCheck)
{
throw new DimpleException("Cannot normalize directed factor table with zero total weight for some input");
}
return 1;
}
/**
* Returns an exception with message indicating an attempt to normalize an undirected table whose weights
* add up to zero.
*/
final DimpleException normalizeUndirectedHandleZero()
{
return new DimpleException("Cannot normalize undirected factor table with zero total weight");
}
abstract boolean normalizeUndirected(boolean justCheck);
abstract void setDirected(@Nullable BitSet outputSet, boolean assertConditional);
abstract void setRepresentation(int newRep);
abstract void setSparseValues(int[][] indices, double[] values, int representation);
}