/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.domains; import static com.analog.lyric.math.Utilities.*; import java.util.Arrays; import java.util.BitSet; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.collect.BitSetUtil; import cern.colt.list.IntArrayList; import cern.colt.map.OpenIntIntHashMap; import net.jcip.annotations.NotThreadSafe; import net.jcip.annotations.ThreadSafe; /** * Supports conversion of indexes between two {@link JointDomainIndexer}s. */ @ThreadSafe public abstract class JointDomainReindexer { /*---------------- * Nested classes */ /** * A set of domain indexes sized appropriately for a single {@link #converter} for use * in conversion computation. * * @see JointDomainReindexer#getScratch() * @see JointDomainReindexer#convertIndices(Indices) */ @NotThreadSafe public static class Indices { public final JointDomainReindexer converter; public final int[] fromIndices; public final int[] toIndices; public final int[] addedIndices; public final int[] removedIndices; public final int[] joinedIndices; private Indices(JointDomainReindexer converter) { this.converter = converter; fromIndices = new int[converter._fromDomains.size()]; toIndices = new int[converter._toDomains.size()]; final JointDomainIndexer addedDomains = converter._addedDomains; addedIndices = addedDomains == null ? ArrayUtil.EMPTY_INT_ARRAY : new int[addedDomains.size()]; final JointDomainIndexer removedDomains = converter._removedDomains; removedIndices = removedDomains == null ? ArrayUtil.EMPTY_INT_ARRAY : new int[removedDomains.size()]; int joinedSize = Math.abs(Math.abs(fromIndices.length - toIndices.length) - Math.abs(addedIndices.length - removedIndices.length)); joinedIndices = joinedSize == 0 ? ArrayUtil.EMPTY_INT_ARRAY : new int[joinedSize]; } /** * Fills in {@link #fromIndices} from {@code fromJointIndex} and * {@link #addedIndices} from {@code addedJointIndex}. If {@code #converter} * has null for {@link JointDomainReindexer#getAddedDomains()}, then * {@code addedJointIndex} will be ignored. */ @SuppressWarnings("null") public void writeIndices(int fromJointIndex, int addedJointIndex) { converter._fromDomains.jointIndexToIndices(fromJointIndex, fromIndices); if (addedIndices.length > 0) { converter._addedDomains.jointIndexToIndices(addedJointIndex, addedIndices); } } /** * Returns joint index computed from {@code #toIndices} and if * {@code removedJointIndexRef} is non-null and {@code #converter} * has non-null for {@link JointDomainReindexer#getRemovedDomains()}, * this will compute a joint index from {@code #removedIndices} and set * it in that {@code removedJointIndexRef}. Returns -1 if {@code #toIndices} * contains negative value. */ @SuppressWarnings("null") public int readIndices(@Nullable AtomicInteger removedJointIndexRef) { if (toIndices[0] < 0) { return -1; } int toJointIndex = converter._toDomains.jointIndexFromIndices(toIndices); if (removedJointIndexRef != null && removedIndices.length > 0) { removedJointIndexRef.set(converter._removedDomains.jointIndexFromIndices(removedIndices)); } return toJointIndex; } /** * Release object back to its {@link #converter} so that it may be reused. * The caller must not use the object after calling this method. */ public final void release() { converter.releaseScratch(this); } } /*------- * State */ protected final @Nullable JointDomainIndexer _addedDomains; protected final JointDomainIndexer _fromDomains; protected final @Nullable JointDomainIndexer _removedDomains; protected final JointDomainIndexer _toDomains; private final AtomicReference<Indices> _scratch = new AtomicReference<Indices>(); /*-------------- * Construction */ protected JointDomainReindexer( JointDomainIndexer fromDomains, @Nullable JointDomainIndexer addedDomains, JointDomainIndexer toDomains, @Nullable JointDomainIndexer removedDomains) { _fromDomains = fromDomains; _addedDomains = addedDomains; _toDomains = toDomains; _removedDomains = removedDomains; } /** * Creates a converter that removes domains conditioned on the specified value indices. * <p> * @param fromDomains * @param valueIndices is an array of length equal to size of {@code fromDomains}. Each entry * is either a negative value to indicate that that dimension will be retained in the new graph * or an index value that is valid for that dimension that indicates the value to be conditioned on. * * @since 0.05 * @see #createConditioner(JointDomainIndexer, int[], boolean) */ public static JointDomainReindexer createConditioner(JointDomainIndexer fromDomains, int[] valueIndices) { return createConditioner(fromDomains, valueIndices, false); } /** * Creates a converter that removes or replaces domains conditioned on the specified value indices. * <p> * @param fromDomains * @param valueIndices is an array of length equal to size of {@code fromDomains}. Each entry * is either a negative value to indicate that that dimension will be retained in the new graph * or an index value that is valid for that dimension that indicates the value to be conditioned on. * @param retainSize indicates whether the new indexer should retain the same size as the original one. If true * then instead of removing conditioned dimensions, they will be replaced with a single-element domain. * <p> * @since 0.08 * @see #createConditioner(JointDomainIndexer, int[], boolean) */ public static JointDomainReindexer createConditioner(final JointDomainIndexer fromDomains, int[] valueIndices, boolean retainSize) { final int fromSize = fromDomains.size(); if (valueIndices.length != fromSize) { throw new ArrayIndexOutOfBoundsException(); // FIXME add message } final int[] oldToNew = new int[fromSize]; final IntArrayList conditionedValues = new IntArrayList(); int firstConditionedIndex = -1; for (int i = 0, j = 0; i < fromSize; ++i) { final int valueIndex = valueIndices[i]; if (valueIndex >= 0) { if (valueIndex >= fromDomains.getDomainSize(i)) { throw new IndexOutOfBoundsException(); // FIXME add message } conditionedValues.add(valueIndex); oldToNew[i] = -1; if (firstConditionedIndex < 0) { firstConditionedIndex = i; } } else { oldToNew[i] = j++; } } if (conditionedValues.isEmpty()) { // No conditioning - return identity reindexer. return createPermuter(fromDomains,fromDomains); } conditionedValues.trimToSize(); JointDomainReindexer permuter = null; JointDomainIndexer prevDomains = fromDomains; if (firstConditionedIndex + conditionedValues.size() != fromSize) { // Not all conditioned dimensions are already at the end, so a permutation is needed. final int nNotConditioned = fromSize - conditionedValues.size(); for (int i = firstConditionedIndex, j = 0; i < fromSize; ++i) { if (oldToNew[i] < 0) { oldToNew[i] = nNotConditioned + j++; } } permuter = createPermuter(fromDomains, oldToNew); prevDomains = permuter.getToDomains(); } JointDomainReindexer conditioner = JointDomainIndexConditioner._createConditioner(prevDomains, conditionedValues.elements()); if (permuter != null) { conditioner = permuter.combineWith(conditioner); } if (retainSize) { // Add a permuter to replace removed domains with single-element domains. final JointDomainIndexer conditioned = conditioner.getToDomains(); final int[] newToOld = new int[conditioned.size()]; final DiscreteDomain[] domains = fromDomains.toArray(new DiscreteDomain[fromSize]); for (int i = 0, j = 0; i < fromSize; ++i) { int vi = valueIndices[i]; if (vi >= 0) { domains[i] = DiscreteDomain.create(fromDomains.get(i).getElement(vi)); } else { newToOld[j++] = i; } } JointDomainIndexer toDomains = JointDomainIndexer.create(fromDomains.getOutputSet(), domains); JointDomainReindexer domainReplacer = createPermuter(conditioned, toDomains, newToOld); conditioner = conditioner.combineWith(domainReplacer); } return conditioner; } /** * Creates a converter that implements a permutation of the domains in * {@code fromDomains} and {@code addedDomains} into {@code ToDomains} and * {@code removedDomains}. * <p> * The total number of domains in the from+added domains must equal those in * to+removed domains. * <p> * @param fromDomains must be non-null * @param addedDomains may be null * @param toDomains must be non-null * @param removedDomains may be null * @param oldToNewIndex maps the index of the old domain to the index of the new domain where the index of domains * in {@code fromDomains} are in the range [0,fromSize-1] and the index of domains in * {@code addedDomains} is in the range [fromSize, totalSize-1]. Likewise the domains in {@code toDomains} and * {@code removeDomains} are in the range [0, toSize-1] and [toSize, totalSize - 1]. The mapping must not * omit or repeat elements and may not map domains of different sizes. */ public static JointDomainReindexer createPermuter( JointDomainIndexer fromDomains, @Nullable JointDomainIndexer addedDomains, JointDomainIndexer toDomains, @Nullable JointDomainIndexer removedDomains, int[] oldToNewIndex) { return new JointDomainIndexPermuter(fromDomains, addedDomains, toDomains, removedDomains, oldToNewIndex); } /** * Creates a converter that implements a permutation of the domains * in {@code fromDomains} to those in {@code toDomains}. * <p> * This implementation will invoke * {@link #createPermuter(JointDomainIndexer, JointDomainIndexer, JointDomainIndexer, JointDomainIndexer, int[])} * after calculating the {@code addedDomains} and {@code removedDomains} arguments. * <p> * There are three cases: * <dl> * <dt>{@code fromDomains} and {@code toDomains} have the same size.</dt> * <dd>The {@code oldToNewIndex} must also be the same size. The {@code addedDomains} and {@code removedDomains} * arguments will be null. * </dd> * * <dt>{@code fromDomains} is smaller than {@code toDomains}</dt> * <dd>It is necessary to deduce the {@code addedDomains}. If {@code oldToNewIndex} is the same * length as {@code toDomains} it specifies how the added domains should be chosen from {@code toDomains}. * If it is the same length as {@code fromDomains} then the {@code addedDomains} will be chosen from * the remaining {@code toDomains} that are not in the index in order. * </dd> * * <dt>{@code fromDomains} is larger than {@code toDomains}</dt> * <dd>The {@code oldToNewIndex} must be the same size as {@code fromDomains}. * It is necessary to deduce the {@code removedDomains}. * </dd> * </ol> */ public static JointDomainReindexer createPermuter( JointDomainIndexer fromDomains, JointDomainIndexer toDomains, int[] oldToNewIndex) { final int fromSize = fromDomains.size(); final int toSize = toDomains.size(); final int diff = fromSize - toSize; if (diff == 0) { return createPermuter(fromDomains, null, toDomains, null, oldToNewIndex); } else if (diff < 0) { // Need to compute added domains final int addedSize = -diff; if (oldToNewIndex.length < toSize) { // index map is too short. Deduce rest of it from missing entries. final BitSet toSet = BitSetUtil.bitsetFromIndices(toSize, oldToNewIndex); oldToNewIndex = Arrays.copyOf(oldToNewIndex, toSize); for (int from = fromSize, to = -1; from < toSize; ++from) { to = toSet.nextClearBit(to + 1); oldToNewIndex[from] = to; } } final DiscreteDomain[] addedDomains = new DiscreteDomain[addedSize]; for (int i = 0; i < addedSize; ++i) { addedDomains[i] = toDomains.get(oldToNewIndex[i + fromSize]); } return createPermuter(fromDomains, JointDomainIndexer.create(addedDomains), toDomains, null, oldToNewIndex); } else { // From is longer than to - need to compute removed domains final int removedSize = diff; final DiscreteDomain[] removedDomains = new DiscreteDomain[removedSize]; for (int from = 0; from < fromSize; ++from) { final int to = oldToNewIndex[from]; if (to >= toSize) { removedDomains[to - toSize] = fromDomains.get(from); } } return createPermuter(fromDomains, null, toDomains, JointDomainIndexer.create(removedDomains), oldToNewIndex); } } /** * Creates a converter to convert {@code fromDomains} to {@code toDomains} * while maintaining the same domain order. This can be used to convert between * domain lists with different input/output domain sets. */ public static JointDomainReindexer createPermuter( JointDomainIndexer fromDomains, JointDomainIndexer toDomains) { int[] oldToNew = new int[fromDomains.size()]; for (int i = oldToNew.length; --i>=0;) { oldToNew[i] = i; } return createPermuter(fromDomains, toDomains, oldToNew); } /** * Creates a converter to convert {@code fromDomains} by permuting the order of its * domains according to the {@code oldToNew} mapping. * * @param fromDomains * @param oldToNew must have the same length as the size of {@code fromDomains} and must contain * integers in the range [0,size-1] with no duplicates. * * @since 0.05 */ public static JointDomainReindexer createPermuter(JointDomainIndexer fromDomains, int[] oldToNew) { final int size = fromDomains.size(); if (oldToNew.length != size) { throw new IllegalArgumentException(); // FIXME- provide a message } DiscreteDomain[] domains = new DiscreteDomain[size]; for (int i = 0; i < size; ++i) { domains[oldToNew[i]] = fromDomains.get(i); } final BitSet fromOutputSet = fromDomains.getOutputSet(); BitSet outputSet = null; if (fromOutputSet != null) { outputSet = new BitSet(size); for (int i = 0; i < size; ++i) { outputSet.set(oldToNew[i], fromOutputSet.get(i)); } } return createPermuter(fromDomains, JointDomainIndexer.create(outputSet, domains), oldToNew); } /** * Creates a converter that inserts a number of {@code addedDomains} at the given {@code offset} * within {@code fromDomains}. */ public static JointDomainReindexer createAdder( JointDomainIndexer fromDomains, int offset, DiscreteDomain ... addedDomains) { return createAdder(fromDomains, offset, JointDomainIndexer.create(addedDomains)); } /** * Creates a converter that inserts the {@code addedDomains} at the given {@code offset} * within {@code fromDomains}. */ public static JointDomainReindexer createAdder( JointDomainIndexer fromDomains, int offset, JointDomainIndexer addedDomains) { final int fromSize = fromDomains.size(); final int addedSize = addedDomains.size(); final int toSize = fromSize + addedSize; final int[] oldToNewIndex = new int[toSize]; final DiscreteDomain[] toDomains = new DiscreteDomain[toSize]; for (int i = 0; i < offset; ++i) { toDomains[i] = fromDomains.get(i); oldToNewIndex[i] = i; } for (int i = 0; i < addedSize; ++i) { int to = offset + i; toDomains[to] = addedDomains.get(i); oldToNewIndex[fromSize + i] = to; } for (int i = offset; i < fromSize; ++i) { int to = offset + addedSize; toDomains[to] = fromDomains.get(i); oldToNewIndex[i] = to; } return createPermuter(fromDomains, addedDomains, JointDomainIndexer.create(toDomains), null, oldToNewIndex); } /** * Creates a converter that joins {@code length} adjacent domains at given {@code offset} in * {@code fromDomains} into a single {@link JointDiscreteDomain}. * <p> * @see #createSplitter(JointDomainIndexer, int) */ public static JointDomainReindexer createJoiner(JointDomainIndexer fromDomains, int offset, int length) { return JointDomainIndexJoiner.createJoiner(fromDomains, offset, length); } /** * Creates a converter that removes the domains specified by {@code removedIndices} * from {@code fromDomains}. */ public static JointDomainReindexer createRemover(JointDomainIndexer fromDomains, BitSet removedIndices) { final int fromSize = fromDomains.size(); final int removedSize = removedIndices.cardinality(); final int toSize = fromSize - removedSize; final DiscreteDomain[] removedDomains = new DiscreteDomain[removedSize]; final DiscreteDomain[] toDomains = new DiscreteDomain[fromSize - removedSize]; final int[] oldToNewIndex = new int[fromSize]; for (int i = 0, to = 0, removed = 0; i < fromSize; ++i) { DiscreteDomain domain = fromDomains.get(i); if (removedIndices.get(i)) { removedDomains[removed] = domain; oldToNewIndex[i] = toSize + removed; ++removed; } else { toDomains[to] = domain; oldToNewIndex[i] = to; ++to; } } return createPermuter(fromDomains, null, JointDomainIndexer.create(toDomains), JointDomainIndexer.create(removedDomains), oldToNewIndex); } /** * Creates a converter that removes the domains specified by {@code removedIndices} * from {@code fromDomains}. */ public static JointDomainReindexer createRemover(JointDomainIndexer fromDomains, int ... removedIndices) { return createRemover(fromDomains, BitSetUtil.bitsetFromIndices(fromDomains.size(), removedIndices)); } /** * Creates a converter that splits a {@link JointDiscreteDomain} at given {@code offset} in * {@code fromDomains} into its constituent subdomains. * <p> * @see #createSplitter(JointDomainIndexer, int...) * @see #createJoiner(JointDomainIndexer, int, int) */ public static JointDomainIndexJoiner createSplitter(JointDomainIndexer fromDomains, int offset) { return JointDomainIndexJoiner.createSplitter(fromDomains, offset); } /** * Creates a converter that splits a {@link JointDiscreteDomain} at given {@code offsets} in * {@code fromDomains} into its constituent subdomains. * <p> * @see #createSplitter(JointDomainIndexer, int) * @see #createJoiner(JointDomainIndexer, int, int) */ public static JointDomainReindexer createSplitter(JointDomainIndexer fromDomains, int ... offsets) { offsets = offsets.clone(); Arrays.sort(offsets); JointDomainReindexer indexer = null; for (int i = offsets.length; --i>=0;) { final int offset = offsets[i]; final JointDomainIndexJoiner splitter = createSplitter(fromDomains, offset); indexer = indexer != null ? indexer.combineWith(splitter) : splitter; fromDomains = indexer.getToDomains(); } return Objects.requireNonNull(indexer); } /** * Returns a converter that combines {@code prev} converter with this one * by first applying the conversion in {@code prev} and passing the result * to this converter. * * If {@code prev} is null, simply returns this. * * @since 0.05 */ public JointDomainReindexer appendTo(@Nullable JointDomainReindexer prev) { return prev != null ? prev.combineWith(this) : this; } /** * Creates a new converter that combines this one with {@code that} by first * applying this conversion and passing the result to {@code that}. */ public JointDomainReindexer combineWith(JointDomainReindexer that) { return ChainedJointDomainReindexer.create(this, that); } /*---------------- * Object methods */ @Override public boolean equals(@Nullable Object other) { if (this == other) { return true; } if (other instanceof JointDomainReindexer) { JointDomainReindexer that = (JointDomainReindexer)other; return _fromDomains.equals(that._fromDomains) && _toDomains.equals(that._toDomains) && Objects.equals(_addedDomains, that._addedDomains) && Objects.equals(_removedDomains, that._removedDomains); } return false; } @Override public abstract int hashCode(); /*------------------------------------- * JointDomainReindexer methods */ /** * Returns {@link JointDomainIndexer} representing the subdomains to be added to * the factor table. This may be null if no dimensions are to be added. */ public final @Nullable JointDomainIndexer getAddedDomains() { return _addedDomains; } public abstract JointDomainReindexer getInverse(); /** * Returns {@link JointDomainIndexer} representing the subdomains to be removed from * the factor table. This may be null if no dimensions are to be removed. */ public final @Nullable JointDomainIndexer getRemovedDomains() { return _removedDomains; } /** * Returns {@link JointDomainIndexer} for factor table to be converted from. Will never be null. */ public final JointDomainIndexer getFromDomains() { return _fromDomains; } /** * Returns an "scratch" instance of {@link Indices} for use in conversion computations. * <p> * If the instance is later released via {@link Indices#release()} then that instance * may be returned by a future invocation of this method. */ public final Indices getScratch() { Indices scratch = _scratch.getAndSet(null); return scratch != null ? scratch : new Indices(this); } /** * Returns {@link JointDomainIndexer} for factor table to be converted to. Will never be null. */ public final JointDomainIndexer getToDomains() { return _toDomains; } /** * Computes {@link Indices#toIndices} and {@link Indices#removedIndices} fields of {@code indices} * assuming that {@link Indices#fromIndices}, {@link Indices#addedIndices} and {@link Indices#joinedIndices} * have already been set. */ public abstract void convertIndices(Indices indices); public double[] convertDenseEnergies(double[] oldEnergies) { if (_removedDomains == null) { // No domains removed, so we don't need to add together any weights and can // do a simple copy. return denseCopy(oldEnergies); } final double[] values = new double[_toDomains.getCardinality()]; // values start out in weight domain and are converted to log domain at bottom. if (hasFastJointIndexConversion()) { for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { for (int added = getAddedCardinality(); --added >= 0;) { final int newJoint = convertJointIndex(oldJoint, added); if (newJoint >= 0) { values[newJoint] += energyToWeight(oldEnergies[oldJoint]); } } } } else { Indices scratch = getScratch(); final JointDomainIndexer addedDomains = _addedDomains; for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { _fromDomains.jointIndexToIndices(oldJoint, scratch.fromIndices); for (int added = getAddedCardinality(); --added >= 0;) { if (addedDomains != null) { addedDomains.jointIndexToIndices(added, scratch.addedIndices); } convertIndices(scratch); if (scratch.toIndices[0] >= 0) { values[_toDomains.jointIndexFromIndices(scratch.toIndices)] += energyToWeight(oldEnergies[oldJoint]); } } } scratch.release(); } for (int i = values.length; --i>=0;) { values[i] = weightToEnergy(values[i]); } return values; } public double[] convertDenseWeights(double[] oldWeights) { if (_removedDomains == null) { // No domains removed, so we don't need to add together any weights and can // do a simple copy. return denseCopy(oldWeights); } final double[] weights = new double[_toDomains.getCardinality()]; final JointDomainIndexer addedDomains = _addedDomains; if (hasFastJointIndexConversion()) { if (addedDomains == null) { for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { final int newJoint = convertJointIndex(oldJoint, 0); if (newJoint >= 0) { weights[newJoint] += oldWeights[oldJoint]; } } } else { for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { for (int added = addedDomains.getCardinality(); --added >= 0;) { final int newJoint = convertJointIndex(oldJoint, added); if (newJoint >= 0) { weights[newJoint] += oldWeights[oldJoint]; } } } } } else { Indices scratch = getScratch(); if (addedDomains == null) { for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { _fromDomains.jointIndexToIndices(oldJoint, scratch.fromIndices); convertIndices(scratch); if (scratch.toIndices[0] >= 0) { weights[_toDomains.jointIndexFromIndices(scratch.toIndices)] += oldWeights[oldJoint]; } } } else { for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { _fromDomains.jointIndexToIndices(oldJoint, scratch.fromIndices); for (int added = addedDomains.getCardinality(); --added >= 0;) { addedDomains.jointIndexToIndices(added, scratch.addedIndices); convertIndices(scratch); if (scratch.toIndices[0] >= 0) { weights[_toDomains.jointIndexFromIndices(scratch.toIndices)] += oldWeights[oldJoint]; } } } } scratch.release(); } return weights; } public int convertJointIndex(int oldJointIndex, int addedJointIndex, @Nullable AtomicInteger removedJointIndex) { Indices scratch = getScratch(); scratch.writeIndices(oldJointIndex, addedJointIndex); convertIndices(scratch); int newJointIndex = scratch.readIndices(removedJointIndex); scratch.release(); return newJointIndex; } public int convertJointIndex(int oldJointIndex, int addedJointIndex) { return convertJointIndex(oldJointIndex, addedJointIndex, null); } public double[] convertSparseEnergies(double[] oldEnergies, int[] oldSparseIndexToJointIndex, int[] sparseIndexToJointIndex) { if (sparseIndexToJointIndex.length == oldSparseIndexToJointIndex.length * getAddedCardinality()) { // No entries need to be merged, so we can use a simple copy and avoid energy/weight conversions. return sparseCopy(oldEnergies, oldSparseIndexToJointIndex, sparseIndexToJointIndex); } final int oldSize = oldSparseIndexToJointIndex.length; final int size = sparseIndexToJointIndex.length; final double[] values = new double[size]; final OpenIntIntHashMap jointIndexToSparseIndex = new OpenIntIntHashMap(sparseIndexToJointIndex.length); for (int si = sparseIndexToJointIndex.length; --si>=0;) { jointIndexToSparseIndex.put(sparseIndexToJointIndex[si], si); } final int maxAdded = getAddedCardinality(); for (int oldSparse = 0; oldSparse < oldSize; ++oldSparse) { final double oldWeight = energyToWeight(oldEnergies[oldSparse]); final int oldJoint = oldSparseIndexToJointIndex[oldSparse]; for (int added = 0; added < maxAdded; ++added) { final int newJoint = convertJointIndex(oldJoint, added); if (newJoint >= 0) { final int newSparse = jointIndexToSparseIndex.get(newJoint); values[newSparse] += oldWeight; } } } for (int i = 0; i < size; ++i) { values[i] = weightToEnergy(values[i]); } return values; } public int[][] convertSparseIndices( int[][] oldSparseIndices, int[] oldSparseIndexToJointIndex, int[] sparseIndexToJointIndex) { final int sparseSize = sparseIndexToJointIndex.length; final int[][] indices = new int[sparseSize][]; for (int si = 0; si < sparseSize; ++si) { indices[si] = _toDomains.jointIndexToIndices(sparseIndexToJointIndex[si]); } return indices; } public int[] convertSparseToJointIndex(int[] oldSparseToJointIndex) { final int[] sparseToJoint = new int[oldSparseToJointIndex.length * getAddedCardinality()]; int i = 0; for (int added = 0, maxAdded = getAddedCardinality(); added < maxAdded; ++added) { for (int oldJoint : oldSparseToJointIndex) { sparseToJoint[i++] = convertJointIndex(oldJoint, added); } } if (!maintainsJointIndexOrder()) { Arrays.sort(sparseToJoint); } if (_removedDomains != null) { // If domains are removed then we may end up with duplicates. // First determine new size. int count = 0, prev = -1; for (int joint : sparseToJoint) { if (joint != prev) { ++count; } prev = joint; } if (count != sparseToJoint.length) { final int[] sparseToJoint2 = new int[count]; i = 0; prev = -1; for (int joint : sparseToJoint) { if (joint != prev) { sparseToJoint2[i++] = joint; } prev = joint; } return sparseToJoint2; } } return sparseToJoint; } public double[] convertSparseWeights( double[] oldWeights, int[] oldSparseIndexToJointIndex, int[] sparseIndexToJointIndex) { if (_removedDomains == null) { return sparseCopy(oldWeights, oldSparseIndexToJointIndex, sparseIndexToJointIndex); } final int oldSize = oldSparseIndexToJointIndex.length; final int size = sparseIndexToJointIndex.length; final double[] weights = new double[size]; final OpenIntIntHashMap jointIndexToSparseIndex = new OpenIntIntHashMap(sparseIndexToJointIndex.length); for (int si = sparseIndexToJointIndex.length; --si>=0;) { jointIndexToSparseIndex.put(sparseIndexToJointIndex[si], si); } final int maxAdded = getAddedCardinality(); for (int oldSparse = 0; oldSparse < oldSize; ++oldSparse) { final double oldWeight = oldWeights[oldSparse]; final int oldJoint = oldSparseIndexToJointIndex[oldSparse]; for (int added = 0; added < maxAdded; ++added) { final int newJoint = convertJointIndex(oldJoint, added); if (newJoint >= 0) { final int newSparse = jointIndexToSparseIndex.get(newJoint); weights[newSparse] += oldWeight; } } } return weights; } /** * The number of different possible combinations of values in {@link #getAddedDomains()} * or else returns 1 if no added domains. */ @SuppressWarnings("null") public final int getAddedCardinality() { return _addedDomains == null ? 1 : _addedDomains.getCardinality(); } /** * The number of different possible combinations of values in {@link #getRemovedDomains()} * or else returns 1 if no removed domains. */ @SuppressWarnings("null") public final int getRemovedCardinality() { return _removedDomains == null ? 1 : _removedDomains.getCardinality(); } public abstract boolean hasFastJointIndexConversion(); /*------------------- * Protected methods */ /** * Default computation of {@link #hashCode()} */ @SuppressWarnings("null") protected int computeHashCode() { int hash = _fromDomains.hashCode() * 7 + _toDomains.hashCode(); if (_addedDomains != null) { hash *= 11; hash += _addedDomains.hashCode(); } if (_removedDomains != null) { hash *= 13; hash += _removedDomains.hashCode(); } return hash; } /** * True if converter maintains the same order of joint indexes such that if * oldA <= oldB, then newA <= newB. Used to avoid sorting when converting * sparse indices. * <p> * True only if all removals are at front of list, additions are at end of * list and relative order of domains is otherwise maintained. */ protected abstract boolean maintainsJointIndexOrder(); /*----------------- * Private methods */ private double[] denseCopy(double[] oldValues) { final double[] values = new double[_toDomains.getCardinality()]; if (hasFastJointIndexConversion()) { for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { for (int added = getAddedCardinality(); --added >= 0;) { values[convertJointIndex(oldJoint, added)] = oldValues[oldJoint]; } } } else { Indices scratch = getScratch(); final JointDomainIndexer addedDomains = _addedDomains; for (int oldJoint = _fromDomains.getCardinality(); --oldJoint >= 0;) { _fromDomains.jointIndexToIndices(oldJoint, scratch.fromIndices); for (int added = getAddedCardinality(); --added >= 0;) { if (addedDomains != null) { addedDomains.jointIndexToIndices(added, scratch.addedIndices); } convertIndices(scratch); values[_toDomains.jointIndexFromIndices(scratch.toIndices)] = oldValues[oldJoint]; } } scratch.release(); } return values; } private final void releaseScratch(Indices scratch) { _scratch.lazySet(scratch); } public double[] sparseCopy(double[] oldValues, int[] oldSparseIndexToJointIndex, int[] sparseIndexToJointIndex) { final int oldSize = oldSparseIndexToJointIndex.length; final int size = sparseIndexToJointIndex.length; final double[] values = new double[size]; final OpenIntIntHashMap jointIndexToSparseIndex = new OpenIntIntHashMap(sparseIndexToJointIndex.length); for (int si = sparseIndexToJointIndex.length; --si>=0;) { jointIndexToSparseIndex.put(sparseIndexToJointIndex[si], si); } for (int oldSparse = 0; oldSparse < oldSize; ++oldSparse) { final int oldJoint = oldSparseIndexToJointIndex[oldSparse]; for (int added = getAddedCardinality(); --added >=0; ) { final int newJoint = convertJointIndex(oldJoint, added); final int newSparse = jointIndexToSparseIndex.get(newJoint); values[newSparse] = oldValues[oldSparse]; } } return values; } }