/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions.core;
import static com.analog.lyric.dimple.model.domains.JointDomainReindexer.*;
import static com.analog.lyric.math.Utilities.*;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Map;
import java.util.Objects;
import net.jcip.annotations.NotThreadSafe;
import org.eclipse.jdt.annotation.Nullable;
import cern.colt.map.OpenIntDoubleHashMap;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.BitSetUtil;
import com.analog.lyric.collect.Tuple2;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer.Indices;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.math.Utilities;
import com.analog.lyric.util.misc.Misc;
import com.google.common.math.DoubleMath;
import com.google.common.primitives.Ints;
@NotThreadSafe
public class FactorTable extends SparseFactorTableBase
{
/*-----------
* Constants
*/
private static final long serialVersionUID = 1L;
// _representation values
/**
* If the low-order four-bits are zero, then the table represents a directed deterministic
* function, with non-zero entries indicated by the values in {@link #_sparseIndexToJointIndex}
* but weights and values are not stored explicitly because all the non-zero weights will be one.
* <p>
* May be combined with {@link #SPARSE_INDICES}.
*/
static final int DETERMINISTIC = 0x0;
/**
* If set, the table stores energies in dense representation in {@link #_denseEnergies}.
*/
static final int DENSE_ENERGY = 0x1;
/**
* If set, the table stores weights in dense representation in {@link #_denseEnergies}.
*/
static final int DENSE_WEIGHT = 0x2;
/**
* If set, the table stores energies in sparse representation in {@link #_sparseEnergies}
* and mapping from sparse to joint indices in {@link #_sparseIndexToJointIndex}.
*/
static final int SPARSE_ENERGY = 0x4;
/**
* If set, the table stores weights in sparse representation in {@link #_sparseWeights}
* and mapping from sparse to joint indices in {@link #_sparseIndexToJointIndex}.
*/
static final int SPARSE_WEIGHT = 0x8;
/**
* If set, the table stores mapping from sparse indices to joint table indices in {@link #_sparseIndices}.
*/
static final int SPARSE_INDICES = 0x10;
static final int ALL_DENSE = DENSE_ENERGY | DENSE_WEIGHT;
static final int ALL_SPARSE = SPARSE_ENERGY | SPARSE_WEIGHT;
static final int ALL_WEIGHT = DENSE_WEIGHT | SPARSE_WEIGHT;
static final int ALL_ENERGY = DENSE_ENERGY | SPARSE_ENERGY;
static final int ALL_VALUES = ALL_DENSE | ALL_SPARSE;
static final int ALL_DENSE_WITH_INDICES = ALL_DENSE | SPARSE_INDICES; // Invalid?
static final int ALL_SPARSE_WITH_INDICES = ALL_SPARSE | SPARSE_INDICES;
static final int ALL_WEIGHT_WITH_INDICES = ALL_WEIGHT | SPARSE_INDICES;
static final int ALL_ENERGY_WITH_INDICES = ALL_ENERGY | SPARSE_INDICES;
static final int ALL = ALL_VALUES | SPARSE_INDICES;
static final int SPARSE_ENERGY_DENSE_WEIGHT = SPARSE_ENERGY | DENSE_WEIGHT;
static final int DENSE_ENERGY_SPARSE_WEIGHT = DENSE_ENERGY | SPARSE_WEIGHT;
static final int NOT_SPARSE_WEIGHT = ALL_ENERGY | DENSE_WEIGHT;
static final int NOT_SPARSE_ENERGY = ALL_WEIGHT | DENSE_ENERGY;
static final int NOT_DENSE_WEIGHT = ALL_ENERGY | SPARSE_WEIGHT;
static final int NOT_DENSE_ENERGY = ALL_WEIGHT | SPARSE_ENERGY;
static final int DETERMINISTIC_WITH_INDICES = SPARSE_INDICES;
static final int DENSE_ENERGY_WITH_INDICES = DENSE_ENERGY | SPARSE_INDICES; // Invalid?
static final int DENSE_WEIGHT_WITH_INDICES = DENSE_WEIGHT | SPARSE_INDICES; // Invalid?
static final int SPARSE_ENERGY_WITH_INDICES = SPARSE_ENERGY | SPARSE_INDICES;
static final int SPARSE_WEIGHT_WITH_INDICES = SPARSE_WEIGHT | SPARSE_INDICES;
static final int SPARSE_ENERGY_DENSE_WEIGHT_WITH_INDICES = SPARSE_ENERGY_DENSE_WEIGHT | SPARSE_INDICES;
static final int DENSE_ENERGY_SPARSE_WEIGHT_WITH_INDICES = DENSE_ENERGY_SPARSE_WEIGHT | SPARSE_INDICES;
static final int NOT_SPARSE_WEIGHT_WITH_INDICES = NOT_SPARSE_WEIGHT | SPARSE_INDICES;
static final int NOT_SPARSE_ENERGY_WITH_INDICES = NOT_SPARSE_ENERGY | SPARSE_INDICES;
static final int NOT_DENSE_WEIGHT_WITH_INDICES = NOT_DENSE_WEIGHT | SPARSE_INDICES;
static final int NOT_DENSE_ENERGY_WITH_INDICES = NOT_DENSE_ENERGY | SPARSE_INDICES;
// _computedMask values
/**
* Set if {@link #isDeterministicDirected()} has been invoked since the last time the values or
* representation of the table were changed.
*/
static final int DETERMINISTIC_COMPUTED = 0x10;
/*-------
* State
*/
private double[] _denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
private double[] _denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
/**
* Maps sparse indexes to joint indexes. Empty if table is in dense form
* (because in that case the location and joint index are the same).
* <p>
* If not dense or deterministic (&directed) this lookup will require a
* binary search.
*/
int[] _sparseIndexToJointIndex = ArrayUtil.EMPTY_INT_ARRAY;
/*--------------
* Construction
*/
FactorTable(JointDomainIndexer domains)
{
super(domains);
_representation = SPARSE_ENERGY;
}
/**
* Construct as a copy of another table instance.
*/
FactorTable(FactorTable that)
{
super(that);
_denseEnergies = Objects.requireNonNull(ArrayUtil.cloneArray(that._denseEnergies));
_denseWeights = Objects.requireNonNull(ArrayUtil.cloneArray(that._denseWeights));
_sparseIndexToJointIndex = Objects.requireNonNull(ArrayUtil.cloneArray(that._sparseIndexToJointIndex));
}
/**
* Constructs a new table by converting the contents of {@code other} table using
* {@code converter} whose "from" domains must match {@code other}'s domains.
* New table will have same representation as {@code other}.
*/
FactorTable(IFactorTable other, JointDomainReindexer converter)
{
this(other, converter, other.getRepresentation());
}
/**
* Constructs a new table with given representation by converting the contents of {@code other} table using
* {@code converter} whose "from" domains must match {@code other}'s domains.
*/
FactorTable(IFactorTable other, JointDomainReindexer converter, FactorTableRepresentation representation)
{
super(converter.getToDomains());
final JointDomainIndexer domains = getDomainIndexer();
_representation = representation.mask();
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseIndexToJointIndex = ArrayUtil.EMPTY_INT_ARRAY;
_sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
_computedMask = 0;
//
// Convert using single representation, then switch to desired representation.
//
if (other instanceof FactorTable)
{
FactorTable ft = (FactorTable)other;
if (ft._representation == DETERMINISTIC)
{
_sparseIndexToJointIndex = converter.convertSparseToJointIndex(ft._sparseIndexToJointIndex);
_representation = DETERMINISTIC;
}
else if ((ft._representation & ALL_SPARSE) != 0)
{
_sparseIndexToJointIndex = converter.convertSparseToJointIndex(ft._sparseIndexToJointIndex);
boolean denseSparse = ft.sparseSize() == ft.jointSize();
// FIXME - convert sparse indices
if ((ft._representation & SPARSE_WEIGHT) != 0)
{
_sparseWeights = denseSparse ?
converter.convertDenseWeights(ft._sparseWeights) :
converter.convertSparseWeights(ft._sparseWeights, ft._sparseIndexToJointIndex,
_sparseIndexToJointIndex);
_representation = SPARSE_WEIGHT;
}
else // SPARSE_ENERGY
{
_sparseEnergies = denseSparse ?
converter.convertDenseEnergies(ft._sparseEnergies) :
converter.convertSparseEnergies(ft._sparseEnergies, ft._sparseIndexToJointIndex,
_sparseIndexToJointIndex);
_representation = SPARSE_ENERGY;
}
}
else if ((ft._representation & DENSE_WEIGHT) != 0)
{
_denseWeights = converter.convertDenseWeights(ft._denseWeights);
_representation = DENSE_WEIGHT;
}
else // DENSE_ENERGY
{
_denseEnergies = converter.convertDenseEnergies(ft._denseEnergies);
_representation = DENSE_ENERGY;
}
if (converter.getRemovedDomains() == null)
{
_nonZeroWeights = ft.countNonZeroWeights() * converter.getAddedCardinality();
}
else if (ft.hasMaximumDensity())
{
_nonZeroWeights = domains.getCardinality();
}
else
{
// Need to count them explicitly
computeNonZeroWeights();
}
}
else
{
// SparseFactorTable
// Do initial conversion to SPARSE_WEIGHT representation.
_representation = SPARSE_WEIGHT;
final Indices scratch = converter.getScratch();
final int otherSparseSize = other.sparseSize();
final OpenIntDoubleHashMap jointIndexToWeight = new OpenIntDoubleHashMap(otherSparseSize);
for (int si = 0; si < otherSparseSize; ++si)
{
other.sparseIndexToIndices(si, scratch.fromIndices);
converter.convertIndices(scratch);
if (scratch.toIndices[0] >= 0)
{
int ji = domains.jointIndexFromIndices(scratch.toIndices);
// OpenIntDoubleHashMap.get returns 0.0 if no entry is found, which is exactly what we want here.
jointIndexToWeight.put(ji, jointIndexToWeight.get(ji) + other.getWeightForSparseIndex(si));
}
}
scratch.release();
final int sparseSize = jointIndexToWeight.size();
final int[] sparseToJoint = Arrays.copyOf(jointIndexToWeight.keys().elements(), sparseSize);
Arrays.sort(sparseToJoint);
final double[] sparseWeights = new double[sparseSize];
for (int si = 0; si < sparseSize; ++ si)
{
double weight = jointIndexToWeight.get(sparseToJoint[si]);
if (weight != 0.0)
{
++_nonZeroWeights;
}
sparseWeights[si] = weight;
}
_sparseIndexToJointIndex = sparseToJoint;
_sparseWeights = sparseWeights;
}
setRepresentation(representation);
}
public static IFactorTable create(Object table, DiscreteDomain[] domains)
{
Object [] result = Misc.nDimensionalArray2indicesAndValues(table);
return FactorTable.create((int[][])result[0], (double[])result[1], domains);
}
public static IFactorTable create(int[][] indices, double[] weights, Discrete... variables)
{
DiscreteDomain[] domains = new DiscreteDomain[variables.length];
for(int i = 0; i < domains.length; ++i)
{
domains[i] = variables[i].getDiscreteDomain();
}
return create(indices, weights, domains);
}
public static IFactorTable create(int[][] indices, double[] weights, DiscreteDomain... domains)
{
IFactorTable table = create(domains);
table.setWeightsSparse(indices, weights);
return table;
}
public static IFactorTable create(JointDomainIndexer domains)
{
return domains.supportsJointIndexing() ? new FactorTable(domains) : new SparseFactorTable(domains);
}
public static IFactorTable create(DiscreteDomain... domains)
{
return create(JointDomainIndexer.create(domains));
}
public static IFactorTable create(@Nullable BitSet outputSet, DiscreteDomain ... domains)
{
return create(JointDomainIndexer.create(outputSet, domains));
}
public static IFactorTable create(FactorFunction function, JointDomainIndexer domains)
{
IFactorTable table = create(domains);
table.populateFromFunction(function);
return table;
}
/**
* Create a directed deterministic factor table that marginalizes out the specified subdomain
* of the joint input domain.
* <p>
* @param outputDomainIndex a number in the range [0, inputDomain.getDimensions() - 1] that specifies
* which subdomain to produce as the first dimension and output.
* @param inputDomain will be the second dimension in the table and the input.
* @since 0.05
*/
public static IFactorTable createMarginal(int outputDomainIndex, JointDiscreteDomain<?> inputDomain)
{
final JointDomainIndexer inputDomains = inputDomain.getDomainIndexer();
final DiscreteDomain outputDomain = inputDomains.get(outputDomainIndex);
final int inputSize = inputDomain.size();
final int outputSize = outputDomain.size();
// The joint cardinality of the domains prior to the output subdomain (one if empty)
final int innerCardinality = inputDomains.getStride(outputDomainIndex);
// The joint cardinality of the domains after the output subdomain (one if empty)
final int outerCardinality = inputSize / (innerCardinality * outputSize);
final int[] indices = new int[inputSize];
for (int i = 0, outer = 0; outer < outerCardinality; ++outer)
for (int out = 0; out < outputSize; ++out)
for (int inner = 0; inner < innerCardinality; ++inner)
indices[i++] = out;
// Build the actual table
final IFactorTable table = create(BitSetUtil.bitsetFromIndices(2, 0), outputDomain, inputDomain);
table.setDeterministicOutputIndices(indices);
return table;
}
public static IFactorTable convert(IFactorTable oldTable, JointDomainReindexer converter)
{
if (converter.getToDomains().supportsJointIndexing())
{
return new FactorTable(oldTable, converter);
}
else
{
return new SparseFactorTable(oldTable, converter);
}
}
public static IFactorTable convert(
IFactorTable oldTable,
JointDomainReindexer converter,
FactorTableRepresentation representation)
{
if (converter.getToDomains().supportsJointIndexing())
{
return new FactorTable(oldTable, converter, representation);
}
else
{
return new SparseFactorTable(oldTable, converter, representation);
}
}
/**
* Constructs a new factor table that is the product of all of the given tables.
* <p>
* Invokes {@link #product(ArrayList, FactorTableRepresentation)} with null representation
* argument.
*/
public static @Nullable IFactorTable product(ArrayList<Tuple2<IFactorTable,int[]>> entries)
{
return product(entries, null);
}
/**
* Constructs a new factor table that is the product of all of the given tables.
*
* @param entries maps factor tables to an array of index mappings that indicates where
* each dimension in the table is located in the table to be constructed. Thus the
* index mapping array for a given factor table must have length equal to the factor table's dimensions
* and each entry must be in the range [0,N] where N is one less than the size of the new table. Every
* value from 0 to N must be represented in at least one index mapping array. N can be any number between
* the sum of the number of dimensions of all factors when the factors do not share any dimensions in common,
* or the number of dimensions of the largest factor. When two tables both map a dimension to the same
* dimension in the target table, the domains must match.
* <p>
* @param representation is the representation to use for the table to be constructed. If null, the representation
* will be set to either {@link FactorTableRepresentation#DENSE_ENERGY} or
* {@link FactorTableRepresentation#SPARSE_ENERGY} based on the density of the tables.
* <p>
* @return Newly constructed table.
*/
public static @Nullable IFactorTable product(
ArrayList<Tuple2<IFactorTable,int[]>> entries, @Nullable FactorTableRepresentation representation)
{
final int nFactors = entries.size();
if (nFactors < 1)
{
return null;
}
int nDimensions = 0;
for (int i = 0; i < nFactors; ++i)
{
nDimensions = Math.max(nDimensions, 1 + Ints.max(entries.get(i).second));
}
// Build target domain list and estimate density of new table.
final DiscreteDomain[] toDomains = new DiscreteDomain[nDimensions];
for (Tuple2<IFactorTable, int[]> tuple : entries)
{
final IFactorTable table = tuple.getKey();
final int[] old2New = tuple.getValue();
final JointDomainIndexer tableDomains = table.getDomainIndexer();
if (old2New.length != tableDomains.size())
{
throw new IllegalArgumentException(
String.format("Index mapping for %s does not match table dimensions", table));
}
for (int i = 0, end = tableDomains.size(); i < end; ++i)
{
final int j = old2New[i];
if (j < 0)
{
throw new IllegalArgumentException(
String.format("Negative index mapping for %s", table));
}
final DiscreteDomain domain = tableDomains.get(i);
final DiscreteDomain curDomain = toDomains[j];
if (curDomain == null)
{
toDomains[j] = domain;
}
else if (!curDomain.equals(domain))
{
throw new IllegalArgumentException(
String.format("Conflicting domain mapping for entry %d of index map for %s", j, table));
}
}
}
final JointDomainIndexer toIndexer = JointDomainIndexer.create(toDomains);
final boolean supportsJoint = toIndexer.supportsJointIndexing();
final int toCardinality = supportsJoint ? toIndexer.getCardinality() : -1;
FactorTableRepresentation tableRep = FactorTableRepresentation.SPARSE_ENERGY;
if (supportsJoint)
{
// The sparsest table is going to put an upper bound
// on the sparsity of the final table. If the minimum
// is high enough, we use a dense representation when
// building the table.
int minNonZeroWeights = Integer.MAX_VALUE;
for (int i = 0; i < nFactors; ++i)
{
IFactorTable table = entries.get(i).first;
final int oldNzw = table.countNonZeroWeights();
final int oldCardinality = table.jointSize();
final int newNzw = oldNzw * (toCardinality / oldCardinality);
minNonZeroWeights = Math.min(minNonZeroWeights, newNzw);
}
// If table looks like it will be dense enough, use a dense representation.
if (minNonZeroWeights >= toCardinality * .90)
{
tableRep = FactorTableRepresentation.DENSE_ENERGY;
}
}
IFactorTable newTable = null;
for (Map.Entry<IFactorTable, int[]> entry : entries)
{
final IFactorTable oldTable = entry.getKey();
final int[] old2New = entry.getValue();
// Convert factor table to new format.
final JointDomainIndexer fromIndexer = oldTable.getDomainIndexer();
final JointDomainReindexer converter = JointDomainReindexer.createPermuter(fromIndexer, toIndexer, old2New);
final IFactorTable convertedTable = FactorTable.convert(oldTable, converter, tableRep);
if (newTable == null)
{
newTable = convertedTable;
}
else
{
// Merge results by adding energies (i.e. multiplying weights)
if (tableRep.hasDense())
{
for (int ji = 0; ji < toCardinality; ++ji)
{
double energy = newTable.getEnergyForJointIndex(ji);
energy += convertedTable.getEnergyForJointIndex(ji);
newTable.setEnergyForJointIndex(energy, ji);
}
}
else
{
final int[] indices = toIndexer.allocateIndices(null);
for (int si = 0, end = newTable.sparseSize(); si < end; ++si)
{
double energy = newTable.getEnergyForSparseIndex(si);
newTable.sparseIndexToIndices(si, indices);
energy += convertedTable.getEnergyForIndices(indices);
newTable.setEnergyForIndices(energy, indices);
}
// Compact table if it became more sparse
newTable.compact();
}
}
}
// Convert to target representation
if (representation != null && newTable != null)
{
newTable.setRepresentation(representation);
}
return newTable;
}
/*---------------
* Serialization
*/
/**
* Override the default serialization to decrease size of serialized representation.
*/
private void writeObject(ObjectOutputStream out) throws IOException
{
out.writeInt(_representation);
out.writeInt(_nonZeroWeights);
out.writeInt(_computedMask);
// Choose transmission representation
int transmitRep = -1;
double[] values = null;
if (hasSparseRepresentation())
{
if (hasSparseWeights())
{
transmitRep = SPARSE_WEIGHT;
values = _sparseWeights;
}
else if (hasSparseEnergies())
{
transmitRep = SPARSE_ENERGY;
values = _sparseEnergies;
}
else
{
transmitRep = DETERMINISTIC;
}
}
else
{
if (hasDenseWeights())
{
transmitRep = DENSE_WEIGHT;
values = _denseWeights;
}
else if (hasDenseEnergies())
{
transmitRep = DENSE_ENERGY;
values = _denseEnergies;
}
}
assert(transmitRep >= 0);
out.writeInt(transmitRep);
if (values != null)
{
out.writeInt(values.length);
for (double d : values)
{
out.writeDouble(d);
}
}
else
{
out.writeInt(sparseSize());
}
if (_sparseIndexToJointIndex.length < jointSize())
{
for (int i : _sparseIndexToJointIndex)
{
out.writeInt(i);
}
}
}
private void readObject(ObjectInputStream in) throws IOException
{
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseIndexToJointIndex = ArrayUtil.EMPTY_INT_ARRAY;
_sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
int representation = in.readInt();
_nonZeroWeights = in.readInt();
_computedMask = in.readInt();
_representation = in.readInt();
int size = in.readInt();
double[] values = ArrayUtil.EMPTY_DOUBLE_ARRAY;
if (!hasDeterministicRepresentation())
{
values = new double[size];
for (int i = 0; i < size; ++i)
{
values[i] = in.readDouble();
}
}
if (hasSparseRepresentation() && size < jointSize())
{
// Read sparse to joint index map
_sparseIndexToJointIndex = new int[size];
for (int si = 0; si < size; ++si)
{
_sparseIndexToJointIndex[si] = in.readInt();
}
}
switch (_representation)
{
case DETERMINISTIC:
break;
case DENSE_ENERGY:
_denseEnergies = values;
break;
case DENSE_WEIGHT:
_denseWeights = values;
break;
case SPARSE_ENERGY:
_sparseEnergies = values;
break;
case SPARSE_WEIGHT:
_sparseWeights = values;
break;
default:
assert(false);
break;
}
// Switch to the real representation.
setRepresentation(representation);
}
/*----------------
* Object methods
*/
@Override
public FactorTable clone()
{
return new FactorTable(this);
}
/*-----------------------------
* INewFactorTableBase methods
*/
@Override
public void evalDeterministic(Value[] arguments)
{
if (!isDeterministicDirected())
{
throw new DimpleException("Table is not deterministic");
}
final JointDomainIndexer domains = getDomainIndexer();
int outputSize = domains.getOutputCardinality();
int inputIndex = domains.inputIndexFromValues(arguments);
int jointIndex = _sparseIndexToJointIndex[inputIndex];
int outputIndex = jointIndex - inputIndex * outputSize;
domains.outputIndexToValues(outputIndex, arguments);
}
@Override
public final double getEnergyForIndicesDense(int ... indices)
{
return _denseEnergies[getDomainIndexer().jointIndexFromIndices(indices)];
}
@Override
public final double getEnergyForValuesDense(Value ... values)
{
return _denseEnergies[getDomainIndexer().jointIndexFromValues(values)];
}
@Override
public final double getWeightForIndicesDense(int ... indices)
{
return _denseWeights[getDomainIndexer().jointIndexFromIndices(indices)];
}
@Override
public final double getWeightForValuesDense(Value ... values)
{
return _denseWeights[getDomainIndexer().jointIndexFromValues(values)];
}
@Override
public final double getEnergyForJointIndex(int jointIndex)
{
int sparseIndex;
// Optimized for speed. Using a single switch instead of multiple if/else's ensures
// there is only a single branching instruction. Switching over the enum may allow the
// JIT compiler to infer that an unconditional branch may be used. This code assumes
// that converting from weight to energy is cheaper than looking up the sparse index
// from the joint one.
switch (_representation)
{
case DETERMINISTIC:
case DETERMINISTIC_WITH_INDICES:
final int expectedJoint = _sparseIndexToJointIndex[getDomainIndexer().inputIndexFromJointIndex(jointIndex)];
return expectedJoint == jointIndex ? 0.0 : Double.POSITIVE_INFINITY;
case ALL_VALUES:
case ALL_DENSE:
case ALL_ENERGY:
case DENSE_ENERGY:
case NOT_DENSE_WEIGHT:
case NOT_SPARSE_ENERGY:
case DENSE_ENERGY_SPARSE_WEIGHT:
case NOT_SPARSE_WEIGHT:
case ALL:
case ALL_DENSE_WITH_INDICES:
case ALL_ENERGY_WITH_INDICES:
case DENSE_ENERGY_WITH_INDICES:
case NOT_DENSE_WEIGHT_WITH_INDICES:
case NOT_SPARSE_ENERGY_WITH_INDICES:
case DENSE_ENERGY_SPARSE_WEIGHT_WITH_INDICES:
case NOT_SPARSE_WEIGHT_WITH_INDICES:
return _denseEnergies[jointIndex];
case ALL_WEIGHT:
case DENSE_WEIGHT:
case NOT_DENSE_ENERGY:
case SPARSE_ENERGY_DENSE_WEIGHT:
case ALL_WEIGHT_WITH_INDICES:
case DENSE_WEIGHT_WITH_INDICES:
case NOT_DENSE_ENERGY_WITH_INDICES:
case SPARSE_ENERGY_DENSE_WEIGHT_WITH_INDICES:
return Utilities.weightToEnergy(_denseWeights[jointIndex]);
case ALL_SPARSE:
case SPARSE_ENERGY:
case ALL_SPARSE_WITH_INDICES:
case SPARSE_ENERGY_WITH_INDICES:
sparseIndex = sparseIndexFromJointIndex(jointIndex);
if (sparseIndex >= 0)
{
return _sparseEnergies[sparseIndex];
}
break;
case SPARSE_WEIGHT:
case SPARSE_WEIGHT_WITH_INDICES:
sparseIndex = sparseIndexFromJointIndex(jointIndex);
if (sparseIndex >= 0)
{
return Utilities.weightToEnergy(_sparseWeights[sparseIndex]);
}
break;
}
return Double.POSITIVE_INFINITY;
}
@Override
public final double getEnergyForSparseIndex(int sparseIndex)
{
switch (_representation)
{
case DETERMINISTIC:
case DETERMINISTIC_WITH_INDICES:
return 0.0;
case ALL_DENSE:
case DENSE_ENERGY:
case ALL_DENSE_WITH_INDICES:
case DENSE_ENERGY_WITH_INDICES:
setRepresentation(_representation | SPARSE_ENERGY);
// $FALL-THROUGH$
case ALL_VALUES:
case ALL_ENERGY:
case ALL_SPARSE:
case NOT_DENSE_WEIGHT:
case NOT_DENSE_ENERGY:
case SPARSE_ENERGY:
case NOT_SPARSE_WEIGHT:
case SPARSE_ENERGY_DENSE_WEIGHT:
case ALL:
case ALL_ENERGY_WITH_INDICES:
case ALL_SPARSE_WITH_INDICES:
case NOT_DENSE_WEIGHT_WITH_INDICES:
case NOT_DENSE_ENERGY_WITH_INDICES:
case SPARSE_ENERGY_WITH_INDICES:
case NOT_SPARSE_WEIGHT_WITH_INDICES:
case SPARSE_ENERGY_DENSE_WEIGHT_WITH_INDICES:
return _sparseEnergies[sparseIndex];
case DENSE_WEIGHT:
case DENSE_WEIGHT_WITH_INDICES:
setRepresentation(_representation | SPARSE_WEIGHT);
// $FALL-THROUGH$
case ALL_WEIGHT:
case NOT_SPARSE_ENERGY:
case DENSE_ENERGY_SPARSE_WEIGHT:
case SPARSE_WEIGHT:
case ALL_WEIGHT_WITH_INDICES:
case NOT_SPARSE_ENERGY_WITH_INDICES:
case DENSE_ENERGY_SPARSE_WEIGHT_WITH_INDICES:
case SPARSE_WEIGHT_WITH_INDICES:
return weightToEnergy(_sparseWeights[sparseIndex]);
}
return Double.POSITIVE_INFINITY;
}
@Override
public final double getWeightForJointIndex(int jointIndex)
{
int sparseIndex;
// See comment in getEnergyForJointIndex
switch (_representation)
{
case DETERMINISTIC:
case DETERMINISTIC_WITH_INDICES:
final int expectedJoint = _sparseIndexToJointIndex[jointIndex / getDomainIndexer().getOutputCardinality()];
return expectedJoint == jointIndex ? 1.0 : 0.0;
case ALL_VALUES:
case ALL_DENSE:
case ALL_WEIGHT:
case NOT_SPARSE_ENERGY:
case DENSE_WEIGHT:
case NOT_DENSE_ENERGY:
case NOT_SPARSE_WEIGHT:
case SPARSE_ENERGY_DENSE_WEIGHT:
case ALL:
case ALL_DENSE_WITH_INDICES:
case ALL_WEIGHT_WITH_INDICES:
case NOT_SPARSE_ENERGY_WITH_INDICES:
case DENSE_WEIGHT_WITH_INDICES:
case NOT_DENSE_ENERGY_WITH_INDICES:
case NOT_SPARSE_WEIGHT_WITH_INDICES:
case SPARSE_ENERGY_DENSE_WEIGHT_WITH_INDICES:
return _denseWeights[jointIndex];
case ALL_ENERGY:
case DENSE_ENERGY:
case NOT_DENSE_WEIGHT:
case DENSE_ENERGY_SPARSE_WEIGHT:
case ALL_ENERGY_WITH_INDICES:
case DENSE_ENERGY_WITH_INDICES:
case NOT_DENSE_WEIGHT_WITH_INDICES:
case DENSE_ENERGY_SPARSE_WEIGHT_WITH_INDICES:
return energyToWeight(_denseEnergies[jointIndex]);
case ALL_SPARSE:
case SPARSE_WEIGHT:
case ALL_SPARSE_WITH_INDICES:
case SPARSE_WEIGHT_WITH_INDICES:
sparseIndex = sparseIndexFromJointIndex(jointIndex);
if (sparseIndex >= 0)
{
return _sparseWeights[sparseIndex];
}
break;
case SPARSE_ENERGY:
case SPARSE_ENERGY_WITH_INDICES:
sparseIndex = sparseIndexFromJointIndex(jointIndex);
if (sparseIndex >= 0)
{
return energyToWeight(_sparseEnergies[sparseIndex]);
}
break;
}
return 0.0;
}
@Override
public final double getWeightForSparseIndex(int sparseIndex)
{
switch (_representation)
{
case DETERMINISTIC:
case DETERMINISTIC_WITH_INDICES:
return 1.0;
case ALL_DENSE:
case DENSE_WEIGHT:
case ALL_DENSE_WITH_INDICES:
case DENSE_WEIGHT_WITH_INDICES:
setRepresentation(_representation | SPARSE_WEIGHT);
return _sparseWeights[sparseIndex];
// $FALL-THROUGH$
case ALL_VALUES:
case ALL_WEIGHT:
case ALL_SPARSE:
case NOT_SPARSE_ENERGY:
case NOT_DENSE_WEIGHT:
case DENSE_ENERGY_SPARSE_WEIGHT:
case NOT_DENSE_ENERGY:
case SPARSE_WEIGHT:
case ALL:
case ALL_WEIGHT_WITH_INDICES:
case ALL_SPARSE_WITH_INDICES:
case NOT_SPARSE_ENERGY_WITH_INDICES:
case NOT_DENSE_WEIGHT_WITH_INDICES:
case DENSE_ENERGY_SPARSE_WEIGHT_WITH_INDICES:
case NOT_DENSE_ENERGY_WITH_INDICES:
case SPARSE_WEIGHT_WITH_INDICES:
return _sparseWeights[sparseIndex];
case DENSE_ENERGY:
case DENSE_ENERGY_WITH_INDICES:
setRepresentation(_representation | SPARSE_ENERGY);
// $FALL-THROUGH$
return energyToWeight(_sparseEnergies[sparseIndex]);
case ALL_ENERGY:
case SPARSE_ENERGY:
case NOT_SPARSE_WEIGHT:
case SPARSE_ENERGY_DENSE_WEIGHT:
case ALL_ENERGY_WITH_INDICES:
case SPARSE_ENERGY_WITH_INDICES:
case NOT_SPARSE_WEIGHT_WITH_INDICES:
case SPARSE_ENERGY_DENSE_WEIGHT_WITH_INDICES:
return energyToWeight(_sparseEnergies[sparseIndex]);
}
return 0.0;
}
@Override
public final boolean hasDenseRepresentation()
{
return (_representation & ALL_DENSE) != 0;
}
@Override
public final boolean hasDenseEnergies()
{
return (_representation & DENSE_ENERGY) != 0;
}
@Override
public final boolean hasDenseWeights()
{
return (_representation & DENSE_WEIGHT) != 0;
}
@Override
public boolean hasMaximumDensity()
{
return _nonZeroWeights == jointSize();
}
@Override
public final boolean hasSparseRepresentation()
{
switch (_representation)
{
case DENSE_ENERGY:
case DENSE_WEIGHT:
case ALL_DENSE:
case DENSE_ENERGY_WITH_INDICES:
case DENSE_WEIGHT_WITH_INDICES:
case ALL_DENSE_WITH_INDICES:
return false;
default:
return true;
}
}
@Override
public final boolean isConditional()
{
if ((_computedMask & CONDITIONAL_COMPUTED) == 0)
{
if (isDirected())
{
normalizeDirected(true, false);
}
if ((_computedMask & CONDITIONAL) == 0)
{
// If its not conditional, it cannot be deterministic directed.
_computedMask |= DETERMINISTIC_COMPUTED;
}
_computedMask |= CONDITIONAL_COMPUTED;
}
return (_computedMask & CONDITIONAL) != 0;
}
@Override
public boolean isDeterministicDirected()
{
if ((_representation & ALL_VALUES) == DETERMINISTIC)
{
return true;
}
if ((_computedMask & DETERMINISTIC_COMPUTED) != 0)
{
return false;
}
boolean deterministic = false;
final JointDomainIndexer domains = getDomainIndexer();
if (isDirected() && _nonZeroWeights == domains.getInputCardinality())
{
// Table can only be deterministic if there is exactly one
// valid output for each possible input and all outputs have the
// same weight.
final int[] sparseToJoint = computeSparseToJointIndexMap();
deterministic = true;
final int outputSize = domains.getOutputCardinality();
int prevInputIndex = -1;
for (int joint : sparseToJoint)
{
int inputIndex = joint / outputSize;
if (inputIndex == prevInputIndex)
{
deterministic = false;
break;
}
prevInputIndex = inputIndex;
}
if (deterministic && (_computedMask & CONDITIONAL) == 0)
{
// Ensure that weights are the same. No need to do this if CONDITIONAL.
final double tolerance = 1e-12;
if (hasSparseEnergies())
{
deterministic = ArrayUtil.allFuzzyEqual(_sparseEnergies, tolerance);
}
else if (hasSparseWeights())
{
deterministic = ArrayUtil.allFuzzyEqual(_sparseWeights, tolerance);
}
else if (hasDenseEnergies())
{
deterministic = ArrayUtil.subsetFuzzyEqual(_denseEnergies, sparseToJoint, tolerance);
}
else
{
deterministic = ArrayUtil.subsetFuzzyEqual(_denseWeights, sparseToJoint, tolerance);
}
}
if (deterministic)
{
_sparseIndexToJointIndex = sparseToJoint;
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_representation = DETERMINISTIC | (_representation & SPARSE_INDICES);
// deterministic directed is a special case of conditional
_computedMask |= CONDITIONAL|CONDITIONAL_COMPUTED;
}
}
_computedMask |= DETERMINISTIC_COMPUTED;
return deterministic;
}
@Override
public final void setEnergyForJointIndex(double energy, int jointIndex)
{
final double prevEnergy = getEnergyForJointIndex(jointIndex);
if (prevEnergy != energy)
{
_computedMask = 0;
if ((_representation & ALL_VALUES) == DETERMINISTIC)
{
// If we have sparse indices, then presumably a sparse representation is still wanted.
setRepresentation(hasSparseIndices() ? ALL_ENERGY_WITH_INDICES : DENSE_ENERGY);
_denseEnergies[jointIndex] = energy;
}
else
{
double weight = (_representation & ALL_WEIGHT) == 0 ? 0.0 : Utilities.energyToWeight(energy);
setWeightEnergyForJointIndex(weight, energy, jointIndex);
}
if (Double.isInfinite(prevEnergy))
{
++_nonZeroWeights;
}
else if (Double.isInfinite(energy))
{
--_nonZeroWeights;
}
}
}
@Override
public void setEnergyForSparseIndex(double energy, int sparseIndex)
{
final double prevEnergy = getEnergyForSparseIndex(sparseIndex);
if (prevEnergy != energy)
{
_computedMask = 0;
if ((_representation & ALL_VALUES) == DETERMINISTIC)
{
setRepresentation(_representation | SPARSE_ENERGY);
_sparseEnergies[sparseIndex] = energy;
}
else
{
double weight = (_representation & ALL_WEIGHT) == 0 ? 0.0 : energyToWeight(energy);
setWeightEnergyForSparseIndex(weight, energy, sparseIndex);
}
if (Double.isInfinite(prevEnergy))
{
++_nonZeroWeights;
}
else if (Double.isInfinite(energy))
{
--_nonZeroWeights;
}
}
}
@Override
public void setWeightForJointIndex(double weight, int jointIndex)
{
final double prevWeight = getWeightForJointIndex(jointIndex);
if (prevWeight != weight)
{
_computedMask = 0;
if ((_representation & ALL_VALUES) == DETERMINISTIC)
{
// If we have sparse indices, then presumably a sparse representation is still wanted.
setRepresentation(hasSparseIndices() ? ALL_WEIGHT_WITH_INDICES : DENSE_WEIGHT);
_denseWeights[jointIndex] = weight;
}
else
{
double energy = (_representation & ALL_ENERGY) == 0.0 ? 0.0 : weightToEnergy(weight);
setWeightEnergyForJointIndex(weight, energy, jointIndex);
}
if (prevWeight == 0.0)
{
++_nonZeroWeights;
}
else if (weight == 0.0)
{
--_nonZeroWeights;
}
}
}
@Override
public void setWeightForSparseIndex(double weight, int sparseIndex)
{
final double prevWeight = getWeightForSparseIndex(sparseIndex);
if (prevWeight != weight)
{
_computedMask = 0;
if ((_representation & ALL_VALUES) == DETERMINISTIC)
{
setRepresentation(_representation | SPARSE_WEIGHT);
_sparseWeights[sparseIndex] = weight;
}
else
{
double energy = (_representation & ALL_ENERGY) == 0 ? 0.0 : weightToEnergy(weight);
setWeightEnergyForSparseIndex(weight, energy, sparseIndex);
}
if (prevWeight == 0.0)
{
++_nonZeroWeights;
}
else if (weight == 0.0)
{
--_nonZeroWeights;
}
}
}
@Override
public int[] sparseIndexToIndices(int sparseIndex, @Nullable int[] indices)
{
JointDomainIndexer indexer = getDomainIndexer();
if ((_representation & SPARSE_INDICES) != 0)
{
indices = indexer.allocateIndices(indices);
System.arraycopy(_sparseIndices[sparseIndex], 0, indices, 0, indices.length);
}
else
{
indices = indexer.jointIndexToIndices(sparseIndexToJointIndex(sparseIndex), indices);
}
return indices;
}
@Override
public final int sparseIndexToJointIndex(int sparseIndex)
{
if (!hasSparseRepresentation())
{
if (hasDenseWeights())
{
setRepresentation(_representation | SPARSE_WEIGHT);
}
else
{
setRepresentation(_representation | SPARSE_ENERGY);
}
}
if (sparseIndex < _sparseIndexToJointIndex.length)
{
sparseIndex = _sparseIndexToJointIndex[sparseIndex];
}
return sparseIndex;
}
@Override
public final int sparseIndexFromJointIndex(int jointIndex)
{
if (sparseSize() == jointSize())
{
return jointIndex;
}
int sparseIndex = jointIndex;
switch (_representation & ALL_VALUES)
{
case DETERMINISTIC:
// Optimize deterministic case. Since there is exactly one entry per distinct
// set of outputs, we can simply check to see if the jointIndex is found at
// the corresponding location for the output indices.
sparseIndex /= getDomainIndexer().getOutputCardinality();
final int prevJointIndex = _sparseIndexToJointIndex[sparseIndex];
if (prevJointIndex != jointIndex)
{
if (jointIndex > prevJointIndex)
{
++sparseIndex;
}
sparseIndex = -1-sparseIndex;
}
break;
case ALL_DENSE:
case DENSE_ENERGY:
case DENSE_WEIGHT:
setRepresentation(_representation | SPARSE_ENERGY);
// $FALL-THROUGH$
case ALL_VALUES:
case ALL_SPARSE:
case ALL_ENERGY:
case ALL_WEIGHT:
case DENSE_ENERGY_SPARSE_WEIGHT:
case NOT_DENSE_ENERGY:
case NOT_DENSE_WEIGHT:
case NOT_SPARSE_ENERGY:
case NOT_SPARSE_WEIGHT:
case SPARSE_ENERGY:
case SPARSE_ENERGY_DENSE_WEIGHT:
case SPARSE_WEIGHT:
sparseIndex = Arrays.binarySearch(_sparseIndexToJointIndex, jointIndex);
break;
}
return sparseIndex;
}
@Override
public final int sparseSize()
{
switch (_representation)
{
case DETERMINISTIC:
case DETERMINISTIC_WITH_INDICES:
return _sparseIndexToJointIndex.length;
case DENSE_ENERGY:
case DENSE_WEIGHT:
case ALL_DENSE:
case DENSE_ENERGY_WITH_INDICES:
case DENSE_WEIGHT_WITH_INDICES:
case ALL_DENSE_WITH_INDICES:
return _nonZeroWeights;
case SPARSE_ENERGY:
case ALL_ENERGY:
case SPARSE_ENERGY_DENSE_WEIGHT:
case NOT_SPARSE_WEIGHT:
case ALL_SPARSE:
case NOT_DENSE_WEIGHT:
case NOT_DENSE_ENERGY:
case ALL_VALUES:
case SPARSE_ENERGY_WITH_INDICES:
case ALL_ENERGY_WITH_INDICES:
case SPARSE_ENERGY_DENSE_WEIGHT_WITH_INDICES:
case NOT_SPARSE_WEIGHT_WITH_INDICES:
case ALL_SPARSE_WITH_INDICES:
case NOT_DENSE_WEIGHT_WITH_INDICES:
case NOT_DENSE_ENERGY_WITH_INDICES:
case ALL:
return _sparseEnergies.length;
case SPARSE_WEIGHT:
case DENSE_ENERGY_SPARSE_WEIGHT:
case ALL_WEIGHT:
case NOT_SPARSE_ENERGY:
case SPARSE_WEIGHT_WITH_INDICES:
case DENSE_ENERGY_SPARSE_WEIGHT_WITH_INDICES:
case ALL_WEIGHT_WITH_INDICES:
case NOT_SPARSE_ENERGY_WITH_INDICES:
return _sparseWeights.length;
}
return 0;
}
/*--------------------------
* INewFactorTable methods
*/
@Override
public int compact()
{
int nRemoved = 0;
if ((_representation & ALL_SPARSE) != 0)
{
final int curSparseSize = sparseSize();
if (curSparseSize > _nonZeroWeights)
{
nRemoved = curSparseSize - _nonZeroWeights;
final boolean wasDense = curSparseSize == jointSize();
final int[] sparseToJoint = new int[_nonZeroWeights];
final boolean hasEnergy = hasSparseEnergies();
final double[] sparseEnergies = hasEnergy ? new double[_nonZeroWeights] : ArrayUtil.EMPTY_DOUBLE_ARRAY;
final boolean hasWeight = hasSparseWeights();
final double[] sparseWeights = hasWeight ? new double[_nonZeroWeights] : ArrayUtil.EMPTY_DOUBLE_ARRAY;
final boolean hasIndices = hasSparseIndices();
final int[][] sparseIndices = hasIndices ? new int[_nonZeroWeights][] : ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
if (hasWeight)
{
for (int i = 0, j = 0; i < curSparseSize; ++i)
{
double w = _sparseWeights[i];
if (w != 0.0)
{
sparseWeights[j] = w;
if (hasEnergy)
{
sparseEnergies[j] = _sparseEnergies[i];
}
if (hasIndices)
{
sparseIndices[j] = _sparseIndices[i];
}
sparseToJoint[j] = wasDense? i : _sparseIndexToJointIndex[i];
++j;
}
}
}
else
{
for (int i = 0, j = 0; i < curSparseSize; ++i)
{
double e = _sparseEnergies[i];
if (!Double.isInfinite(e))
{
sparseEnergies[j] = e;
sparseToJoint[j] = wasDense? i : _sparseIndexToJointIndex[i];
if (hasIndices)
{
sparseIndices[j] = _sparseIndices[i];
}
++j;
}
}
}
_sparseEnergies = sparseEnergies;
_sparseWeights = sparseWeights;
_sparseIndexToJointIndex = sparseToJoint;
_sparseIndices = sparseIndices;
}
}
return nRemoved;
}
@Override
public final double[] getEnergiesSparseUnsafe()
{
if (_sparseEnergies.length == 0 && !hasSparseEnergies())
{
if (hasDeterministicRepresentation())
{
_sparseEnergies = new double[getDomainIndexer().getInputCardinality()];
}
else
{
setRepresentation(_representation | SPARSE_ENERGY);
}
}
return _sparseEnergies;
}
@Override
public double[] getEnergiesDenseUnsafe()
{
if (!hasDenseEnergies())
{
setRepresentation(DENSE_ENERGY);
}
return _denseEnergies;
}
@Override
public final double[] getEnergySlice(@Nullable double[] slice, int sliceDimension, int ... indices)
{
JointDomainIndexer indexer = getDomainIndexer();
final int savedIndex = indices[sliceDimension];
indices[sliceDimension] = 0;
final int start = indexer.jointIndexFromIndices(indices);
indices[sliceDimension] = savedIndex;
return getEnergySliceImpl(slice, sliceDimension, start);
}
@Override
public final double[] getEnergySlice(@Nullable double[] slice, int sliceDimension, Value ... values)
{
JointDomainIndexer indexer = getDomainIndexer();
final int savedIndex = values[sliceDimension].getIndex();
values[sliceDimension].setIndex(0);
final int start = indexer.jointIndexFromValues(values);
values[sliceDimension].setIndex(savedIndex);
return getEnergySliceImpl(slice, sliceDimension, start);
}
private final double[] getEnergySliceImpl(@Nullable double[] slice, int sliceDimension, int start)
{
final JointDomainIndexer indexer = getDomainIndexer();
final int size = indexer.getDomainSize(sliceDimension);
final int stride = indexer.getStride(sliceDimension);
if (slice == null || slice.length < size)
{
slice = new double[size];
}
if (hasDenseEnergies())
{
for (int i = 0, ji = start; i < size; ++i, ji += stride)
{
slice[i] = _denseEnergies[ji];
}
}
else
{
for (int i = 0, ji = start; i < size; ++i, ji += stride)
{
slice[i] = getEnergyForJointIndex(ji);
}
}
return slice;
}
@Override
public final double[] getWeightSlice(@Nullable double[] slice, int sliceDimension, int ... indices)
{
JointDomainIndexer indexer = getDomainIndexer();
final int savedIndex = indices[sliceDimension];
indices[sliceDimension] = 0;
final int start = indexer.jointIndexFromIndices(indices);
indices[sliceDimension] = savedIndex;
return getWeightSliceImpl(slice, sliceDimension, start);
}
@Override
public final double[] getWeightSlice(@Nullable double[] slice, int sliceDimension, Value ... values)
{
JointDomainIndexer indexer = getDomainIndexer();
final int savedIndex = values[sliceDimension].getIndex();
values[sliceDimension].setIndex(0);
final int start = indexer.jointIndexFromValues(values);
values[sliceDimension].setIndex(savedIndex);
return getEnergySliceImpl(slice, sliceDimension, start);
}
private final double[] getWeightSliceImpl(@Nullable double[] slice, int sliceDimension, int start)
{
final JointDomainIndexer indexer = getDomainIndexer();
final int size = indexer.getDomainSize(sliceDimension);
final int stride = indexer.getStride(sliceDimension);
if (slice == null || slice.length < size)
{
slice = new double[size];
}
if (hasDenseWeights())
{
for (int i = 0, ji = start; i < size; ++i, ji += stride)
{
slice[i] = _denseWeights[ji];
}
}
else
{
for (int i = 0, ji = start; i < size; ++i, ji += stride)
{
slice[i] = getWeightForJointIndex(ji);
}
}
return slice;
}
@Override
public final double[] getWeightsSparseUnsafe()
{
if (_sparseWeights.length == 0 && !hasSparseWeights())
{
if (hasDeterministicRepresentation())
{
_sparseWeights = new double[getDomainIndexer().getInputCardinality()];
Arrays.fill(_sparseWeights, 1.0);
}
else
{
setRepresentation(_representation | SPARSE_WEIGHT);
}
}
return _sparseWeights;
}
@Override
public final double[] getWeightsDenseUnsafe()
{
if (!hasDenseWeights())
{
setRepresentation(DENSE_WEIGHT);
}
return _denseWeights;
}
@Override
public final int[][] getIndicesSparseUnsafe()
{
if (!hasSparseIndices())
{
if (hasSparseRepresentation())
{
setRepresentation(_representation | SPARSE_INDICES);
}
else if (hasDenseWeights())
{
setRepresentation(_representation | SPARSE_WEIGHT_WITH_INDICES);
}
else
{
setRepresentation(_representation | SPARSE_ENERGY_WITH_INDICES);
}
}
return _sparseIndices;
}
@Override
public boolean hasDeterministicRepresentation()
{
return (_representation & ALL_VALUES) == DETERMINISTIC;
}
@Override
public void setEnergiesDense(double[] energies)
{
setDenseValues(energies, DENSE_ENERGY);
}
@Override
public void setWeightsDense(double[] weights)
{
setDenseValues(weights, DENSE_WEIGHT);
}
@Override
public void setDeterministicOutputIndices(int[] outputIndices)
{
_function = null;
final JointDomainIndexer domains = getDomainIndexer();
final int size = domains.getInputCardinality();
if (!isDirected())
{
throw new UnsupportedOperationException(
"'setDeterministicOuputIndices' not supported on non-directed table");
}
if (size != outputIndices.length)
{
throw new IllegalArgumentException(
String.format("'ouputIndices' array length %d does not match size of possible inputs %d",
outputIndices.length, size));
}
int[] sparseToJoint = new int[size];
final int outputCardinality = domains.getOutputCardinality();
for (int inputIndex = 0; inputIndex < size; ++inputIndex)
{
final int outputIndex = outputIndices[inputIndex];
if (outputIndex < 0 || outputIndex >= outputCardinality)
{
throw new IllegalArgumentException(String.format("Output index %d is out of range", outputIndex));
}
sparseToJoint[inputIndex] = domains.jointIndexFromInputOutputIndices(inputIndex, outputIndex);
}
_sparseIndexToJointIndex = sparseToJoint;
_representation = DETERMINISTIC;
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_nonZeroWeights = size;
_computedMask = NORMALIZED | DETERMINISTIC_COMPUTED;
}
@Override
public void setEnergiesSparse(int[] jointIndices, double[] energies)
{
setSparseValues(jointIndices, energies, SPARSE_ENERGY);
}
@Override
public void setWeightsSparse(int[] jointIndices, double[] weights)
{
setSparseValues(jointIndices, weights, SPARSE_WEIGHT);
}
@Override
void setRepresentation(int newRep)
{
int oldRep = _representation;
if (oldRep == newRep)
{
return;
}
//
// Disallow non-sparse rep combined with sparse indices
//
if ((newRep & ALL_DENSE_WITH_INDICES) == newRep && (newRep & ALL_VALUES) != DETERMINISTIC)
{
newRep &= ALL_DENSE;
}
//
// Special cases for deterministic conversions
//
if ((newRep & ALL_VALUES) == DETERMINISTIC)
{
if (!isDeterministicDirected())
{
throw new DimpleException("Cannot set representation to DETERMINISTIC*");
}
if ((newRep & SPARSE_INDICES) != 0)
{
_sparseIndices = computeSparseIndices();
}
else
{
_sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
}
_representation = newRep;
return;
}
final int jointSize = jointSize();
if ((oldRep & ALL_VALUES) == DETERMINISTIC)
{
if ((newRep & SPARSE_WEIGHT) != 0)
{
_sparseWeights = new double[sparseSize()];
Arrays.fill(_sparseWeights, 1.0);
}
if ((newRep & SPARSE_ENERGY) != 0)
{
_sparseEnergies = new double[sparseSize()];
}
if ((newRep & DENSE_WEIGHT) != 0)
{
_denseWeights = new double[jointSize];
for (int ji : _sparseIndexToJointIndex)
{
_denseWeights[ji] = 1.0;
}
}
if ((newRep & DENSE_ENERGY) != 0)
{
_denseEnergies = new double[jointSize];
Arrays.fill(_denseEnergies, Double.POSITIVE_INFINITY);
for (int ji : _sparseIndexToJointIndex)
{
_denseEnergies[ji] = 0.0;
}
}
if ((newRep & SPARSE_INDICES) != 0 && (oldRep & SPARSE_INDICES) == 0)
{
_sparseIndices = getIndicesSparseUnsafe();
}
if ((newRep & ALL_SPARSE) == 0)
{
_sparseIndexToJointIndex = ArrayUtil.EMPTY_INT_ARRAY;
}
if ((newRep & SPARSE_INDICES) == 0)
{
_sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
}
_representation = newRep;
_computedMask = 0;
return;
}
//
// Dense-to-sparse conversion
//
int diff = newRep & ~oldRep;
if ((diff & ALL_SPARSE) != 0 && (oldRep & ALL_SPARSE) == 0)
{
if (_nonZeroWeights == jointSize)
{
// sparse == dense
// dense == sparse, use same arrays if possible
if ((diff & SPARSE_WEIGHT) != 0 && (oldRep & DENSE_WEIGHT) != 0)
{
_sparseWeights = _denseWeights;
oldRep |= SPARSE_WEIGHT;
}
if ((diff & SPARSE_ENERGY) != 0 && (oldRep & DENSE_ENERGY) != 0)
{
_sparseEnergies = _denseEnergies;
oldRep |= SPARSE_ENERGY;
}
diff = newRep & ~oldRep;
}
if ((diff & ALL_SPARSE) != 0)
{
final int[] sparseToJoint = _sparseIndexToJointIndex = computeSparseToJointIndexMap();
final int sparseSize = sparseToJoint.length;
if ((diff & SPARSE_WEIGHT) != 0)
{
final double[] sparseWeights = _sparseWeights = new double[sparseSize];
if ((oldRep & DENSE_WEIGHT) != 0)
{
final double[] denseWeights = _denseWeights;
for (int si = sparseSize; --si>=0;)
{
sparseWeights[si] = denseWeights[sparseToJoint[si]];
}
}
else
{
final double[] denseEnergies = _denseEnergies;
for (int si = sparseSize; --si>=0;)
{
sparseWeights[si] = Utilities.energyToWeight(denseEnergies[sparseToJoint[si]]);
}
}
oldRep |= SPARSE_WEIGHT;
}
if ((diff & SPARSE_ENERGY) != 0)
{
final double[] sparseEnergies = _sparseEnergies = new double[sparseToJoint.length];
if ((oldRep & DENSE_ENERGY) != 0)
{
final double[] denseEnergies = _denseEnergies;
for (int si = sparseSize; --si>=0;)
{
sparseEnergies[si] = denseEnergies[sparseToJoint[si]];
}
}
else
{
final double[] denseWeights = _denseWeights;
for (int si = sparseSize; --si>=0;)
{
sparseEnergies[si] = Utilities.weightToEnergy(denseWeights[sparseToJoint[si]]);
}
}
oldRep |= SPARSE_ENERGY;
}
}
}
final int[] sparseToJoint = _sparseIndexToJointIndex;
//
// Compute sparse indices
//
if ((newRep & SPARSE_INDICES) != 0 && (oldRep & SPARSE_INDICES) == 0)
{
_sparseIndices = computeSparseIndices();
oldRep |= SPARSE_INDICES;
}
//
// Sparse-to-sparse conversions
//
diff &= ~oldRep;
if ((diff & ALL_SPARSE) != 0)
{
if ((diff & SPARSE_ENERGY) != 0 & (oldRep & SPARSE_WEIGHT) != 0)
{
final double[] sparseWeights = _sparseWeights;
final double[] sparseEnergies = _sparseEnergies = new double[sparseWeights.length];
for (int i = sparseWeights.length; --i >= 0;)
{
sparseEnergies[i] = Utilities.weightToEnergy(sparseWeights[i]);
}
oldRep |= SPARSE_ENERGY;
}
else if ((diff & SPARSE_WEIGHT) != 0 & (oldRep & SPARSE_ENERGY) != 0)
{
final double[] sparseEnergies = _sparseEnergies;
final double[] sparseWeights = _sparseWeights = new double[sparseEnergies.length];
for (int i = sparseEnergies.length; --i >= 0;)
{
sparseWeights[i] = Utilities.energyToWeight(sparseEnergies[i]);
}
oldRep |= SPARSE_WEIGHT;
}
}
//
// *-to-dense conversions
//
diff &= ~oldRep;
if ((diff & ALL_DENSE) != 0)
{
if ((oldRep & ALL_SPARSE) != 0)
{
if ((diff & DENSE_ENERGY) != 0)
{
if ((oldRep & SPARSE_ENERGY) != 0)
{
if (jointSize == sparseSize())
{
_denseEnergies = _sparseEnergies;
}
else
{
final double[] sparseEnergies = _sparseEnergies;
final double[] denseEnergies = _denseEnergies = new double[jointSize];
Arrays.fill(denseEnergies, Double.POSITIVE_INFINITY);
for (int si = sparseToJoint.length; --si >= 0;)
{
denseEnergies[sparseToJoint[si]] = sparseEnergies[si];
}
}
}
else
{
assert((oldRep & SPARSE_WEIGHT) != 0);
final double[] sparseWeights = _sparseWeights;
final double[] denseEnergies = _denseEnergies = new double[jointSize];
Arrays.fill(_denseEnergies, Double.POSITIVE_INFINITY);
if (denseEnergies.length == sparseWeights.length)
{
for (int di = denseEnergies.length; --di >=0;)
{
denseEnergies[di] = Utilities.weightToEnergy(sparseWeights[di]);
}
}
else
{
for (int si = sparseToJoint.length; --si >= 0;)
{
denseEnergies[sparseToJoint[si]] = Utilities.weightToEnergy(sparseWeights[si]);
}
}
}
oldRep |= DENSE_ENERGY;
}
if ((diff & DENSE_WEIGHT) != 0)
{
if ((oldRep & SPARSE_WEIGHT) != 0)
{
if (jointSize == sparseSize())
{
_denseWeights = _sparseWeights;
}
else
{
final double[] sparseWeights = _sparseWeights;
final double[] denseWeights = _denseWeights = new double[jointSize];
for (int si = sparseToJoint.length; --si >= 0;)
{
denseWeights[sparseToJoint[si]] = sparseWeights[si];
}
}
}
else
{
assert((oldRep & SPARSE_ENERGY) != 0);
final double[] sparseEnergies = _sparseEnergies;
final double[] denseWeights = _denseWeights = new double[jointSize];
if (denseWeights.length == sparseEnergies.length)
{
for(int di = denseWeights.length; --di>=0;)
{
denseWeights[di] = Utilities.energyToWeight(sparseEnergies[di]);
}
}
else
{
for (int si = sparseToJoint.length; -- si >= 0;)
{
denseWeights[sparseToJoint[si]] = Utilities.energyToWeight(sparseEnergies[si]);
}
}
}
oldRep |= DENSE_WEIGHT;
}
}
else
{
if ((diff & DENSE_ENERGY) != 0)
{
final double[] denseWeights = _denseWeights;
final double[] denseEnergies = _denseEnergies = new double[jointSize];
for (int i = 0; i < jointSize; ++i)
{
denseEnergies[i] = Utilities.weightToEnergy(denseWeights[i]);
}
oldRep |= DENSE_ENERGY;
}
else
{
final double[] denseEnergies = _denseEnergies;
final double[] denseWeights = _denseWeights = new double[jointSize];
for (int i = 0; i < jointSize; ++i)
{
denseWeights[i] = Utilities.energyToWeight(denseEnergies[i]);
}
oldRep |= DENSE_WEIGHT;
}
}
}
assert((newRep & ~oldRep) == 0);
//
// Remove old arrays
//
if ((newRep & SPARSE_ENERGY) == 0)
{
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
}
if ((newRep & SPARSE_WEIGHT) == 0)
{
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
}
if ((newRep & DENSE_ENERGY) == 0)
{
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
}
if ((newRep & DENSE_WEIGHT) == 0)
{
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
}
if ((newRep & ALL_SPARSE) == 0)
{
_sparseIndexToJointIndex = ArrayUtil.EMPTY_INT_ARRAY;
}
if ((newRep & SPARSE_INDICES) == 0)
{
_sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
}
_representation = newRep;
}
/*----------------------
* Old IFactorTable methods
*/
public FactorTable copy()
{
return clone();
}
@Override
public void copy(IFactorTable that)
{
if (that == this)
{
return;
}
if (!getDomainIndexer().domainsEqual(that.getDomainIndexer()))
{
throw new DimpleException("Cannot copy from factor table with different domains");
}
if (that instanceof FactorTable)
{
FactorTable other = (FactorTable)that;
_nonZeroWeights = other._nonZeroWeights;
_representation = other._representation;
_denseEnergies = Objects.requireNonNull(ArrayUtil.cloneArray(other._denseEnergies));
_denseWeights = Objects.requireNonNull(ArrayUtil.cloneArray(other._denseWeights));
_sparseEnergies = Objects.requireNonNull(ArrayUtil.cloneArray(other._sparseEnergies));
_sparseWeights = Objects.requireNonNull(ArrayUtil.cloneArray(other._sparseWeights));
_sparseIndexToJointIndex = Objects.requireNonNull(ArrayUtil.cloneArray(other._sparseIndexToJointIndex));
_sparseIndices = Objects.requireNonNull(ArrayUtil.cloneArray(other._sparseIndices));
_computedMask = other._computedMask;
}
else
{
_representation = that.getRepresentation().mask();
_nonZeroWeights = that.countNonZeroWeights();
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseIndexToJointIndex = ArrayUtil.EMPTY_INT_ARRAY;
_sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY;
_computedMask = 0;
if (that.hasSparseRepresentation())
{
final int size = that.sparseSize();
_sparseIndexToJointIndex = new int[size];
for (int si = 0; si < size; ++si)
{
_sparseIndexToJointIndex[si] = that.sparseIndexToJointIndex(si);
}
if (that.hasSparseEnergies())
{
_sparseEnergies = that.getEnergiesSparseUnsafe().clone();
}
if (that.hasSparseWeights())
{
_sparseWeights = that.getWeightsSparseUnsafe().clone();
}
if (that.hasSparseIndices())
{
_sparseIndices = that.getIndicesSparseUnsafe().clone();
}
}
if (that.hasDenseRepresentation())
{
final int size = that.jointSize();
if (that.hasDenseEnergies())
{
_denseEnergies = new double[size];
for (int ji = 0; ji < size; ++ji)
{
_denseEnergies[ji] = that.getEnergyForJointIndex(ji);
}
}
if (that.hasDenseWeights())
{
_denseWeights = new double[size];
for (int ji = 0; ji < size; ++ji)
{
_denseEnergies[ji] = that.getWeightForJointIndex(ji);
}
}
}
}
}
// FIXME: what to do if table is directed? Should we assert that the joined
// variables are all either inputs or outputs?
@Override
public FactorTable joinVariablesAndCreateNewTable(
int[] varIndices,
int[] indexToJointIndex,
DiscreteDomain[] allDomains,
DiscreteDomain jointDomain)
{
final JointDomainIndexer domains = getDomainIndexer();
final JointDomainReindexer converter =
makeConverterForJoinVariables(domains, varIndices, indexToJointIndex, allDomains, jointDomain);
return new FactorTable(this, converter);
}
static JointDomainReindexer makeConverterForJoinVariables(
JointDomainIndexer domains,
int[] varIndices,
int[] indexToJointIndex,
DiscreteDomain[] allDomains,
DiscreteDomain jointDomain)
{
assert(Arrays.equals(allDomains, domains.toArray()));
assert(varIndices.length == indexToJointIndex.length);
// Build a domain converter by first permuting joined domains to proper order at
// end of domain list, and then by doing the join.
final int joinedSize = varIndices.length;
final int unjoinedSize = domains.size() - joinedSize;
final int[] permutation = new int[domains.size()];
Arrays.fill(permutation, -1);
// Compute the mappings for the joined variables to the end of the list
for (int i = 0; i < joinedSize; ++i)
{
permutation[varIndices[i]] = unjoinedSize + indexToJointIndex[i];
}
// and the remaining unjoined variables at the front of the list.
int to = 0;
for (int i = 0; i < permutation.length; ++i)
{
if (permutation[i] < 0)
{
permutation[i] = to++;
}
}
// See if permutation actually changes anything.
boolean identityMap = true;
for (int i = permutation.length; --i>=0;)
{
if (permutation[i] != i)
{
identityMap = false;
break;
}
}
JointDomainReindexer converter = null;
if (!identityMap)
{
DiscreteDomain[] toDomains = new DiscreteDomain[permutation.length];
for (int i = permutation.length; --i>=0;)
{
toDomains[permutation[i]] = domains.get(i);
}
converter = createPermuter(domains, JointDomainIndexer.create(toDomains), permutation);
}
if (converter != null)
{
converter = converter.combineWith(createJoiner(converter.getToDomains(), unjoinedSize, joinedSize));
}
else
{
converter = createJoiner(domains, unjoinedSize, joinedSize);
}
return converter;
}
/*-----------------
* Private methods
*/
private int allocateSparseIndexForJointIndex(int jointIndex)
{
final int representation = _representation;
int sparseIndex = sparseIndexFromJointIndex(jointIndex);
if (sparseIndex < 0)
{
sparseIndex = -1-sparseIndex;
if ((representation & SPARSE_ENERGY) != 0)
{
_sparseEnergies = ArrayUtil.copyArrayForInsert(_sparseEnergies, sparseIndex, 1);
}
if ((representation & SPARSE_WEIGHT) != 0)
{
_sparseWeights = ArrayUtil.copyArrayForInsert(_sparseWeights, sparseIndex, 1);
}
if ((representation & SPARSE_INDICES) != 0)
{
_sparseIndices = ArrayUtil.copyArrayForInsert(_sparseIndices, sparseIndex, 1);
_sparseIndices[sparseIndex] = getDomainIndexer().jointIndexToIndices(jointIndex);
}
_sparseIndexToJointIndex = ArrayUtil.copyArrayForInsert(_sparseIndexToJointIndex, sparseIndex, 1);
_sparseIndexToJointIndex[sparseIndex] = jointIndex;
}
return sparseIndex;
}
private void computeNonZeroWeights()
{
int count = 0;
switch (_representation & ALL_VALUES)
{
case DETERMINISTIC:
_nonZeroWeights = _sparseIndexToJointIndex.length;
break;
case SPARSE_WEIGHT:
case DENSE_ENERGY_SPARSE_WEIGHT:
case ALL_WEIGHT:
case NOT_SPARSE_ENERGY:
case ALL_SPARSE:
case NOT_DENSE_ENERGY:
case NOT_DENSE_WEIGHT:
case ALL_VALUES:
for (double w : _sparseWeights)
if (w != 0)
++count;
break;
case SPARSE_ENERGY:
case ALL_ENERGY:
case SPARSE_ENERGY_DENSE_WEIGHT:
case NOT_SPARSE_WEIGHT:
for (double e : _sparseEnergies)
if (!Double.isInfinite(e))
++count;
break;
case DENSE_WEIGHT:
case ALL_DENSE:
for (double w : _denseWeights)
if (w != 0)
++count;
break;
case DENSE_ENERGY:
for (double e : _denseEnergies)
if (!Double.isInfinite(e))
++count;
break;
}
_nonZeroWeights = count;
}
private int[][] computeSparseIndices()
{
final JointDomainIndexer indexer = getDomainIndexer();
final int jointSize = indexer.getCardinality();
final int sparseSize = this.sparseSize();
final int[] sparseToJoint = _sparseIndexToJointIndex;
final int[][] sparseIndices = new int[sparseSize][];
if (sparseSize < jointSize)
{
for (int si = 0; si < sparseSize; ++si)
{
sparseIndices[si] = indexer.jointIndexToIndices(sparseToJoint[si]);
}
}
else
{
for (int ji = 0; ji < jointSize; ++ji)
{
sparseIndices[ji] = indexer.jointIndexToIndices(ji);
}
}
return sparseIndices;
}
private int[] computeSparseToJointIndexMap()
{
if (_sparseIndexToJointIndex.length > 0)
{
return _sparseIndexToJointIndex;
}
final int jointSize = jointSize();
final int[] map = new int[_nonZeroWeights];
if ((_representation & ALL_WEIGHT) != 0)
{
final double[] denseWeights = hasDenseWeights() ? _denseWeights : _sparseWeights;
assert(denseWeights.length == jointSize);
for (int di = 0, si = 0; si < map.length; ++di)
{
if (denseWeights[di] != 0.0)
{
map[si++] = di;
}
}
}
else
{
final double[] denseEnergies = hasDenseEnergies() ? _denseEnergies : _sparseEnergies;
assert(denseEnergies.length == jointSize);
for (int di = 0, si = 0; di < jointSize; ++di)
{
if (!Double.isInfinite(denseEnergies[di]))
{
map[si++] = di;
}
}
}
return map;
}
@Override
int normalizeDirected(boolean justCheck, boolean ignoreZeroWeightInputs)
{
final JointDomainIndexer domains = getDomainIndexer();
final int inputSize = domains.getInputCardinality();
final int outputSize = domains.getOutputCardinality();
final boolean hasSparseToJoint = _sparseIndexToJointIndex.length > 0;
boolean computeNormalizedTotal = justCheck;
double normalizedTotal = 1.0;
double totalForInput = 0.0;
int nNotNormalized = 0;
final double[] normalizers = justCheck ? null : new double[inputSize];
// We represent the joint index such that the outputs for the same
// input are stored consecutively, so we only need to walk through
// the values in order.
//
// When just checking, we allow total to equal something other than one
// as long as they are all the same.
switch (_representation & ALL_VALUES)
{
case DETERMINISTIC:
break;
case ALL_VALUES:
case ALL_WEIGHT:
case ALL_SPARSE:
case NOT_DENSE_WEIGHT:
case NOT_SPARSE_ENERGY:
case DENSE_ENERGY_SPARSE_WEIGHT:
case NOT_DENSE_ENERGY:
case SPARSE_WEIGHT:
for (int ii = 0, si = 0, size = _sparseWeights.length; ii < inputSize; ++ii)
{
final int maxji = domains.jointIndexFromInputOutputIndices(ii, outputSize-1);
totalForInput = 0.0;
for (; si < size && maxji >= (hasSparseToJoint ? _sparseIndexToJointIndex[si] : si); ++si)
{
totalForInput += _sparseWeights[si];
}
if (totalForInput == 0.0)
{
if (ignoreZeroWeightInputs)
++nNotNormalized;
else
return normalizeDirectedHandleZeroForInput(justCheck);
}
if (computeNormalizedTotal)
{
normalizedTotal = totalForInput;
computeNormalizedTotal = false;
}
else if (!DoubleMath.fuzzyEquals(totalForInput, normalizedTotal, 1e-12))
{
if (justCheck)
{
return 1;
}
}
if (normalizers != null)
{
normalizers[ii] = totalForInput;
}
}
if (normalizers != null)
{
for (int si = 0, size = _sparseWeights.length; si < size; ++si)
{
final int ji = hasSparseToJoint ? _sparseIndexToJointIndex[si] : si;
final int ii = domains.inputIndexFromJointIndex(ji);
final double normalizer = normalizers[ii];
if (normalizer != 0.0)
{
setWeightForSparseIndex(_sparseWeights[si] / normalizers[ii], si);
}
}
}
break;
case ALL_ENERGY:
case SPARSE_ENERGY:
case NOT_SPARSE_WEIGHT:
case SPARSE_ENERGY_DENSE_WEIGHT:
// TODO: if sparse size is large enough, it would be faster to iterate over the dense weights
for (int ii = 0, si = 0, size = _sparseEnergies.length; ii < inputSize; ++ii)
{
final int maxji = domains.jointIndexFromInputOutputIndices(ii, outputSize-1);
totalForInput = 0.0;
for (; si < size && maxji >= (hasSparseToJoint ? _sparseIndexToJointIndex[si] : si); ++si)
{
totalForInput += energyToWeight(_sparseEnergies[si]);
}
if (totalForInput == 0.0)
{
if (ignoreZeroWeightInputs)
++nNotNormalized;
else
return normalizeDirectedHandleZeroForInput(justCheck);
}
if (computeNormalizedTotal)
{
normalizedTotal = totalForInput;
computeNormalizedTotal = false;
}
else if (!DoubleMath.fuzzyEquals(totalForInput, normalizedTotal, 1e-12))
{
if (justCheck)
{
return 1;
}
}
if (normalizers != null)
{
normalizers[ii] = Math.log(totalForInput);
}
}
if (normalizers != null)
{
for (int si = 0, size = _sparseEnergies.length; si < size; ++si)
{
final int ji = hasSparseToJoint ? _sparseIndexToJointIndex[si] : si;
final int ii = domains.inputIndexFromJointIndex(ji);
final double normalizer = normalizers[ii];
if (normalizer != Double.NEGATIVE_INFINITY)
{
setEnergyForSparseIndex(_sparseEnergies[si] + normalizers[ii], si);
}
}
}
break;
case ALL_DENSE:
case DENSE_WEIGHT:
for (int jointIndex = 0, inputIndex = 0; inputIndex < inputSize; ++inputIndex, jointIndex += outputSize)
{
totalForInput = 0.0;
for (int outputIndex = 0; outputIndex < outputSize; ++outputIndex)
{
totalForInput += _denseWeights[jointIndex + outputIndex];
}
if (totalForInput == 0.0)
{
if (ignoreZeroWeightInputs)
++nNotNormalized;
else
return normalizeDirectedHandleZeroForInput(justCheck);
}
if (computeNormalizedTotal)
{
normalizedTotal = totalForInput;
computeNormalizedTotal = false;
}
else if (!DoubleMath.fuzzyEquals(totalForInput, normalizedTotal, 1e-12))
{
if (justCheck)
{
return 1;
}
}
if (normalizers != null)
{
normalizers[inputIndex] = totalForInput;
}
}
if (normalizers != null)
{
for (int jointIndex = 0, inputIndex = 0; inputIndex < inputSize; ++inputIndex, jointIndex += outputSize)
{
totalForInput = normalizers[inputIndex];
if (totalForInput != 0.0)
{
for (int outputIndex = 0; outputIndex < outputSize; ++outputIndex)
{
int ji = jointIndex + outputIndex;
setWeightForJointIndex(_denseWeights[ji] / totalForInput, ji);
}
}
}
}
break;
case DENSE_ENERGY:
for (int jointIndex = 0, inputIndex = 0; inputIndex < inputSize; ++inputIndex, jointIndex += outputSize)
{
totalForInput = 0.0;
for (int outputIndex = 0; outputIndex < outputSize; ++outputIndex)
{
totalForInput += energyToWeight(_denseEnergies[jointIndex + outputIndex]);
}
if (totalForInput == 0.0)
{
if (ignoreZeroWeightInputs)
++nNotNormalized;
else
return normalizeDirectedHandleZeroForInput(justCheck);
}
if (computeNormalizedTotal)
{
normalizedTotal = totalForInput;
computeNormalizedTotal = false;
}
else if (!DoubleMath.fuzzyEquals(totalForInput, normalizedTotal, 1e-12))
{
if (justCheck)
{
return 1;
}
}
if (normalizers != null)
{
normalizers[inputIndex] = Math.log(totalForInput);
}
}
if (normalizers != null)
{
for (int jointIndex = 0, inputIndex = 0; inputIndex < inputSize; ++inputIndex, jointIndex += outputSize)
{
final double normalizer = normalizers[inputIndex];
if (normalizer != Double.NEGATIVE_INFINITY)
{
for (int outputIndex = 0; outputIndex < outputSize; ++outputIndex)
{
int ji = jointIndex + outputIndex;
setEnergyForJointIndex(_denseEnergies[ji] + normalizer, ji);
}
}
}
}
break;
}
if (nNotNormalized == 0)
{
_computedMask |= CONDITIONAL|CONDITIONAL_COMPUTED;
}
else
{
_computedMask |= CONDITIONAL_COMPUTED;
}
return nNotNormalized;
}
@Override
boolean normalizeUndirected(boolean justCheck)
{
if ((_computedMask & NORMALIZED) != 0)
{
return true;
}
double total = 0.0;
switch (_representation & ALL_VALUES)
{
case ALL_VALUES:
case ALL_WEIGHT:
case ALL_SPARSE:
case NOT_DENSE_WEIGHT:
case NOT_SPARSE_ENERGY:
case DENSE_ENERGY_SPARSE_WEIGHT:
case NOT_DENSE_ENERGY:
case SPARSE_WEIGHT:
for (double w : _sparseWeights)
{
total += w;
}
break;
case ALL_ENERGY:
case SPARSE_ENERGY:
case NOT_SPARSE_WEIGHT:
case SPARSE_ENERGY_DENSE_WEIGHT:
// TODO: if sparse size is large enough, it would be faster to iterate over the dense weights
for (double e: _sparseEnergies)
{
total += energyToWeight(e);
}
break;
case ALL_DENSE:
case DENSE_WEIGHT:
for (double w : _denseWeights)
{
total += w;
}
break;
case DENSE_ENERGY:
for (double e : _denseEnergies)
{
total += energyToWeight(e);
}
break;
}
if (!DoubleMath.fuzzyEquals(total, 1.0, 1e-12))
{
if (justCheck)
{
return false;
}
if (total == 0.0)
{
throw normalizeUndirectedHandleZero();
}
for (int i = _sparseWeights.length; --i>=0;)
{
_sparseWeights[i] /= total;
}
if (_sparseWeights != _denseWeights)
{
for (int i = _denseWeights.length; --i>=0;)
{
_denseWeights[i] /= total;
}
}
final double logTotal = Math.log(total);
for (int i = _sparseEnergies.length; --i>=0;)
{
_sparseEnergies[i] += logTotal;
}
if (_sparseEnergies != _denseEnergies)
{
for (int i = _denseEnergies.length; --i>=0;)
{
_denseEnergies[i] += logTotal;
}
}
}
_computedMask |= NORMALIZED|NORMALIZED_COMPUTED;
return true;
}
private void setDenseValues(double[] values, int representation)
{
_function = null;
final JointDomainIndexer domains = getDomainIndexer();
if (values.length != domains.getCardinality())
{
throw new IllegalArgumentException(String.format("Bad dense length: was %d, expected %d",
values.length, domains.getCardinality()));
}
_computedMask = 0;
switch(representation)
{
case DENSE_ENERGY:
_denseEnergies = values.clone();
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
break;
case DENSE_WEIGHT:
_denseWeights = values.clone();
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
break;
default:
assert(false);
}
_representation = representation;
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_sparseIndexToJointIndex = ArrayUtil.EMPTY_INT_ARRAY;
computeNonZeroWeights();
}
@Override
void setDirected(@Nullable BitSet outputSet, boolean assertConditional)
{
final JointDomainIndexer oldDomains = getDomainIndexer();
final JointDomainIndexer newDomains = JointDomainIndexer.create(outputSet, oldDomains);
if (oldDomains.equals(newDomains))
{
if (assertConditional)
{
assertIsConditional();
}
return;
}
final int oldComputedMask = _computedMask;
final double[] oldDenseEnergies = _denseEnergies;
final double[] oldDenseWeights = _denseWeights;
final double[] oldSparseEnergies = _sparseEnergies;
final double[] oldSparseWeights = _sparseWeights;
final int[] oldSparseToJoint = _sparseIndexToJointIndex;
final int[][] oldSparseIndices = _sparseIndices;
boolean ok = false;
// FIXME: I don't think this is quite right if starting from DETERMINISTIC*
try
{
_computedMask = 0;
setDomainIndexer(newDomains);
if (!oldDomains.hasCanonicalDomainOrder() | !newDomains.hasCanonicalDomainOrder())
{
JointDomainReindexer converter =
JointDomainReindexer.createPermuter(oldDomains, newDomains);
if (hasDenseEnergies())
{
_denseEnergies = converter.convertDenseEnergies(_denseEnergies);
}
if (hasDenseWeights())
{
_denseWeights = converter.convertDenseWeights(_denseWeights);
}
if (_sparseIndexToJointIndex.length > 0)
{
_sparseIndexToJointIndex = converter.convertSparseToJointIndex(oldSparseToJoint);
if (hasSparseEnergies())
{
_sparseEnergies =
converter.convertSparseEnergies(_sparseEnergies, oldSparseToJoint, _sparseIndexToJointIndex);
}
if (hasSparseWeights())
{
_sparseWeights =
converter.convertSparseWeights(_sparseWeights, oldSparseToJoint, _sparseIndexToJointIndex);
}
if (hasSparseIndices())
{
_sparseIndices =
converter.convertSparseIndices(_sparseIndices, oldSparseToJoint, _sparseIndexToJointIndex);
}
}
}
if (outputSet == null)
{
if ((_representation & ALL_VALUES) == DETERMINISTIC)
{
// If direction removed, then convert DETERMINISTIC format to SPARSE_ENERGY.
_sparseEnergies = new double[_sparseIndexToJointIndex.length];
_representation = hasSparseIndices() ? SPARSE_ENERGY_WITH_INDICES : SPARSE_ENERGY;
}
}
else if (assertConditional)
{
assertIsConditional();
}
ok = true;
}
finally
{
if (!ok)
{
_computedMask = oldComputedMask;
_denseEnergies = oldDenseEnergies;
_denseWeights = oldDenseWeights;
_sparseEnergies = oldSparseEnergies;
_sparseWeights = oldSparseWeights;
_sparseIndexToJointIndex = oldSparseToJoint;
_sparseIndices = oldSparseIndices;
}
}
}
@Override
void setSparseValues(int[][] indices, double[] values, int representation)
{
final JointDomainIndexer domains = getDomainIndexer();
final int[] jointIndices = new int[indices.length];
for (int i = indices.length; --i>=0;)
{
jointIndices[i] = domains.jointIndexFromIndices(domains.validateIndices(indices[i]));
}
setSparseValues(jointIndices, values, representation);
}
private void setSparseValues(int[] jointIndices, double[] values, int representation)
{
_function = null;
final int size= jointIndices.length;
if (size != values.length)
{
throw new IllegalArgumentException(
String.format("'Arrays have different sizes: %d and %d",
size, values.length));
}
int[] jointIndices2 = null;
double[] values2 = null;
boolean doSort = false;
final JointDomainIndexer domains = getDomainIndexer();
int cardinality = domains.getCardinality();
for (int i = size; --i>=1;)
{
int jointIndex = jointIndices[i];
if (jointIndex < 0 || jointIndex >= cardinality)
{
throw new IllegalArgumentException(String.format("Joint index %d is out of range", jointIndex));
}
if (jointIndex < jointIndices[i-1])
{
doSort = true;
break;
}
}
if (doSort)
{
jointIndices2 = new int[size];
values2 = new double[size];
long[] sortedIndices = new long[size];
for (int i = size; --i>=0;)
{
sortedIndices[i] = ((long)jointIndices[i] << 32) | i;
}
Arrays.sort(sortedIndices);
for (int i = size; --i>=0;)
{
int jointIndex = (int)(sortedIndices[i] >>> 32);
jointIndices2[i] = jointIndex;
values2[i] = values[(int)sortedIndices[i]];
}
}
else
{
jointIndices2 = jointIndices.clone();
values2 = values.clone();
}
int prev = -1;
for (int jointIndex : jointIndices2)
{
if (jointIndex == prev)
{
throw new IllegalArgumentException(String.format(
"Multiple entries with same set of indices %s (joint index %d)",
Arrays.toString(domains.jointIndexToIndices(jointIndex)), jointIndex));
}
prev = jointIndex;
}
switch (representation)
{
case SPARSE_ENERGY:
_sparseEnergies = values2;
_sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
break;
case SPARSE_WEIGHT:
_sparseWeights = values2;
_sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
break;
default:
assert(false);
}
_representation = representation;
_sparseIndexToJointIndex = jointIndices2;
_computedMask = 0;
_denseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY;
_denseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY;
computeNonZeroWeights();
}
/**
* For implementation of {@link #setWeightForJointIndex(double, int)} and
* {@link #setEnergyForJointIndex(double, int)}
*/
private void setWeightEnergyForJointIndex(double weight, double energy, int jointIndex)
{
if ((_representation & ALL_SPARSE) != 0)
{
final int sparseIndex = allocateSparseIndexForJointIndex(jointIndex);
if ((_representation & SPARSE_ENERGY) != 0)
{
_sparseEnergies[sparseIndex] = energy;
}
if ((_representation & SPARSE_WEIGHT) != 0)
{
_sparseWeights[sparseIndex] = weight;
}
}
if ((_representation & DENSE_ENERGY) != 0)
{
_denseEnergies[jointIndex] = energy;
}
if ((_representation & DENSE_WEIGHT) != 0)
{
_denseWeights[jointIndex] = weight;
}
}
/**
* For implementation of {@link #setWeightForSparseIndex(double, int)} and
* {@link #setEnergyForSparseIndex(double, int)}
*/
private void setWeightEnergyForSparseIndex(double weight, double energy, int sparseIndex)
{
if ((_representation & ALL_DENSE) != 0)
{
final int jointIndex =
_sparseIndexToJointIndex.length == 0 ? sparseIndex : _sparseIndexToJointIndex[sparseIndex];
if ((_representation & DENSE_ENERGY) != 0)
{
_denseEnergies[jointIndex] = energy;
}
if ((_representation & DENSE_WEIGHT) != 0)
{
_denseWeights[jointIndex] = weight;
}
}
if ((_representation & SPARSE_ENERGY) != 0)
{
_sparseEnergies[sparseIndex] = energy;
}
if ((_representation & SPARSE_WEIGHT) != 0)
{
_sparseWeights[sparseIndex] = weight;
}
}
}