/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.model.domains;
import static java.util.Objects.*;
import java.io.Serializable;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.BitSetUtil;
import com.analog.lyric.collect.Comparators;
import com.analog.lyric.collect.Supers;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.exceptions.DomainException;
import com.analog.lyric.dimple.model.values.DiscreteValue;
import com.analog.lyric.dimple.model.values.Value;
import net.jcip.annotations.Immutable;
/**
* Provides a representation and canonical indexing operations for an ordered list of
* {@link DiscreteDomain} for use in implementing discrete factor tables and messages.
* <p>
* Construct using one of the {@code create} methods listed below.
* <p>
* The methods of this class employs the following terminology:
* <dl>
* <dt>size</dt>
* <dd>Refers to number of things in a single dimension: {@link #size()} is the number of
* domains in the indexer, {@link #getInputSize()}/{@link #getOutputSize()} are the number of domains that are
* designated as inputs/respectively. and {@link #getDomainSize(int)} is the size of the nth domain.
* </dd>
*
* <dt>cardinality</dt>
* <dd>Is a combinatoric size across multiple dimensions: {@link #getCardinality()} is the number of
* combinations of elements from all domains, and {@link #getInputCardinality()}/{@link #getOutputCardinality()}
* are the combinations involving only domains designated as inputs/outputs respectively.
* </dd>
*
* <dt>elements</dt>
* <dd>Refers to an array of elements of the domains in the indexer in their canonical order in the indexer,
* where the ith element is a member of the domain returned by {@link #get}(i).
* </dd>
*
* <dt>indices</dt>
* <dd>Refers to an array of indices of domain elements in their canonical order in the indexer. This is
* an equivalent representation to elements but uses the element indexes rather than their actual values.
* </dd>
*
* <dt>values</dt>
* <dd>Refers to a {@link Value} array presumed to contain non-null {@link DiscreteValue}s that contain
* both the element and index.
* </dd>
*
* <dt>joint index</dt>
* <dd>Is a single integral index representing a unique combination of indices/elements that maps each
* combination of values to the range [0, {@link #getCardinality()}-1].
* </dd>
*
* <dt>domain index</dt>
* <dd>Refers to the index of a domain in the indexer considered as a {@link DomainList}, e.g.
* {@link #getInputDomainIndices} returns an array of indexes of domains designated as input domains.
* </dd>
*
* <dt>input index</dt>
* <dd>Is a single integral index representing a unique combination of just input indices/elements that
* maps each such combination to the range [0, {@link #getInputCardinality()}-1]. Only meaningful for
* directed indexers.
* </dd>
*
* <dt>output index</dt>
* <dd>Is the same as input index but for outputs.</dd>
*
* </dl>
* <p>
* @see #create(DiscreteDomain...)
* @see #create(Domain...)
* @see #create(BitSet, DiscreteDomain...)
* @see #create(BitSet, JointDomainIndexer)
* @see #create(int[], DiscreteDomain[])
* @see #create(int[], Domain...)
*/
@Immutable
public abstract class JointDomainIndexer extends DomainList<DiscreteDomain>
{
/*-------
* State
*/
private static final long serialVersionUID = 1L;
/**
* Precomputed {@link #hashCode()}.
*/
final int _hashCode;
/**
* Nearest common base class across all domains.
*/
private final Class<?> _elementClass;
/*--------------
* Construction
*/
JointDomainIndexer(int hashCode, DiscreteDomain[] domains)
{
super(domains);
_hashCode = hashCode;
if (domains.length == 0)
{
throw new DimpleException("Empty domain list");
}
Class<?> elementClass = domains[0].getElementClass();
for (int i = domains.length; --i >= 1; )
{
assert(elementClass != null);
elementClass = Supers.nearestCommonSuperClass(elementClass, domains[i].getElementClass());
}
// This could only be null if the getElementClass() could return null or a primitive type.
_elementClass = Objects.requireNonNull(elementClass);
}
JointDomainIndexer(DiscreteDomain[] domains)
{
this(computeHashCode(domains), domains);
}
private static JointDomainIndexer lookupOrCreate(@Nullable BitSet outputs, List<DiscreteDomain> domains)
{
return lookupOrCreate(outputs, domains.toArray(new DiscreteDomain[domains.size()]), false);
}
private static JointDomainIndexer lookupOrCreate(@Nullable BitSet outputs, DiscreteDomain[] domains, boolean cloneDomains)
{
if (cloneDomains)
{
domains = Arrays.copyOf(domains, domains.length, DiscreteDomain[].class);
}
if (domainsTooLargeForIntegerIndex(domains))
{
if (outputs != null)
{
return intern(new LargeDirectedJointDomainIndexer(outputs, domains));
}
else
{
return intern(new LargeJointDomainIndexer(domains));
}
}
if (outputs != null)
{
return intern(new StandardDirectedJointDomainIndexer(outputs, domains));
}
else
{
return intern(new StandardJointDomainIndexer(domains));
}
}
private static JointDomainIndexer lookupOrCreate(@Nullable int[] outputIndices, List<DiscreteDomain> domains)
{
return lookupOrCreate(outputIndices, domains.toArray(new DiscreteDomain[domains.size()]), false);
}
static JointDomainIndexer lookupOrCreate(@Nullable int[] outputIndices, DiscreteDomain[] domains, boolean cloneDomains)
{
BitSet outputs = null;
if (outputIndices != null)
{
outputs = BitSetUtil.bitsetFromIndices(domains.length, outputIndices);
}
return lookupOrCreate(outputs, domains, cloneDomains);
}
/**
* Creates a directed indexer consisting of the specified {@code domains} in the given
* order and with the specified domains designated as outputs.
* <p>
* If {@code outputs} is null, this will instead return an undirected indexer.
*/
public static JointDomainIndexer create(@Nullable BitSet outputs, DiscreteDomain ... domains)
{
return lookupOrCreate(outputs, domains, true);
}
/**
* Creates a directed indexer consisting of the specified {@code domains} in the given
* order and with the specified domains designated as outputs.
* <p>
* If {@code outputs} is null, this will instead return an undirected indexer.
* @since 0.08
*/
public static JointDomainIndexer create(@Nullable BitSet outputs, List<DiscreteDomain> domains)
{
return lookupOrCreate(outputs, domains);
}
/**
* Creates an undirected indexer consisting of the specified {@code domains} in the given order.
* <p>
* May return a previously cached value.
*/
public static JointDomainIndexer create(DiscreteDomain ... domains)
{
return lookupOrCreate((BitSet)null, domains, true);
}
/**
* Creates an undirected indexer consisting of the specified {@code domains} in the given order.
* <p>
* May return a previously cached value.
* @since 0.08
*/
public static JointDomainIndexer create(List<DiscreteDomain> domains)
{
return lookupOrCreate((BitSet)null, domains);
}
/**
* Creates a directed indexer consisting of the specified {@code domains} in the given
* order and with the specified domains designated as outputs.
* <p>
* If {@code outputIndices} is null, this will instead return an undirected indexer.
*/
public static JointDomainIndexer create(@Nullable int[] outputDomainIndices, DiscreteDomain[] domains)
{
return lookupOrCreate(outputDomainIndices, domains, true);
}
/**
* Creates a directed indexer consisting of the specified {@code domains} in the given
* order and with the specified domains designated as outputs.
* <p>
* If {@code outputIndices} is null, this will instead return an undirected indexer.
* @since 0.08
*/
public static JointDomainIndexer create(@Nullable int[] outputDomainIndices, List<DiscreteDomain> domains)
{
return lookupOrCreate(outputDomainIndices, domains);
}
/**
* Creates a directed indexer consisting of the specified {@code domains} in the given
* order and with the specified domains designated as outputs.
* <p>
* If {@code outputs} is null, this will instead return an undirected indexer.
*/
public static JointDomainIndexer create(@Nullable BitSet outputs, JointDomainIndexer domains)
{
return lookupOrCreate(outputs, domains._domains, false);
}
/**
* Returns a new domain list that concatenates the domains of this list
* with {@code that}. Only produces a directed list if both lists are directed.
*/
public static @Nullable JointDomainIndexer concat(
@Nullable JointDomainIndexer domains1,
@Nullable JointDomainIndexer domains2)
{
if (domains1 == null)
{
return domains2;
}
else if (domains2 == null)
{
return domains1;
}
return concatNonNull(domains1, domains2);
}
/**
* Returns a new domain list that concatenates the domains of this list
* with {@code that}. Only produces a directed list if both lists are directed.
*/
public static JointDomainIndexer concatNonNull(JointDomainIndexer domains1, JointDomainIndexer domains2)
{
final int size1 = domains1.size();
final int size2 = domains2.size();
final int size = size1 + domains2.size();
final DiscreteDomain[] domains = Arrays.copyOf(domains1._domains, size);
for (int i = 0; i < size2; ++ i)
{
domains[size1 + i] = domains2.get(i);
}
BitSet outputs = null;
if (domains1.isDirected() & domains2.isDirected())
{
outputs = requireNonNull(domains1.getOutputSet());
for (int i : requireNonNull(domains2.getOutputDomainIndices()))
{
outputs.set(size1 + i);
}
}
return JointDomainIndexer.create(outputs, domains);
}
/**
* Creates a new indexer constructed from the {@code length} domains in this list starting
* at {@code offset}. Returns a directed indexer if this indexer is directed.
*
* @since 0.05
*/
public JointDomainIndexer subindexer(int offset, int length)
{
if (offset < 0 || offset >= size() || length < 0 || length + offset > size())
{
throw new IllegalArgumentException(String.format("Bad offset/length %d/%d for subindexer", offset, length));
}
final DiscreteDomain[] domains = new DiscreteDomain[length];
for (int i = 0; i < length; ++i)
{
domains[i] = get(i + offset);
}
BitSet outputs = null;
if (isDirected())
{
final BitSet theseOutputs = requireNonNull(getOutputSet());
outputs = new BitSet(length);
for (int i = 0; i < length; ++i)
{
if (theseOutputs.get(i + offset))
{
outputs.set(i);
}
}
}
return JointDomainIndexer.create(outputs, domains);
}
/*----------------
* Object methods
*/
@Override
public boolean equals(@Nullable Object that)
{
if (this == that)
{
return true;
}
if (that instanceof JointDomainIndexer)
{
JointDomainIndexer thatDiscrete = (JointDomainIndexer)that;
return thatDiscrete._hashCode == _hashCode &&
!thatDiscrete.isDirected() &&
Arrays.equals(_domains, thatDiscrete._domains);
}
return false;
}
@Override
public final int hashCode()
{
return _hashCode;
}
/*--------------------
* DomainList methods
*/
@Override
public JointDomainIndexer asJointDomainIndexer()
{
return this;
}
@Override
public boolean isDiscrete()
{
return true;
}
/*----------------------------
* JointDomainIndexer methods
*/
/**
* Returns an array with length at least {@link #size()} and with component type
* compatible with elements of all domains in list (i.e. a superclass of type specified
* by {@link #getElementClass()}.
* <p>
* If {@code elements} fits the above description, it will simply be returned.
* If {@code elements} is too short but has a compatible component type, this will
* return a new array with length equal to {@link #size()} and component type
* the same as {@code elements}. Otherwise returns a new array with component
* type same as {@link #getElementClass()}.
*/
public final <T> T[] allocateElements(@Nullable T [] elements)
{
return ArrayUtil.allocateArrayOfType(_elementClass, elements, _domains.length);
}
/**
* Returns an array with length at least {@link #size()}.
* <p>
* If {@code indices} is non-null and is sufficiently long, it will be returned.
* Otherwise a newly allocated array will be returned.
*/
public final int[] allocateIndices(@Nullable int [] indices)
{
if (indices == null || indices.length < _domains.length)
{
indices = new int[_domains.length];
}
return indices;
}
/**
* Convert elements to corresponding domain indices.
* <p>
* @param elements array of domain elements. Each element must be a member of
* the correspondingly numbered domain.
* @return individual domain indices
* @throws DomainException if any element is not a member of the corresponding domain.
*/
public final int[] elementsToIndices(Object[] elements)
{
return elementsToIndices(elements, null);
}
/**
* Convert elements to corresponding domain indices.
* <p>
* @param elements array of domain elements. Each element must be a member of
* the correspondingly numbered domain.
* @param indices is the array into which indices will be written. Will only be used
* if adequately large (see {@link #allocateIndices(int[])}).
* @return individual domain indices
* @throws DomainException if any element is not a member of the corresponding domain.
*/
public final int[] elementsToIndices(Object[] elements, @Nullable int indices[])
{
indices = allocateIndices(indices);
for (int i = 0, end = _domains.length; i < end; ++i)
{
indices[i] = _domains[i].getIndexOrThrow(elements[i]);
}
return indices;
}
public final Object[] elementsFromIndices(int indices[])
{
return elementsFromIndices(indices, null);
}
public final Object[] elementsFromIndices(int indices[], @Nullable Object[] elements)
{
elements = allocateElements(elements);
for (int i = 0, end = _domains.length; i < end; ++i)
{
elements[i] = _domains[i].getElement(indices[i]);
}
return elements;
}
/**
* The number of possible combinations of all domain elements. Equal to the product of
* all of the domain sizes.
* <p>
* @see #getInputCardinality()
* @see #getOutputCardinality()
*/
public abstract int getCardinality();
/**
* Returns the size of the ith domain in the list.
*/
public final int getDomainSize(int i)
{
return _domains[i].size();
}
/**
* Returns nearest common superclass for elements in all domains.
*/
public final Class<?> getElementClass()
{
return _elementClass;
}
/**
* Returns a comparator that orders indices arrays.
* <p>
* If {@link #supportsJointIndexing()}, this is guaranteed to produce the same order as
* the natural order of the corresponding joint indexes. If not {@link #isDirected()} or
* {@link #hasCanonicalDomainOrder()}, then the comparator implements a reverse lexicographical
* ordering (see {@link Comparators#reverseLexicalIntArray()}).
* <p>
* The comparator is intended to be used with arrays of length {@link #size()}.
*/
public Comparator<int[]> getIndicesComparator()
{
return Comparators.reverseLexicalIntArray();
}
/**
* The number of possible combinations of input domain elements. Equal to the product of
* all of the input domain sizes. Will be one if not {@link #isDirected()}.
* @see #getCardinality()
* @see #getOutputCardinality()
*/
public int getInputCardinality()
{
return 1;
}
/**
* Returns the index of the ith domain designated as an input domain.
* <p>
* This is equivalent to returning the ith element of {@link #getInputDomainIndices()} but
* without having to allocate and copy an array.
* <p>
* @throws ArrayIndexOutOfBoundsException if i is not in range [0, {@link #getInputSize()}-1]
* (will always throw if not {@link #isDirected()}.
*/
public int getInputDomainIndex(int i)
{
throw new ArrayIndexOutOfBoundsException();
}
/**
* Returns a copy of the indexes of the input domains listed in increasing order or
* else null if not {@link #isDirected()}.
* <p>
* Use {@link #getInputDomainIndex(int)} to lookup an index without allocating a new array object.
*/
public @Nullable int[] getInputDomainIndices()
{
return null;
}
/**
/**
* Returns a copy of the {@link BitSet} representing the indexes of the input domains or
* else null if not {@link #isDirected()}.
*
* @see #getInputDomainIndices()
* @see #getInputDomainIndex(int)
* @see #getOutputSet()
*/
public @Nullable BitSet getInputSet()
{
return null;
}
/**
* The number of input domains if {@link #isDirected()}, otherwise zero.
* <p>
* If directed, this must be greater than zero and when combined with {@link #getOutputSize()} must add up to
* {@link #size()}.
*/
public int getInputSize()
{
return 0;
}
/**
* The number of possible combinations of output domain elements. Equal to the product of
* all of the output domain sizes. Will be the same as {@link #getCardinality()} if not {@link #isDirected()}.
* @see #getInputCardinality()
*/
public abstract int getOutputCardinality();
/**
* Returns the index of the ith output domain. If not {@link #isDirected()} this treats
* all domains as output domains and will just return {@code i} if within range.
* @throws ArrayIndexOutOfBoundsException if {@code i} is negative or not less than
* {@link #getOutputSize()}.
*/
public int getOutputDomainIndex(int i)
{
if (i < 0 || i >= size())
{
throw new ArrayIndexOutOfBoundsException();
}
return i;
}
/**
* Returns a copy of the indexes of the input domains listed in increasing order or
* else null if not {@link #isDirected()}.
* <p>
* Use {@link #getOutputDomainIndex(int)} to lookup an index without allocating a new array object.
*/
public @Nullable int[] getOutputDomainIndices()
{
return null;
}
/**
* Returns a copy of the {@link BitSet} representing the indexes of the output domains or
* else null if not {@link #isDirected()}.
*
* @see #getOutputDomainIndices()
* @see #getOutputDomainIndex(int)
* @see #getInputSet()
*/
public @Nullable BitSet getOutputSet()
{
return null;
}
/**
* The number of output domains if {@link #isDirected()}, otherwise the same as {@link #size()}.
* <p>
* If directed, this must be greater than zero and when combined with {@link #getInputSize()} must add up to
* {@link #size()}.
*/
public int getOutputSize()
{
return size();
}
/**
* Returns sum of sizes of all of the domains.
* @since 0.08
*/
public abstract int getSumOfDomainSizes();
/**
* Returns amount by which joint index returned by {@link #jointIndexFromIndices(int...)} changes
* when ith element index changes by 1.
* <p>
* This can be used to iterate over the joint indexes for one dimension for fixed values of all of
* the other dimensions.
* <p>
* @see #getUndirectedStride(int)
*/
public abstract int getStride(int i);
/**
* Returns amount by which joint index returned by {@link #undirectedJointIndexFromIndices(int...)} changes
* when ith element index changes by 1.
* <p>
* This can be used to iterate over the joint indexes for one dimension for fixed values of all of
* the other dimensions.
* <p>
* @see #getStride(int)
*/
public abstract int getUndirectedStride(int i);
/**
* True if domain list is partitioned into inputs and outputs.
*/
public boolean isDirected()
{
return false;
}
/**
* True if not {@link #isDirected()}. Otherwise,
* if {@link #isDirected()}, then this is true if all the
* output domains in {@link #getOutputSet()} are at the front
* of the list.
*/
public boolean hasCanonicalDomainOrder()
{
return true;
}
/**
* Returns true if two indices arrays have the same values at the
* indexes specified by {@link #getInputDomainIndices()}. If
* not {@link #isDirected()}, this will always return true.
*/
public boolean hasSameInputs(int[] indices1, int[] indices2)
{
return true;
}
/**
* Computes a unique index for the subset of {@code elements} designated as inputs.
* <p>
* Similar to {@link #jointIndexFromElements(Object...)} but computes value only
* from those elements indexed by {@link #getInputSet()}.
* <p>
* Returns 0 if not {@link #isDirected()}.
* <p>
* @throw DomainException if an element is not a member of the corresponding domain
* @see #inputIndexFromIndices(int...)
* @see #outputIndexFromElements(Object...)
*/
public int inputIndexFromElements(Object ... elements)
{
return 0;
}
/**
* Computes a unique index for the subset of {@code indices} designated as inputs.
* <p>
* Similar to {@link #jointIndexFromIndices(int...)} but computes value only
* from those element indices designated by {@link #getInputSet()}.
* <p>
* Returns 0 if not {@link #isDirected()}.
* <p>
* @see #inputIndexFromElements(Object...)
* @see #outputIndexFromIndices(int...)
*/
public int inputIndexFromIndices(int ... indices)
{
return 0;
}
/**
* Computes a unique index for the subset of {@code values} designated as inputs.
* <p>
* Similar to {@link #jointIndexFromValues(Value...)} but computes value only
* from those element indices designated by {@link #getInputSet()}.
* <p>
* Returns 0 if not {@link #isDirected()}.
* <p>
* @see #inputIndexFromElements(Object...)
* @see #inputIndexFromIndices(int...)
* @see #outputIndexFromValues(Value...)
*/
public int inputIndexFromValues(Value ... values)
{
return 0;
}
/**
* Converts a joint index to an input index.
* <p>
* Returns 0 if not {@link #isDirected()}.
* <p>
* @see #outputIndexFromJointIndex(int)
* @see #jointIndexFromInputOutputIndices(int, int)
*/
public int inputIndexFromJointIndex(int jointIndex)
{
return 0;
}
/**
* Writes elements corresponding to {@code inputIndex} into corresponding {@code elements} array.
* <p>
* Only updates members of {@code elements} designated as inputs, and therefore will do nothing
* if not {@link #isDirected()}.
* <p>
* @param inputIndex must be in range [0, {@link #getInputCardinality()}-1].
* @param elements must have length equal to {@link #size()}.
* <p>
* @see #inputIndexFromElements(Object...)
* @see #outputIndexToElements(int, Object[])
*/
public void inputIndexToElements(int inputIndex, Object[] elements)
{
}
/**
* Writes elements corresponding to {@code inputIndex} into corresponding {@code elements} array.
* <p>
* Only updates members of {@code elements} designated as inputs, and therefore will do nothing
* if not {@link #isDirected()}.
* <p>
* @param inputIndex must be in range [0, {@link #getInputCardinality()}-1].
* @param values must have length equal to {@link #size()} and must be fully populated (i.e. no null entries)
* with {@link Value} objects with domain compatible with corresponding indexer domains.
*
* @since 0.07
*/
public void inputIndexToValues(int inputIndex, Value[] values)
{
}
/**
* Writes element indices corresponding to {@code inputIndex} into corresponding {@code indices} array.
* <p>
* Only updates members of {@code indices} for domains designated as inputs, and therefore will do nothing
* if not {@link #isDirected()}.
* <p>
* @param inputIndex must be in range [0, {@link #getInputCardinality()}-1].
* @param indices must have length equal to {@link #size()}.
* <p>
* @see #inputIndexFromIndices(int...)
* @see #outputIndexToIndices(int, int[])
*/
public void inputIndexToIndices(int inputIndex, int[] indices)
{
}
/**
* Computes a unique joint index associated with the specified domain elements.
* <p>
* @param elements must have length equal to {@link #size()} and each elements must
* be an element of the corresponding domain.
* @see #jointIndexFromIndices(int[])
* @see #jointIndexToElements(int, Object[])
*/
public int jointIndexFromElements(Object ... elements)
{
return undirectedJointIndexFromElements(elements);
}
/**
* Computes a unique joint index associated with the specified {@code indices}.
* <p>
* The joint index is equivalent the inner product of the vector of stride values and {@code indices}.
* That is:
* <pre>
* int jointIndex = 0;
* for (int i = 0; i < getDimensions(); ++i)
* jointIndex += getStride(i) * indices[i];
* </pre>
* <p>
* @param indices must have length equal to {@link #size()} and each index must be a non-negative
* value less than the size of the corresponding domain otherwise the function could return an
* incorrect result.
* @see #jointIndexFromElements
* @see #jointIndexToIndices
* @see #validateIndices(int...)
*/
public int jointIndexFromIndices(int ... indices)
{
return undirectedJointIndexFromIndices(indices);
}
public int jointIndexFromValues(Value ... values)
{
return undirectedJointIndexFromValues(values);
}
/**
* Converts input and output indexes to a joint index.
* <p>
* Returns {@code outputIndex} if not {@link #isDirected()}.
* <p>
* @param inputIndex must be in range [0, {@link #getInputCardinality()}-1]
* @param outputIndex must be in range [0, {@link #getOutputCardinality()}-1]
* <p>
* @see #inputIndexFromJointIndex(int)
* @see #outputIndexFromJointIndex(int)
*/
public int jointIndexFromInputOutputIndices(int inputIndex, int outputIndex)
{
return outputIndex;
}
/**
* Computes domain values corresponding to given joint index.
* <p>
* @param jointIndex a unique joint table index in the range [0,{@link #getCardinality()}).
* @param elements if this is an array of length {@link #size()}, the computed values will
* be placed in this array, otherwise a new array will be allocated.
* @see #jointIndexToIndices(int, int[])
* @see #jointIndexFromElements(Object...)
*/
public <T> T[] jointIndexToElements(int jointIndex, @Nullable T[] elements)
{
return undirectedJointIndexToElements(jointIndex, elements);
}
/**
* Computes domain values corresponding to given joint index.
* <p>
* @param jointIndex a unique joint table index in the range [0,{@link #getCardinality()}).
* @param values must have length equal to {@link #size()} and must be fully populated (i.e. no null entries)
* with {@link Value} objects with domain compatible with corresponding indexer domains.
* @since 0.07
*/
public Value[] jointIndexToValues(int jointIndex, Value[] values)
{
return undirectedJointIndexToValues(jointIndex, values);
}
/**
* Computes domain values corresponding to given joint index.
* <p>
* Same as {@link #jointIndexToElements(int, Object[])} with null second argument.
*/
public final Object[] jointIndexToElements(int jointIndex)
{
return jointIndexToElements(jointIndex, null);
}
/**
* Computes domain values corresponding to given joint index.
* <p>
* @see #jointIndexToValues(int, Value[])
* @since 0.07
*/
public final Value[] jointIndexToValues(int jointIndex)
{
return jointIndexToValues(jointIndex, Value.createFromDomains(_domains));
}
/**
* Computes element index for a single domain from a joint index.
* <p>
* This is like {@link #jointIndexToIndices} but only computes one element index.
* <p>
* @param jointIndex must be in range [0, {@link #getCardinality()}-1].
* @param domainIndex must be in range [0, {@link #size()}-1].
* <p>
* @see #undirectedJointIndexToElementIndex(int, int)
*/
public int jointIndexToElementIndex(int jointIndex, int domainIndex)
{
return undirectedJointIndexToElementIndex(jointIndex, domainIndex);
}
/**
* Computes domain indices corresponding to given joint index.
* <p>
* @param jointIndex a unique joint table index in the range [0,{@link #getCardinality()}).
* @param indices if this is an array of length {@link #size()}, the computed values will
* be placed in this array, otherwise a new array will be allocated.
* @see #jointIndexToElements(int, Object[])
* @see #jointIndexFromIndices(int...)
* @see #jointIndexToElementIndex(int, int)
*/
public int[] jointIndexToIndices(int jointIndex, @Nullable int[] indices)
{
return undirectedJointIndexToIndices(jointIndex, indices);
}
/**
* Computes domain indices corresponding to given joint index.
* <p>
* Same as {@link #jointIndexToIndices(int, int[])} with null second argument.
*/
public final int[] jointIndexToIndices(int jointIndex)
{
return jointIndexToIndices(jointIndex, null);
}
/**
* Computes a unique index for the subset of {@code elements} designated as outputs.
* <p>
* Similar to {@link #jointIndexFromElements(Object...)} but computes value only
* from those elements indexed by {@link #getOutputSet()}.
* <p>
* Same as {@link #jointIndexFromElements(Object...)} if not {@link #isDirected()}.
* <p>
* @throw DomainException if an element is not a member of the corresponding domain
* @see #outputIndexFromIndices(int...)
* @see #inputIndexFromElements(Object...)
*/
public int outputIndexFromElements(Object ... elements)
{
return undirectedJointIndexFromElements(elements);
}
/**
* Computes a unique index for the subset of {@code indices} designated as outputs.
* <p>
* Similar to {@link #jointIndexFromIndices(int...)} but computes value only
* from those element indices designated by {@link #getOutputSet()}.
* <p>
* Same as {@link #jointIndexFromIndices(int...)} if not {@link #isDirected()}.
* <p>
* @see #outputIndexFromElements(Object...)
* @see #inputIndexFromIndices(int...)
*/
public int outputIndexFromIndices(int ... indices)
{
return undirectedJointIndexFromIndices(indices);
}
public int outputIndexFromValues(Value ... values)
{
return undirectedJointIndexFromValues(values);
}
/**
* Converts a joint index to an output index.
* <p>
* Returns {@code jointIndex} if not {@link #isDirected()}.
* <p>
* @see #inputIndexFromJointIndex(int)
* @see #jointIndexFromInputOutputIndices(int, int)
*/
public int outputIndexFromJointIndex(int jointIndex)
{
return jointIndex;
}
/**
* Writes elements corresponding to {@code outputIndex} into corresponding {@code elements} array.
* <p>
* Only updates members of {@code elements} designated as outputs, so this will only update all
* of the elements if not {@link #isDirected()}.
* <p>
* @param outputIndex must be in range [0, {@link #getOutputCardinality()}-1].
* @param elements must have length equal to {@link #size()}.
* <p>
* @see #outputIndexFromElements(Object...)
* @see #inputIndexToElements(int, Object[])
*/
public void outputIndexToElements(int outputIndex, Object[] elements)
{
undirectedJointIndexToElements(outputIndex, elements);
}
/**
* Writes elements corresponding to {@code outputIndex} into corresponding {@code elements} array.
* <p>
* Only updates members of {@code elements} designated as outputs, so this will only update all
* of the elements if not {@link #isDirected()}.
* <p>
* @param outputIndex must be in range [0, {@link #getOutputCardinality()}-1].
* @param values must have length equal to {@link #size()} and must be fully populated (i.e. no null entries)
* with {@link Value} objects with domain compatible with corresponding indexer domains.
* @since 0.07
*/
public void outputIndexToValues(int outputIndex, Value[] values)
{
undirectedJointIndexToValues(outputIndex, values);
}
/**
* Writes element indices corresponding to {@code outputIndex} into corresponding {@code indices} array.
* <p>
* Only updates members of {@code indices} for domains designated as inputs, and therefore will do nothing
* if not {@link #isDirected()}.
* <p>
* @param outputIndex must be in range [0, {@link #getOutputCardinality()}-1].
* @param indices must have length equal to {@link #size()}.
* <p>
* @see #outputIndexFromIndices(int...)
*/
public void outputIndexToIndices(int outputIndex, int[] indices)
{
undirectedJointIndexToIndices(outputIndex, indices);
}
public abstract int undirectedJointIndexFromElements(Object ... elements);
public abstract int undirectedJointIndexFromIndices(int ... indices);
public abstract int undirectedJointIndexFromValues(Value ... values);
public abstract <T> T[] undirectedJointIndexToElements(int jointIndex, @Nullable T[] elements);
/**
* @since 0.07
*/
public abstract Value[] undirectedJointIndexToValues(int jointIndex, Value[] elements);
public final Object[] undirectedJointIndexToElements(int jointIndex)
{
return undirectedJointIndexToElements(jointIndex, null);
}
/**
* @since 0.07
*/
public final Value[] undirectedJointIndexToValues(int jointIndex)
{
return undirectedJointIndexToValues(jointIndex, Value.createFromDomains(_domains));
}
/**
* Computes element index for a single domain from a joint index using undirected ordering
* of domains.
* <p>
* This is like {@link #undirectedJointIndexToIndices} but only computes one element index.
* <p>
* @param jointIndex must be in range [0, {@link #getCardinality()}-1].
* @param domainIndex must be in range [0, {@link #size()}-1].
* <p>
* @see #jointIndexToElementIndex(int, int)
*/
public abstract int undirectedJointIndexToElementIndex(int jointIndex, int domainIndex);
public abstract int[] undirectedJointIndexToIndices(int jointIndex, @Nullable int[] indices);
public final int[] undirectedJointIndexToIndices(int jointIndex)
{
return undirectedJointIndexToIndices(jointIndex, null);
}
/**
* Generates a random joint index value in the range [0, {@link #getCardinality()} - 1] using provided
* random number generator. Will throw an exception if not {@link #supportsJointIndexing()}.
* <p>
* @see #randomIndices(Random, int[])
*/
public int randomJointIndex(Random rand)
{
return rand.nextInt(getCardinality());
}
/**
* Generates a random set of indices for the domains using the supplied random number generator.
*
* @param indices if non-null and of length at least {@link #size()}, this indices will be written
* into this array. Otherwise a new one will be allocated.
*
* @see #randomJointIndex(Random)
*/
public int[] randomIndices(Random rand, @Nullable int[] indices)
{
indices = allocateIndices(indices);
for (int i = 0; i < size(); ++i)
{
indices[i] = rand.nextInt(getDomainSize(i));
}
return indices;
}
/**
* Indicates whether class supports operations involving single integer jointIndex representation.
* This will be false when the joint cardinality of the component domains is larger than 2<sup>31</sup>.
*/
public abstract boolean supportsJointIndexing();
/**
* Indicates whether class supports operations involving single integer output index representation
* of the subset of domains identified by {@link #getOutputSet()}. This will be false when the
* joint cardinality of the component output domains is larger than 2<sup>31</sup>.
* Note that this can be true when {@link #supportsJointIndexing} is false.
*/
public abstract boolean supportsOutputIndexing();
/**
* Verifies that the provided {@code indices} are in the correct range for
* this domain list, namely that:
* <ul>
* <li>{@code indices} has {@link #size()} elements
* <li>all values are non-negative
* <li>{@code indices[i]} < {@link #getDomainSize}{@code (i)}
* </ul>
* @throws IllegalArgumentException if wrong number of indices
* @throws IndexOutOfBoundsException if any index is out of range for its domain.
*/
public int[] validateIndices(int ... indices)
{
final DiscreteDomain[] domains = _domains;
final int length = domains.length;
if (indices.length != length)
{
throw new IllegalArgumentException(
String.format("Wrong number of indices: %d instead of %d", indices.length, length));
}
for (int i = 0; i < length; ++i)
{
final int index = indices[i];
if (index < 0 || index >= domains[i].size())
{
throw new IndexOutOfBoundsException(
String.format("Index %d out of bounds for domain %d with size %d", index, i, domains[i].size()));
}
}
return indices;
}
public Value[] validateValues(Value ... values)
{
final DiscreteDomain[] domains = _domains;
final int length = domains.length;
if (values.length != length)
{
throw new IllegalArgumentException(
String.format("Wrong number of values: %d instead of %d", values.length, length));
}
// TODO: should this check that value is in the appropriate domain?
for (int i = 0; i < length; ++i)
{
final int index = values[i].getIndex();
if (index < 0 || index >= domains[i].size())
{
throw new IndexOutOfBoundsException(
String.format("Index %d out of bounds for domain %d with size %d", index, i, domains[i].size()));
}
}
return values;
}
/*-------------------
* Protected methods
*/
protected static int computeHashCode(DiscreteDomain[] domains)
{
return Arrays.hashCode(domains);
}
static int computeHashCode(BitSet inputs, DiscreteDomain[] domains)
{
return Arrays.hashCode(domains) * 13 + inputs.hashCode();
}
/**
* Indicate whether the joint cardinality of the specified discrete domains
* is too large to be represented in an {@code int}.
* @see #domainSubsetTooLargeForIntegerIndex(DiscreteDomain[], int[])
*/
static boolean domainsTooLargeForIntegerIndex(DiscreteDomain[] domains)
{
if (domains.length > 31)
{
return true;
}
double logProduct = 0.0;
for (DiscreteDomain domain : domains)
{
logProduct += Math.log(domain.size());
}
double approxBits = logProduct / Math.log(2);
if (approxBits >= 32)
{
return true;
}
else if (approxBits <= 30)
{
return false;
}
// Compute product with longs to get exact answer when close to threshold.
long product = 1;
for (DiscreteDomain domain : domains)
{
product *= domain.size();
}
return product > Integer.MAX_VALUE;
}
/**
* Indicate whether the joint cardinality of a subset of the specified discrete domains
* is too large to be represented in an {@code int}.
*
* @param subindexes specifies a unique set of indices into {@code domains}.
* @see #domainSubsetTooLargeForIntegerIndex(DiscreteDomain[], int[])
*/
static boolean domainSubsetTooLargeForIntegerIndex(DiscreteDomain[] domains, int[] subindexes)
{
if (subindexes.length > 31)
{
return true;
}
double logProduct = 0.0;
for (int i : subindexes)
{
logProduct += Math.log(domains[i].size());
}
double approxBits = logProduct / Math.log(2);
if (approxBits >= 32)
{
return true;
}
else if (approxBits <= 30)
{
return false;
}
// Compute product with longs to get exact answer when close to threshold.
long product = 1;
for (int i : subindexes)
{
product *= domains[i].size();
}
return product > Integer.MAX_VALUE;
}
protected static boolean hasSameInputsImpl(int[] array1, int[] array2, int[] inputIndices)
{
for (int i : inputIndices)
{
if (array1[i] != array2[i])
{
return false;
}
}
return true;
}
void locationToElements(int location, Object[] elements, int[] subindices, int[] products)
{
final DiscreteDomain[] domains = _domains;
int product, index;
for (int i = subindices.length; --i >= 0;)
{
int j = subindices[i];
index = location / (product = products[j]);
elements[j] = domains[j].getElement(index);
location -= index * product;
}
}
void locationToValues(int location, Value[] elements, int[] subindices, int[] products)
{
final DiscreteDomain[] domains = _domains;
int product, index;
for (int i = subindices.length; --i >= 0;)
{
int j = subindices[i];
index = location / (product = products[j]);
final DiscreteDomain domain = domains[j];
final Value value = elements[j];
if (value.getDomain() == domain)
{
// If domain matches, then use the faster setIndex method.
//
// Because domains are interned, the == check should be sufficient the vast
// majority of the time, and in the unlikely event it is not, setObject will
// still do the right thing.
value.setIndex(index);
}
else
{
value.setObject(domain.getElement(index));
}
location -= index * product;
}
}
static void locationToIndices(int location, int[] indices, int[] subindices, int[] products)
{
int product, index;
for (int i = subindices.length; --i >= 0;)
{
int j = subindices[i];
indices[j] = index = location / (product = products[j]);
location -= index * product;
}
}
/*---------------
* Inner classes
*/
/**
* A comparator for integer arrays that compares them first by their input subindexes
* in reverse lexicographical order (from back to front) and then their output subindexes.
* This will order all index arrays such that the ones representing the same input will be adjacent.
*/
@Immutable
static class DirectedArrayComparator implements Comparator<int[]>, Serializable
{
private static final long serialVersionUID = 1L;
private final int[] _inputIndices;
private final int[] _outputIndices;
DirectedArrayComparator(int[] inputIndices, int[] outputIndices)
{
_inputIndices = inputIndices;
_outputIndices = outputIndices;
}
@Override
@NonNullByDefault(false)
public int compare(int[] array1, int[] array2)
{
int diff = array1.length - array2.length;
if (diff == 0)
{
for (int i = _inputIndices.length; --i>=0;)
{
int j = _inputIndices[i];
int val1 = array1[j], val2 = array2[j];
if (val1 != val2)
{
return val1 < val2 ? -1 : 1;
}
}
for (int i = _outputIndices.length; --i>=0;)
{
int j = _outputIndices[i];
int val1 = array1[j], val2 = array2[j];
if (val1 != val2)
{
return val1 < val2 ? -1 : 1;
}
}
}
return diff;
}
}
}