/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.model.domains; import java.util.Arrays; import java.util.BitSet; import java.util.Comparator; import net.jcip.annotations.Immutable; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.Comparators; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.values.Value; /** * Directed implementation of {@link JointDomainIndexer}. */ @Immutable final class StandardDirectedJointDomainIndexer extends StandardJointDomainIndexer { /*------- * State */ private static final long serialVersionUID = 1L; private final BitSet _outputSet; private final int _inputCardinality; private final int[] _inputIndices; private final int[] _inputProducts; private final int _outputCardinality; private final int[] _outputIndices; private final int[] _outputProducts; private final int[] _directedProducts; private final boolean _canonicalOrder; private final Comparator<int[]> _indicesComparator; /*-------------- * Construction */ StandardDirectedJointDomainIndexer(BitSet outputs, DiscreteDomain ... domains) { super(computeHashCode(outputs, domains), domains); _outputSet = outputs; final int nDomains = domains.length; final int nOutputs = outputs.cardinality(); final int nInputs = nDomains - nOutputs; if (outputs.length() > nDomains) { throw new DimpleException("Illegal output set for domain list"); } final int[] inputIndices = new int[nInputs]; final int[] inputProducts = new int[nDomains]; final int[] outputIndices = new int[nOutputs]; final int[] outputProducts = new int[nDomains]; final int[] directedProducts = new int[nDomains]; int curInput = 0, curOutput = 0; int inputProduct = 1, outputProduct = 1; for (int i = 0; i < nDomains; ++i) { final int size = domains[i].size(); if (outputs.get(i)) { outputIndices[curOutput] = i; outputProducts[i] = outputProduct; outputProduct *= size; ++curOutput; } else { inputIndices[curInput] = i; inputProducts[i] = inputProduct; inputProduct *= size; ++curInput; } } boolean canonicalOrder = true; for (int i = 0; i < nOutputs; ++i) { int j = outputIndices[i]; if (i != j) { canonicalOrder = false; } directedProducts[j] = outputProducts[j]; } for (int i = 0; i < nInputs; ++i) { int j = inputIndices[i]; directedProducts[j] = inputProducts[j] * outputProduct; } _inputCardinality = inputProduct; _outputCardinality = outputProduct; _inputIndices = inputIndices; _inputProducts = inputProducts; _outputIndices = outputIndices; _outputProducts = outputProducts; _directedProducts = directedProducts; _canonicalOrder = canonicalOrder; _indicesComparator = canonicalOrder ? Comparators.reverseLexicalIntArray() : new DirectedArrayComparator(inputIndices, outputIndices); } /*---------------- * Object methods */ @Override public boolean equals(@Nullable Object that) { if (this == that) { return true; } if (that instanceof StandardDirectedJointDomainIndexer) { StandardDirectedJointDomainIndexer thatDiscrete = (StandardDirectedJointDomainIndexer)that; return _hashCode == thatDiscrete._hashCode && Arrays.equals(_domains, thatDiscrete._domains) && _outputSet.equals(thatDiscrete._outputSet); } return false; } /*---------------------------- * JointDomainIndexer methods */ @Override public final Comparator<int[]> getIndicesComparator() { return _indicesComparator; } @Override public int getInputCardinality() { return _inputCardinality; } @Override public int[] getInputDomainIndices() { return _inputIndices.clone(); } @Override public int getInputDomainIndex(int i) { return _inputIndices[i]; } @Override public BitSet getInputSet() { BitSet set = getOutputSet(); set.flip(0, size()); return set; } @Override public int getInputSize() { return _inputIndices.length; } @Override public int getOutputCardinality() { return _outputCardinality; } @Override public int getOutputDomainIndex(int i) { return _outputIndices[i]; } @Override public int[] getOutputDomainIndices() { return _outputIndices.clone(); } @Override public BitSet getOutputSet() { return (BitSet) _outputSet.clone(); } @Override public int getOutputSize() { return _outputIndices.length; } @Override public int getStride(int i) { return _directedProducts[i]; } @Override public boolean isDirected() { return true; } @Override public boolean hasCanonicalDomainOrder() { return _canonicalOrder; } @Override public boolean hasSameInputs(int[] indices1, int[] indices2) { return hasSameInputsImpl(indices1, indices2, _inputIndices); } @Override public int inputIndexFromElements(Object ... elements) { final DiscreteDomain[] domains = _domains; final int[] products = _inputProducts; int joint = 0; for (int i = 0, end = products.length; i < end; ++i) { int product = products[i]; if (product != 0) { joint += product * domains[i].getIndexOrThrow(elements[i]); } } return joint; } @Override public int inputIndexFromIndices(int ... indices) { final int length = size(); int joint = 0; for (int i = 0, end = length; i != end; ++i) // != is slightly faster than < comparison { joint += indices[i] * _inputProducts[i]; } return joint; } @Override public int inputIndexFromValues(Value ... values) { final int length = size(); int joint = 0; for (int i = 0, end = length; i != end; ++i) // != is slightly faster than < comparison { joint += values[i].getIndex() * _inputProducts[i]; } return joint; } @Override public int inputIndexFromJointIndex(int jointIndex) { return jointIndex / _outputCardinality; } @Override public void inputIndexToElements(int inputIndex, Object[] elements) { locationToElements(inputIndex, elements, _inputIndices, _inputProducts); } @Override public void inputIndexToValues(int inputIndex, Value[] elements) { locationToValues(inputIndex, elements, _inputIndices, _inputProducts); } @Override public void inputIndexToIndices(int inputIndex, int[] indices) { locationToIndices(inputIndex, indices, _inputIndices, _inputProducts); } @Override public int jointIndexFromElements(Object ... elements) { final DiscreteDomain[] domains = _domains; final int[] products = _directedProducts; int joint = 0; for (int i = 0, end = products.length; i < end; ++i) { joint += products[i] * domains[i].getIndexOrThrow(elements[i]); } return joint; } @Override public int jointIndexFromIndices(int ... indices) { final int length = size(); int joint = 0; for (int i = 0, end = length; i != end; ++i) // != is slightly faster than < comparison { joint += indices[i] * _directedProducts[i]; } return joint; } @Override public int jointIndexFromValues(Value ... values) { final int length = size(); int joint = 0; for (int i = 0, end = length; i != end; ++i) // != is slightly faster than < comparison { joint += values[i].getIndex() * _directedProducts[i]; } return joint; } @Override public int jointIndexFromInputOutputIndices(int inputIndex, int outputIndex) { return outputIndex + inputIndex * _outputCardinality; } @Override public <T> T[] jointIndexToElements(int jointIndex, @Nullable T[] elements) { elements = allocateElements(elements); final int inputIndex = jointIndex / _outputCardinality; final int outputIndex = jointIndex - inputIndex * _outputCardinality; inputIndexToElements(inputIndex, elements); outputIndexToElements(outputIndex, elements); return elements; } @Override public Value[] jointIndexToValues(int jointIndex, Value[] elements) { final int inputIndex = jointIndex / _outputCardinality; final int outputIndex = jointIndex - inputIndex * _outputCardinality; inputIndexToValues(inputIndex, elements); outputIndexToValues(outputIndex, elements); return elements; } @Override public int jointIndexToElementIndex(int jointIndex, int domainIndex) { return (jointIndex / _directedProducts[domainIndex]) % _domains[domainIndex].size(); } @Override public int[] jointIndexToIndices(int jointIndex, @Nullable int[] indices) { indices = allocateIndices(indices); final int inputIndex = jointIndex / _outputCardinality; final int outputIndex = jointIndex - inputIndex * _outputCardinality; inputIndexToIndices(inputIndex, indices); outputIndexToIndices(outputIndex, indices); return indices; } @Override public int outputIndexFromElements(Object ... elements) { final DiscreteDomain[] domains = _domains; final int[] products = _outputProducts; int joint = 0; for (int i = 0, end = products.length; i < end; ++i) { int product = products[i]; if (product != 0) { joint += product * domains[i].getIndexOrThrow(elements[i]); } } return joint; } @Override public int outputIndexFromIndices(int ... indices) { final int length = size(); int joint = 0; for (int i = 0, end = length; i != end; ++i) // != is slightly faster than < comparison { joint += indices[i] * _outputProducts[i]; } return joint; } @Override public int outputIndexFromValues(Value ... values) { final int length = size(); int joint = 0; for (int i = 0, end = length; i != end; ++i) // != is slightly faster than < comparison { joint += values[i].getIndex() * _outputProducts[i]; } return joint; } @Override public int outputIndexFromJointIndex(int jointIndex) { return jointIndex % _outputCardinality; } @Override public void outputIndexToElements(int outputIndex, Object[] elements) { locationToElements(outputIndex, elements, _outputIndices, _outputProducts); } @Override public void outputIndexToIndices(int outputIndex, int[] indices) { locationToIndices(outputIndex, indices, _outputIndices, _outputProducts); } @Override public void outputIndexToValues(int outputIndex, Value[] elements) { locationToValues(outputIndex, elements, _outputIndices, _outputProducts); } }