/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.model;
import static com.analog.lyric.math.Utilities.*;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.BitSetUtil;
import com.analog.lyric.collect.Comparators;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.DiscreteIndicesIterator;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.domains.DomainList;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.analog.lyric.util.test.SerializationTester;
public class TestJointDomainIndexer extends DimpleTestBase
{
@Test
public void test()
{
DiscreteDomain d1 = DiscreteDomain.create(42);
DiscreteDomain d2 = DiscreteDomain.range(0,1);
DiscreteDomain d3 = DiscreteDomain.range(0,2);
expectThrow(DimpleException.class, JointDomainIndexer.class, "create");
expectThrow(NullPointerException.class, JointDomainIndexer.class, "create", (DiscreteDomain[])null);
BitSet bitset = new BitSet();
bitset.set(0);
bitset.set(1);
bitset.set(2);
expectThrow(DimpleException.class, "Illegal output set for domain list.*", JointDomainIndexer.class, "create", bitset, d2, d3);
JointDomainIndexer dl2 = JointDomainIndexer.create(d2);
testInvariants(dl2);
JointDomainIndexer dl3 = JointDomainIndexer.create(d3);
testInvariants(dl3);
JointDomainIndexer dl2by3 = JointDomainIndexer.create(d2, d3);
testInvariants(dl2by3);
assertSame(dl2by3, JointDomainIndexer.create(d2, d3));
assertNotEquals(dl2by3, JointDomainIndexer.create(d3, d2));
JointDomainIndexer dl2by1by3 = JointDomainIndexer.create(d2, d1, d3);
testInvariants(dl2by1by3);
JointDomainIndexer dl2to3 = JointDomainIndexer.create(new int[] { 1 }, new DiscreteDomain[] {d2, d3 });
testInvariants(dl2to3);
assertNotEquals(dl2by3, dl2to3);
assertNotEquals(dl2to3, dl2by3);
assertNotEquals(dl2by3.hashCode(), dl2to3.hashCode());
JointDomainIndexer dl2from3 = JointDomainIndexer.create(new int[] { 0 }, new DiscreteDomain[] { d2, d3 });
testInvariants(dl2from3);
assertNotEquals(dl2to3, dl2from3);
assertNotEquals(dl2to3.hashCode(), dl2from3.hashCode());
DomainList<?> dl2by2 = DomainList.create(new Domain[] { d2, d2 });
assertTrue(dl2by2.isDiscrete());
testInvariants(requireNonNull(dl2by2.asJointDomainIndexer()));
// Test concat
assertSame(dl2, JointDomainIndexer.concat(dl2, null));
assertSame(dl2, JointDomainIndexer.concat(null, dl2));
JointDomainIndexer dl2by2a = requireNonNull(JointDomainIndexer.concat(dl2, dl2));
testInvariants(dl2by2a);
assertSame(dl2by2, dl2by2a);
JointDomainIndexer dlfoo = requireNonNull(JointDomainIndexer.concat(dl2to3, dl2from3));
testInvariants(dlfoo);
assertTrue(dlfoo.isDirected());
assertArrayEquals(new int[] { 1, 2 }, dlfoo.getOutputDomainIndices());
assertArrayEquals(new Object[] { d2, d3, d2, d3}, dlfoo.toArray());
assertSame(dl2to3, dlfoo.subindexer(0, 2));
assertSame(dl2from3, dlfoo.subindexer(2, 2));
JointDomainIndexer dlbar = JointDomainIndexer.create((BitSet)null, dlfoo);
assertTrue(dlfoo.domainsEqual(dlbar));
assertFalse(dlbar.isDirected());
testInvariants(dlbar);
//
// Test large-cardinality cases
//
final DiscreteDomain dshort = DiscreteDomain.range(Short.MIN_VALUE, Short.MAX_VALUE);
JointDomainIndexer dlshort2 = JointDomainIndexer.create(dshort, dshort);
assertFalse(dlshort2.supportsJointIndexing());
assertFalse(dlshort2.supportsOutputIndexing());
testInvariants(dlshort2);
DiscreteDomain[] d2x31 = new DiscreteDomain[31];
Arrays.fill(d2x31, d2);
JointDomainIndexer dl2x31 = JointDomainIndexer.create(BitSetUtil.bitsetFromIndices(31, 0, 1), d2x31);
assertArrayEquals(new int[] {0,1}, dl2x31.getOutputDomainIndices());
assertFalse(dl2x31.supportsJointIndexing());
assertTrue(dl2x31.supportsOutputIndexing());
assertTrue(dl2x31.hasCanonicalDomainOrder());
testInvariants(dl2x31);
DiscreteDomain[] d2x32 = new DiscreteDomain[32];
Arrays.fill(d2x32, d2);
JointDomainIndexer dl2x32 = JointDomainIndexer.create(BitSetUtil.bitsetFromIndices(32, 0, 2), d2x32);
assertArrayEquals(new int[] {0,2}, dl2x32.getOutputDomainIndices());
assertFalse(dl2x32.supportsJointIndexing());
assertTrue(dl2x32.supportsOutputIndexing());
assertFalse(dl2x32.hasCanonicalDomainOrder());
testInvariants(dl2x31);
DiscreteDomain d46340 = DiscreteDomain.range(1,46340);
DiscreteDomain[] d46340x2 = new DiscreteDomain[] { d46340, d46340 };
JointDomainIndexer dl46340x2 = JointDomainIndexer.create(d46340x2);
assertTrue(dl46340x2.supportsJointIndexing());
assertTrue(dl46340x2.supportsOutputIndexing());
testInvariants(dl46340x2);
//
// Test DomainList
//
DomainList<?> mixed = DomainList.create(d2, RealDomain.unbounded());
assertFalse(mixed.isDiscrete());
assertNull(mixed.asJointDomainIndexer());
assertFalse(DomainList.allDiscrete(d2, RealDomain.unbounded()));
assertTrue(DomainList.allDiscrete(new DiscreteDomain[] { d2, d3 }));
}
@SuppressWarnings("null")
public static void testInvariants(JointDomainIndexer indexer)
{
Random rand = new Random(42);
assertTrue(indexer.equals(indexer));
assertFalse(indexer.equals("foo"));
assertTrue(indexer.isDiscrete());
assertTrue(DomainList.allDiscrete(indexer.toArray(new Domain[indexer.size()])));
assertSame(indexer, indexer.asJointDomainIndexer());
final int size = indexer.size();
assertTrue(size > 0);
final int inSize = indexer.getInputSize();
assertTrue(inSize >= 0);
assertTrue(inSize < size);
final int outSize = indexer.getOutputSize();
assertTrue(outSize >= 0);
assertTrue(outSize <= size);
assertEquals(size, inSize + outSize);
final boolean supportsJoint = indexer.supportsJointIndexing();
final boolean supportsOutputIndex = indexer.supportsOutputIndexing();
if (supportsJoint)
{
assertTrue(supportsOutputIndex);
}
if (!supportsOutputIndex)
{
assertFalse(supportsJoint);
}
final int cardinality = supportsJoint ? indexer.getCardinality() : -1;
if (supportsJoint)
{
assertTrue(cardinality > 1);
}
final int inCardinality = supportsJoint ? indexer.getInputCardinality() : - 1;
if (supportsJoint)
{
assertTrue(inCardinality >= 1);
assertTrue(inCardinality <= cardinality);
}
final int outCardinality = supportsOutputIndex ? indexer.getOutputCardinality() : -1;
if (supportsOutputIndex)
{
assertTrue(outCardinality >= 1);
if (supportsJoint)
{
assertTrue(outCardinality <= cardinality);
assertEquals(cardinality, inCardinality * outCardinality);
}
}
DiscreteDomain[] domains = indexer.toArray(new DiscreteDomain[indexer.size()]);
int i = 0, expectedStride = 1;
for (DiscreteDomain domain : indexer)
{
if (supportsJoint)
{
assertEquals(expectedStride, indexer.getUndirectedStride(i));
if (indexer.hasCanonicalDomainOrder())
{
assertEquals(indexer.getStride(i), indexer.getUndirectedStride(i));
}
expectedStride *= domain.size();
}
assertSame(domain, indexer.get(i));
assertSame(domain, domains[i]);
assertEquals(domain.size(), indexer.getDomainSize(i));
assertTrue(indexer.getElementClass().isAssignableFrom(domain.getElementClass()));
++i;
}
assertEquals(size, i);
final BitSet inSet = indexer.getInputSet();
final BitSet outSet = indexer.getOutputSet();
int[] inIndices = indexer.getInputDomainIndices();
int[] outIndices = indexer.getOutputDomainIndices();
final int[] indices = new int[size], indices2 = new int[size];
final Object[] elements = new Object[size], elements2 = new Object[size];
indexer.randomIndices(rand, indices);
assertSame(elements, indexer.elementsFromIndices(indices, elements));
assertSame(indices2, indexer.elementsToIndices(elements, indices2));
assertArrayEquals(indices, indices2);
// Try using arrays that are too big - the extra entry should be ignored.
assertArrayEquals(indices, indexer.elementsToIndices(Arrays.copyOf(elements,size+1)));
assertArrayEquals(elements, indexer.elementsFromIndices(Arrays.copyOf(indices, size+1)));
// Count the number of times that the undirected/directed indexes match.
int canonicalCount = 0;
if (supportsJoint)
{
DiscreteIndicesIterator indicesIterator = new DiscreteIndicesIterator(indexer);
// Limit iteration to prevent test from taking to long for large cardinalities
int max = Math.min(cardinality, 10000);
for (i = 0; i < max; ++i)
{
Value[] values = Value.createFromObjects(elements, domains);
assertSame(indices, indexer.undirectedJointIndexToIndices(i, indices));
assertArrayEquals(indices, indexer.undirectedJointIndexToIndices(i, null));
assertArrayEquals(indices, indexer.undirectedJointIndexToIndices(i, new int[0]));
assertArrayEquals(indices, indexer.undirectedJointIndexToIndices(i));
assertSame(elements, indexer.undirectedJointIndexToElements(i, elements));
assertArrayEquals(elements, indexer.undirectedJointIndexToElements(i));
assertArrayEquals(elements, indexer.undirectedJointIndexToElements(i, null));
assertArrayEquals(elements, indexer.undirectedJointIndexToElements(i, new Object[0]));
assertSame(values, indexer.undirectedJointIndexToValues(i, values));
assertArrayEquals(Value.toObjects(values), Value.toObjects(indexer.undirectedJointIndexToValues(i)));
for (int j = 0; j < size; ++ j)
{
assertTrue(indices[j] >= 0);
assertTrue(indices[j] < indexer.getDomainSize(j));
assertEquals(indices[j], indexer.undirectedJointIndexToElementIndex(i, j));
assertEquals(elements[j], indexer.get(j).getElement(indices[j]));
}
assertTrue(indicesIterator.hasNext());
assertArrayEquals(indices, indicesIterator.next());
indexer.validateIndices(indices);
indexer.validateValues(values);
assertEquals(i, indexer.undirectedJointIndexFromElements(elements));
assertEquals(i, indexer.undirectedJointIndexFromIndices(indices));
assertEquals(i, indexer.undirectedJointIndexFromValues(values));
int ji = indexer.jointIndexFromIndices(indices);
assertEquals(ji, indexer.jointIndexFromElements(elements));
assertEquals(ji, indexer.jointIndexFromValues(values));
if (i == ji)
{
++canonicalCount;
}
int in = indexer.inputIndexFromIndices(indices);
assertEquals(in, indexer.inputIndexFromElements(elements));
assertEquals(in, indexer.inputIndexFromValues(values));
assertEquals(in, indexer.inputIndexFromJointIndex(ji));
int out = indexer.outputIndexFromIndices(indices);
assertEquals(out, indexer.outputIndexFromElements(elements));
assertEquals(out, indexer.outputIndexFromValues(values));
assertEquals(out, indexer.outputIndexFromJointIndex(ji));
assertEquals(ji, indexer.jointIndexFromInputOutputIndices(in, out));
Arrays.fill(indices2, -1);
Arrays.fill(elements2, null);
assertSame(indices2, indexer.jointIndexToIndices(ji, indices2));
assertArrayEquals(indices, indices2);
assertArrayEquals(indices, indexer.jointIndexToIndices(ji));
assertSame(elements2, indexer.jointIndexToElements(ji, elements2));
assertArrayEquals(elements, elements2);
Value[] values2 = Value.createFromObjects(elements2, domains);
assertSame(values2, indexer.jointIndexToValues(ji, values2));
for (int j = 0; j < size; ++j)
{
assertEquals(indices[j], indexer.jointIndexToElementIndex(ji, j));
}
Object[] elements3 = indexer.jointIndexToElements(ji);
assertArrayEquals(elements, elements3);
Value[] values3 = indexer.jointIndexToValues(ji);
assertArrayEquals(Value.toObjects(values), Value.toObjects(values3));
if (!indexer.isDirected())
{
assertEquals(ji, i);
}
assertEquals(out + in * indexer.getOutputCardinality(), ji);
Arrays.fill(indices2, -1);
Arrays.fill(elements2, null);
indexer.inputIndexToIndices(in, indices2);
assertEquals(in, indexer.inputIndexFromIndices(indices2));
indexer.inputIndexToElements(in, elements2);
assertEquals(in, indexer.inputIndexFromElements(elements2));
indexer.inputIndexToValues(in, values2);
assertEquals(in, indexer.inputIndexFromValues(values2));
if (indexer.isDirected())
{
}
else
{
for (int j : indices2)
{
assertEquals(-1, j);
}
}
Arrays.fill(indices2, -1);
Arrays.fill(elements2, null);
indexer.outputIndexToIndices(out, indices2);
assertEquals(out, indexer.outputIndexFromIndices(indices2));
indexer.outputIndexToElements(out, elements2);
assertEquals(out, indexer.outputIndexFromElements(elements2));
indexer.outputIndexToValues(out, values2);
assertEquals(out, indexer.outputIndexFromValues(values2));
if (indexer.isDirected())
{
}
else
{
for (int j = 0; j < size; ++j)
{
assertEquals(indices[j], indices2[j]);
}
}
}
if (max == cardinality)
{
assertFalse(indicesIterator.hasNext());
assertEquals(indexer.hasCanonicalDomainOrder(), canonicalCount == cardinality);
}
else
{
assertTrue(indicesIterator.hasNext());
}
}
else // !supportsJoint
{
expectNoJoint(indexer, "getCardinality");
expectNoJoint(indexer, "getStride", 0);
expectNoJoint(indexer, "getUndirectedStride", 0);
expectNoJoint(indexer, "undirectedJointIndexFromElements");
expectNoJoint(indexer, "undirectedJointIndexFromIndices");
expectNoJoint(indexer, "undirectedJointIndexFromValues");
expectNoJoint(indexer, "undirectedJointIndexToElements", 42);
expectNoJoint(indexer, "undirectedJointIndexToIndices", 42);
expectNoJoint(indexer, "undirectedJointIndexToElementIndex", 1, 2);
if (!supportsOutputIndex)
{
expectNoJoint(indexer, "getOutputCardinality");
}
}
if (indexer.isDirected())
{
assertFalse(inSet.intersects(outSet));
assertEquals(size, inSet.cardinality(), outSet.cardinality());
assertNotSame(inSet, indexer.getInputSet());
assertEquals(inSet, indexer.getInputSet());
assertNotSame(outSet, indexer.getOutputSet());
assertEquals(outSet, indexer.getOutputSet());
assertEquals(indexer.hasCanonicalDomainOrder(),
Comparators.reverseLexicalIntArray() == indexer.getIndicesComparator());
assertEquals(inIndices.length, inSet.cardinality());
assertEquals(outIndices.length, outSet.cardinality());
assertEquals(inIndices.length, indexer.getInputSize());
assertEquals(outIndices.length, indexer.getOutputSize());
expectedStride = 1;
for (int j = 0; j < outIndices.length; ++j)
{
assertEquals(outIndices[j], indexer.getOutputDomainIndex(j));
assertTrue(outSet.get(outIndices[j]));
if (supportsJoint)
{
assertEquals(expectedStride, indexer.getStride(outIndices[j]));
expectedStride *= indexer.getDomainSize(outIndices[j]);
}
}
for (int j = 0; j < inIndices.length; ++j)
{
assertEquals(inIndices[j], indexer.getInputDomainIndex(j));
assertTrue(inSet.get(inIndices[j]));
if (supportsJoint)
{
assertEquals(expectedStride, indexer.getStride(inIndices[j]));
expectedStride *= indexer.getDomainSize(inIndices[j]);
}
}
Arrays.fill(indices, 0);
Arrays.fill(indices2, 1);
assertEquals(inIndices.length == 0, indexer.hasSameInputs(indices, indices2));
for (int x : inIndices)
{
indices[x] = 1;
}
assertTrue(indexer.hasSameInputs(indices, indices2));
}
else
{
assertEquals(1, indexer.getInputCardinality());
assertEquals(0, inSize);
assertNull(inSet);
assertNull(outSet);
assertNull(inIndices);
assertNull(outIndices);
assertSame(Comparators.reverseLexicalIntArray(), indexer.getIndicesComparator());
assertTrue(indexer.hasSameInputs(new int[] {0}, new int[] {1}));
for (i = 0; i < size; ++i)
{
assertEquals(i, indexer.getOutputDomainIndex(i));
}
expectThrow(ArrayIndexOutOfBoundsException.class, indexer, "getInputDomainIndex", 0);
expectThrow(ArrayIndexOutOfBoundsException.class, indexer, "getOutputDomainIndex", -1);
expectThrow(ArrayIndexOutOfBoundsException.class, indexer, "getOutputDomainIndex", size);
}
expectThrow(IllegalArgumentException.class, "Wrong number of indices.*", indexer, "validateIndices");
Arrays.fill(indices, 0);
indices[0] = -1;
expectThrow(IndexOutOfBoundsException.class, indexer, "validateIndices", indices);
Arrays.fill(indices, 0);
indices[0] = indexer.getDomainSize(0);
expectThrow(IndexOutOfBoundsException.class, indexer, "validateIndices", indices);
JointDomainIndexer domainList2 = SerializationTester.clone(indexer);
assertSame(indexer, domainList2);
assertEquals(indexer.hashCode(), domainList2.hashCode());
}
@Test
public void testReindexer()
{
DiscreteDomain d2 = DiscreteDomain.range(0,1);
DiscreteDomain d3 = DiscreteDomain.range(0,2);
DiscreteDomain d4 = DiscreteDomain.range(0,3);
DiscreteDomain d5 = DiscreteDomain.range(0,4);
JointDomainIndexer dl2 = JointDomainIndexer.create(d2);
JointDomainIndexer dl3 = JointDomainIndexer.create(d3);
JointDomainIndexer dl4 = JointDomainIndexer.create(d4);
JointDomainIndexer dl2by3 = JointDomainIndexer.create(d2, d3);
JointDomainIndexer dl3by2 = JointDomainIndexer.create(d3, d2);
JointDomainIndexer dl3by4 = JointDomainIndexer.create(d3, d4);
JointDomainIndexer dl4by2 = JointDomainIndexer.create(d4, d2);
JointDomainIndexer dl2to3 = JointDomainIndexer.create(new int[] {1}, new DiscreteDomain[] {d2, d3});
JointDomainIndexer dl2from3 = JointDomainIndexer.create(new int[] {0}, new DiscreteDomain[] {d2, d3});
// A simple permutation
JointDomainReindexer dl2by3_to_dl3by2 =
JointDomainReindexer.createPermuter(dl2by3, null, dl3by2, null, new int[] { 1, 0});
assertSame(dl2by3, dl2by3_to_dl3by2.getFromDomains());
assertSame(dl3by2, dl2by3_to_dl3by2.getToDomains());
testInvariants(dl2by3_to_dl3by2);
assertNotEquals(dl2by3_to_dl3by2, dl2by3_to_dl3by2.getInverse());
assertNotEquals(dl2by3_to_dl3by2.hashCode(), dl2by3_to_dl3by2.getInverse().hashCode());
JointDomainReindexer dl3by2_to_dl2by3 =
JointDomainReindexer.createPermuter(dl3by2, null, dl2by3, null, new int[] { 1, 0});
testInvariants(dl3by2_to_dl2by3);
assertEquals(dl2by3_to_dl3by2, dl3by2_to_dl2by3.getInverse());
assertEquals(dl3by2_to_dl2by3, dl2by3_to_dl3by2.getInverse());
assertEquals(dl2by3_to_dl3by2.hashCode(), dl3by2_to_dl2by3.getInverse().hashCode());
// Remove a domain
JointDomainReindexer dl2by3_to_dl3 = JointDomainReindexer.createRemover(dl2by3, 0);
assertSame(dl2by3, dl2by3_to_dl3.getFromDomains());
assertSame(dl2, dl2by3_to_dl3.getRemovedDomains());
testInvariants(dl2by3_to_dl3);
assertNotEquals(dl2by3_to_dl3by2, dl2by3_to_dl3);
assertNotEquals(dl2by3_to_dl3by2.hashCode(), dl2by3_to_dl3.hashCode());
JointDomainReindexer dl2by3_to_dl2 = JointDomainReindexer.createRemover(dl2by3, 1);
assertSame(dl2by3, dl2by3_to_dl2.getFromDomains());
assertEquals(dl2, dl2by3_to_dl2.getToDomains());
testInvariants(dl2by3_to_dl2);
JointDomainIndexer dl2by3by4by5 = JointDomainIndexer.create(d2, d3, d4, d5);
JointDomainReindexer dl2by3by4by5_to_dl2by = JointDomainReindexer.createJoiner(dl2by3by4by5, 1, 2);
JointDomainIndexer dl2by12by5 = dl2by3by4by5_to_dl2by.getToDomains();
testInvariants(dl2by3by4by5_to_dl2by);
assertEquals(3, dl2by12by5.size());
assertEquals(12, dl2by12by5.getDomainSize(1));
assertNotEquals(dl2by3_to_dl3by2, dl2by3by4by5_to_dl2by);
assertNotEquals(dl2by3by4by5_to_dl2by, dl2by3_to_dl3);
assertNotEquals(dl2by3by4by5_to_dl2by, dl2by3by4by5_to_dl2by.getInverse());
assertNotEquals(dl2by3by4by5_to_dl2by.hashCode(), dl2by3by4by5_to_dl2by.getInverse().hashCode());
JointDomainReindexer dl2by12by5_to_dl2by3by4by5 = JointDomainReindexer.createSplitter(dl2by12by5, 1);
testInvariants(dl2by12by5_to_dl2by3by4by5);
assertEquals(dl2by3by4by5_to_dl2by, dl2by12by5_to_dl2by3by4by5.getInverse());
JointDomainReindexer dl2by3_to_dl3by4 =
JointDomainReindexer.createPermuter(dl2by3, dl4, dl3by4, dl2, new int [] { 2, 0, 1 });
testInvariants(dl2by3_to_dl3by4);
JointDomainReindexer dl2by3_to_dl4by2 =
JointDomainReindexer.createPermuter(dl2by3, dl4, dl4by2, dl3, new int [] { 1, 2, 0 });
testInvariants(dl2by3_to_dl4by2);
// Chain
JointDomainReindexer dl3by2_to_dl3 = dl3by2_to_dl2by3.combineWith(dl2by3_to_dl3);
testInvariants(dl3by2_to_dl3);
expectThrow(DimpleException.class, dl3by2_to_dl3, "combineWith", dl2by12by5_to_dl2by3by4by5);
// Directed conversion
JointDomainReindexer dl3by2_to_dl3to2 =
JointDomainReindexer.createPermuter(dl2by3, dl2to3);
testInvariants(dl3by2_to_dl3to2);
JointDomainReindexer dl3by2_to_dl3from2 =
JointDomainReindexer.createPermuter(dl2by3, dl2from3);
testInvariants(dl3by2_to_dl3from2);
// Deduce added/removed domains
JointDomainReindexer dl4_to_dl3by4 = JointDomainReindexer.createPermuter(dl4, dl3by4, new int[] {1});
testInvariants(dl4_to_dl3by4);
dl4_to_dl3by4 = JointDomainReindexer.createPermuter(dl4, dl3by4, new int[] {1, 0});
testInvariants(dl4_to_dl3by4);
JointDomainReindexer dl3by4_to_dl4 = JointDomainReindexer.createPermuter(dl3by4, dl4, new int[] {1,0});
testInvariants(dl3by4_to_dl4);
//
// Conditioning
//
JointDomainReindexer dl2by3by4by5_conditionedOn_2_3 =
JointDomainReindexer.createConditioner(dl2by3by4by5, new int[] {-1, -1, 2, 3});
testInvariants(dl2by3by4by5_conditionedOn_2_3);
assertEquals(2, dl2by3by4by5_conditionedOn_2_3.getToDomains().size());
// instead of removing dimension, replace with single-element domains
dl2by3by4by5_conditionedOn_2_3 =
JointDomainReindexer.createConditioner(dl2by3by4by5, new int[] {-1, -1, 2, 3}, true);
testInvariants(dl2by3by4by5_conditionedOn_2_3);
assertEquals(4, dl2by3by4by5_conditionedOn_2_3.getToDomains().size());
assertEquals(1, dl2by3by4by5_conditionedOn_2_3.getToDomains().get(2).size());
assertEquals(dl2by3by4by5_conditionedOn_2_3.getFromDomains().get(2).getElement(2),
dl2by3by4by5_conditionedOn_2_3.getToDomains().get(2).getElement(0));
assertEquals(1, dl2by3by4by5_conditionedOn_2_3.getToDomains().get(3).size());
assertEquals(dl2by3by4by5_conditionedOn_2_3.getFromDomains().get(3).getElement(3),
dl2by3by4by5_conditionedOn_2_3.getToDomains().get(3).getElement(0));
JointDomainReindexer dl2by3by4by5_conditionedOn_1_2 =
JointDomainReindexer.createConditioner(dl2by3by4by5, new int[] {1, -1, 2, -1});
testInvariants(dl2by3by4by5_conditionedOn_1_2, false);
assertEquals(2, dl2by3by4by5_conditionedOn_1_2.getToDomains().size());
dl2by3by4by5_conditionedOn_1_2 =
JointDomainReindexer.createConditioner(dl2by3by4by5, new int[] {1, -1, 2, -1}, true);
testInvariants(dl2by3by4by5_conditionedOn_1_2, false);
assertEquals(4, dl2by3by4by5_conditionedOn_1_2.getToDomains().size());
assertEquals(1, dl2by3by4by5_conditionedOn_1_2.getToDomains().get(0).size());
assertEquals(dl2by3by4by5_conditionedOn_1_2.getFromDomains().get(0).getElement(1),
dl2by3by4by5_conditionedOn_1_2.getToDomains().get(0).getElement(0));
assertEquals(1, dl2by3by4by5_conditionedOn_1_2.getToDomains().get(2).size());
assertEquals(dl2by3by4by5_conditionedOn_1_2.getFromDomains().get(2).getElement(2),
dl2by3by4by5_conditionedOn_1_2.getToDomains().get(2).getElement(0));
//
// Construction errors
//
expectThrow(IllegalArgumentException.class, "Combined size.*",
JointDomainReindexer.class, "createPermuter", dl2, dl3, dl3, null, new int[] { 0, 1 });
expectThrow(IllegalArgumentException.class, ".*does not match domain sizes.*",
JointDomainReindexer.class, "createPermuter", dl2, null, dl2, null, new int[] { 0, 1, 2 });
expectThrow(IllegalArgumentException.class, ".*out-of-range value -1.*",
JointDomainReindexer.class, "createPermuter", dl2, null, dl2, null, new int[] { -1 });
expectThrow(IllegalArgumentException.class, ".*out-of-range value 2.*",
JointDomainReindexer.class, "createPermuter", dl2, null, dl2, null, new int[] { 2 });
expectThrow(IllegalArgumentException.class, ".*two entries mapping to 0.*",
JointDomainReindexer.class, "createPermuter", dl2, dl3, dl2by3, null, new int[] { 0, 0 });
expectThrow(IllegalArgumentException.class, ".*domain size mismatch at index 0.*",
JointDomainReindexer.class, "createPermuter", dl2, null, dl3, null, new int[] { 0 });
}
public void testInvariants(JointDomainReindexer converter)
{
testInvariants(converter, true);
}
@SuppressWarnings("null")
private void testInvariants(JointDomainReindexer converter, boolean testInverse)
{
assertEquals(converter, converter);
JointDomainReindexer inverse = converter.getInverse();
assertEquals(converter, inverse.getInverse());
JointDomainReindexer.Indices indices = converter.getScratch();
assertSame(converter, indices.converter);
assertEquals(converter.getFromDomains().size(), indices.fromIndices.length);
assertEquals(converter.getToDomains().size(), indices.toIndices.length);
if (converter.getAddedDomains() == null)
{
assertSame(ArrayUtil.EMPTY_INT_ARRAY, indices.addedIndices);
}
else
{
assertEquals(converter.getAddedDomains().size(), indices.addedIndices.length);
}
if (converter.getRemovedDomains() == null)
{
assertSame(ArrayUtil.EMPTY_INT_ARRAY, indices.removedIndices);
}
else
{
assertEquals(converter.getRemovedDomains().size(), indices.removedIndices.length);
}
indices.release();
assertSame(indices, converter.getScratch());
assertNotSame(indices, converter.getScratch());
indices = converter.getScratch();
final int maxFrom = converter.getFromDomains().getCardinality();
final int maxAdded = converter.getAddedCardinality();
final AtomicInteger removedRef = new AtomicInteger();
final AtomicInteger removedRef2 = new AtomicInteger();
final AtomicInteger addedRef = new AtomicInteger();
double[] fromDenseWeights = new double[maxFrom];
double[] fromDenseEnergies = new double[maxFrom];
for (int i = 0; i < maxFrom; ++i)
{
double w = testRand.nextDouble();
fromDenseWeights[i] = w;
fromDenseEnergies[i] = weightToEnergy(w);
}
final int maxTo = converter.getToDomains().getCardinality();
final int maxRemoved = converter.getRemovedCardinality();
final double[] toDenseWeights = converter.convertDenseWeights(fromDenseWeights);
assertEquals(maxTo, toDenseWeights.length);
double[] toDenseEnergies = converter.convertDenseEnergies(fromDenseEnergies);
assertEquals(maxTo, toDenseEnergies.length);
for (int from = 0; from < maxFrom; ++from)
{
double fromWeight = fromDenseWeights[from];
double fromEnergy = fromDenseEnergies[from];
for (int added = 0; added < maxAdded; ++added)
{
int to = converter.convertJointIndex(from, added, null);
assertEquals(to, converter.convertJointIndex(from, added));
assertEquals(to, converter.convertJointIndex(from, added, removedRef));
indices.writeIndices(from, added);
converter.convertIndices(indices);
if (indices.toIndices[0] < 0)
{
assertTrue(to < 0);
}
if (to >= 0)
{
assertEquals(from, inverse.convertJointIndex(to, removedRef.get(), null));
assertEquals(from, inverse.convertJointIndex(to, removedRef.get(), addedRef));
assertEquals(added, addedRef.get());
int to2 = indices.readIndices(null);
assertEquals(to, to2);
assertEquals(to, indices.readIndices(removedRef2));
assertEquals(removedRef.get(), removedRef2.get());
if (maxRemoved == 1)
{
assertEquals(fromWeight, toDenseWeights[to], 0.0);
assertEquals(fromEnergy, toDenseEnergies[to], 0.0);
}
else
{
// Weight must be equal sum of entries mapping to this one
double weightSum = 0.0;
for (int removed = 0; removed < maxRemoved; ++removed)
{
int fromInverse = inverse.convertJointIndex(to, removed);
if (fromInverse >= 0)
{
weightSum += fromDenseWeights[fromInverse];
}
}
assertEquals(weightSum, toDenseWeights[to], 1e-12);
assertEquals(weightToEnergy(weightSum), toDenseEnergies[to], 1e-12);
}
}
}
}
//
// Test sparse conversions
//
// A "dense" sparse to joint index.
final int[] fromDenseSparseToJoint = new int[maxFrom];
for (int i = 0; i < maxFrom; ++i)
{
fromDenseSparseToJoint[i] = i;
}
final int[] toDenseSparseToJoint = converter.convertSparseToJointIndex(fromDenseSparseToJoint);
assertTrue(maxTo >= toDenseSparseToJoint.length);
if (maxTo == toDenseSparseToJoint.length)
{
for (int i = toDenseSparseToJoint.length; --i>=0;)
{
assertEquals(i, toDenseSparseToJoint[i]);
}
assertArrayEquals(
toDenseWeights,
converter.convertSparseWeights(fromDenseWeights, fromDenseSparseToJoint, toDenseSparseToJoint),
1e-12);
assertArrayEquals(
toDenseEnergies,
converter.convertSparseEnergies(fromDenseEnergies, fromDenseSparseToJoint, toDenseSparseToJoint),
1e-12);
}
// Test a random sparse selection
BitSet sparseSet = new BitSet(maxFrom);
for (int i = maxFrom/2; --i>=0;)
{
sparseSet.set(testRand.nextInt(maxFrom));
}
final int[] fromSparseToJoint = new int[sparseSet.cardinality()];
for (int i = 0, sparseIndex = -1; i < fromSparseToJoint.length; ++i)
{
sparseIndex = sparseSet.nextSetBit(sparseIndex+1);
fromSparseToJoint[i] = sparseIndex;
}
final int[] toSparseToJoint = converter.convertSparseToJointIndex(fromSparseToJoint);
for (int oldSparse : fromSparseToJoint)
{
for (int added = 0; added < maxAdded; ++added)
{
int newSparse = converter.convertJointIndex(oldSparse, added);
if (newSparse >= 0)
{
assertTrue(Arrays.binarySearch(toSparseToJoint, newSparse) >= 0);
}
}
}
final double[] fromSparseWeights = new double[fromSparseToJoint.length];
final double[] fromSparseEnergies = new double[fromSparseToJoint.length];
for (int si = fromSparseToJoint.length; --si>=0;)
{
int ji = fromSparseToJoint[si];
fromSparseWeights[si] = fromDenseWeights[ji];
fromSparseEnergies[si] = fromDenseEnergies[ji];
}
final double[] toSparseWeights =
converter.convertSparseWeights(fromSparseWeights, fromSparseToJoint, toSparseToJoint);
final double[] toSparseEnergies =
converter.convertSparseEnergies(fromSparseEnergies, fromSparseToJoint, toSparseToJoint);
for (int si = toSparseToJoint.length; --si>=0;)
{
int ji = toSparseToJoint[si];
if (maxRemoved == 1)
{
assertEquals(toDenseWeights[ji], toSparseWeights[si], 1e-12);
assertEquals(toDenseEnergies[ji], toSparseEnergies[si], 1e-12);
}
else
{
// Weight must be equal sum of entries mapping to this one
double weightSum = 0.0;
for (int removed = 0; removed < maxRemoved; ++removed)
{
int fromInverse = inverse.convertJointIndex(ji, removed);
if (Arrays.binarySearch(fromSparseToJoint, fromInverse) >= 0)
{
weightSum += fromDenseWeights[fromInverse];
}
}
assertEquals(weightSum, toSparseWeights[si], 1e-12);
assertEquals(weightToEnergy(weightSum), toSparseEnergies[si], 1e-12);
}
}
if (testInverse)
{
testInvariants(inverse, false);
}
}
private static void expectNoJoint(Object obj, String methodName, Object ... args)
{
expectThrow(DimpleException.class,
".*" + methodName + "' not supported for very large joint( output)? domain cardinality.*",
obj, methodName, args);
}
}