/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.domains; import static java.util.Objects.*; import java.util.Arrays; import java.util.concurrent.atomic.AtomicInteger; import org.eclipse.jdt.annotation.Nullable; import cern.colt.list.IntArrayList; import net.jcip.annotations.Immutable; /** * Converts {@link JointDomainIndexer} by conditioning one or more dimensions on fixed values. * <p> * Construct using {@link JointDomainReindexer#createConditioner}. There should be no reason to * refer to this class directly. * <p> * @since 0.05 * @author Christopher Barber */ @Immutable public final class JointDomainIndexConditioner extends JointDomainReindexer { private final int _hashCode; private final int[] _conditionedValues; private final int _addedJointIndex; private final JointDomainIndexConditioner _inverse; private final boolean _supportsJointIndexing; /*-------------- * Construction */ private JointDomainIndexConditioner( JointDomainIndexer fromDomains, @Nullable JointDomainIndexer addedDomains, JointDomainIndexer toDomains, @Nullable JointDomainIndexer removedDomains, int[] conditionedValues, @Nullable JointDomainIndexConditioner inverse) { super(fromDomains, addedDomains, toDomains, removedDomains); _conditionedValues = conditionedValues.clone(); _hashCode = computeHashCode(); int addedJointIndex = 0; if (addedDomains != null) { _supportsJointIndexing = toDomains.supportsJointIndexing(); if (_supportsJointIndexing) { final int addedIndex = addedDomains.jointIndexFromIndices(conditionedValues); final int fromCardinality = fromDomains.getCardinality(); addedJointIndex = addedIndex * fromCardinality; } } else { _supportsJointIndexing = fromDomains.supportsJointIndexing(); if (_supportsJointIndexing) { final int toCardinality = toDomains.getCardinality(); addedJointIndex = -requireNonNull(removedDomains).jointIndexFromIndices(conditionedValues) * toCardinality; } } _addedJointIndex = addedJointIndex; if (inverse == null) { inverse = new JointDomainIndexConditioner(toDomains, removedDomains, fromDomains, addedDomains, conditionedValues, this); } _inverse = inverse; } /** * Create new conditioning reindexer. * <p> * @param fromDomains the indexer to be conditioned * @param conditionedValues an array the same length as {@code fromDomains} size. Non-negative entries indicate * a discrete index to condition on. */ static JointDomainIndexConditioner _createConditioner(JointDomainIndexer fromDomains, int[] conditionedValues) { final int nConditioned = conditionedValues.length; final int toSize = fromDomains.size() - nConditioned; final JointDomainIndexer toDomains = fromDomains.subindexer(0, toSize); final JointDomainIndexer removedDomains = fromDomains.subindexer(toSize, nConditioned); return new JointDomainIndexConditioner(fromDomains, null, toDomains, removedDomains, conditionedValues, null); } /*---------------- * Object methods */ @Override public boolean equals(@Nullable Object other) { if (this == other) { return true; } if (other instanceof JointDomainIndexConditioner) { JointDomainIndexConditioner that = (JointDomainIndexConditioner)other; return _addedJointIndex == that._addedJointIndex && Arrays.equals(_conditionedValues, that._conditionedValues) && super.equals(that); } return false; } @Override public int hashCode() { return _hashCode; } /*------------------------------ * JointDomainReindexer methods */ @Override public JointDomainIndexConditioner getInverse() { return _inverse; } @Override public void convertIndices(Indices indices) { final int[] fromIndices = indices.fromIndices; final int[] toIndices = indices.toIndices; final int fromLength = indices.fromIndices.length; final int toLength = toIndices.length; if (fromLength < toLength) { System.arraycopy(fromIndices, 0, toIndices, 0, fromLength); System.arraycopy(_conditionedValues, 0, toIndices, fromLength, _conditionedValues.length); } else { final int[] removedIndices = indices.removedIndices; System.arraycopy(indices.fromIndices, indices.toIndices.length, removedIndices, 0, removedIndices.length); if (Arrays.equals(_conditionedValues, removedIndices)) { System.arraycopy(fromIndices, 0, toIndices, 0, toLength); } else { Arrays.fill(toIndices, -1); } } } @Override public double[] convertDenseEnergies(double[] oldEnergies) { final double[] values = new double[_toDomains.getCardinality()]; Arrays.fill(values, Double.POSITIVE_INFINITY); return convertDenseValues(oldEnergies, values); } @Override public double[] convertDenseWeights(double[] oldWeights) { return convertDenseValues(oldWeights, new double[_toDomains.getCardinality()]); } private double[] convertDenseValues(double[] oldValues, double[] newValues) { assert(_supportsJointIndexing); if (_addedDomains != null) { for (int fromji = oldValues.length; --fromji>=0;) { final int toji = fromji + _addedJointIndex; newValues[toji] = oldValues[fromji]; } } else { assert(_addedJointIndex <= 0); for (int toji = newValues.length; --toji>=0;) { final int fromji = toji - _addedJointIndex; newValues[toji] = oldValues[fromji]; } } return newValues; } @Override public int convertJointIndex(int oldJointIndex, int addedJointIndex, @Nullable AtomicInteger removedJointIndex) { assert(_supportsJointIndexing); if (_addedDomains != null) { final int fromCardinality = _fromDomains.getCardinality(); if (_addedJointIndex == addedJointIndex * fromCardinality) { return oldJointIndex + _addedJointIndex; } } else { final int toCardinality = _toDomains.getCardinality(); final int newJointIndex = oldJointIndex + _addedJointIndex; if (0 <= newJointIndex && newJointIndex < toCardinality) { if (removedJointIndex != null) { removedJointIndex.set(oldJointIndex / toCardinality); } return newJointIndex; } } return -1; } @Override public double[] convertSparseEnergies( double[] oldEnergies, int[] oldSparseIndexToJointIndex, int[] sparseIndexToJointIndex) { final double[] newEnergies = new double[sparseIndexToJointIndex.length]; Arrays.fill(newEnergies, Double.POSITIVE_INFINITY); return convertSparseValues(oldEnergies, oldSparseIndexToJointIndex, sparseIndexToJointIndex, newEnergies); } private double[] convertSparseValues( double[] oldValues, int[] oldSparseIndexToJointIndex, int[] sparseIndexToJointIndex, double[] newValues) { final int size = newValues.length; if (_addedDomains != null) { // Because we are adding new fixed values to the end of the domains, it does not // change the number or order of the existing values, so we can simply copy them. System.arraycopy(oldValues, 0, newValues, 0, size); } else { assert(_addedJointIndex < 0); final int oldsize = oldSparseIndexToJointIndex.length; int oldsi = -1; for (int newsi = 0; newsi < size; ++newsi) { final int newji = sparseIndexToJointIndex[newsi]; final int oldji = newji - _addedJointIndex; oldsi = Arrays.binarySearch(oldSparseIndexToJointIndex, oldsi + 1, oldsize, oldji); newValues[newsi] = oldValues[oldsi]; } } return newValues; } @Override public int[] convertSparseToJointIndex(int[] oldSparseToJointIndex) { assert(hasFastJointIndexConversion()); final int size = oldSparseToJointIndex.length; if (_addedDomains != null) { // Will have same number of entries as the old but with new index values for the added // fixed dimensions. int[] sparseToJointIndex = new int[size]; for (int i = size; --i>=0;) { sparseToJointIndex[i] = oldSparseToJointIndex[i] + _addedJointIndex; } return sparseToJointIndex; } else { final int toCardinality = _toDomains.getCardinality(); final IntArrayList sparseToJointIndex = new IntArrayList(size); for (int oldji : oldSparseToJointIndex) { final int newji = oldji + _addedJointIndex; if (0 <= newji && newji < toCardinality) { sparseToJointIndex.add(newji); } } sparseToJointIndex.trimToSize(); return sparseToJointIndex.elements(); } } @Override public double[] convertSparseWeights( double[] oldWeights, int[] oldSparseIndexToJointIndex, int[] sparseIndexToJointIndex) { final double[] newWeights = new double[sparseIndexToJointIndex.length]; return convertSparseValues(oldWeights, oldSparseIndexToJointIndex, sparseIndexToJointIndex, newWeights); } @Override public boolean hasFastJointIndexConversion() { return _supportsJointIndexing; } /** * {@inheritDoc} * * @return true */ @Override protected boolean maintainsJointIndexOrder() { return true; } /*------------------- * Protected methods */ @Override protected int computeHashCode() { return super.computeHashCode() * 17 + Arrays.hashCode(_conditionedValues); } /*----------------- * Private methods */ }