/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.factorfunctions.core; import static com.analog.lyric.math.Utilities.*; import java.io.Serializable; import java.util.Arrays; import java.util.BitSet; import java.util.Comparator; import java.util.HashMap; import java.util.Map; import java.util.Objects; import net.jcip.annotations.Immutable; import net.jcip.annotations.NotThreadSafe; import org.eclipse.jdt.annotation.NonNull; import org.eclipse.jdt.annotation.NonNullByDefault; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.DiscreteIndicesIterator; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.model.domains.JointDomainReindexer; import com.analog.lyric.dimple.model.domains.JointDomainReindexer.Indices; import com.analog.lyric.dimple.model.values.Value; import com.google.common.math.DoubleMath; /** * @since 0.05 * @author Christopher Barber */ @NotThreadSafe public class SparseFactorTable extends SparseFactorTableBase implements IFactorTable { /*----------- * Constants */ private static final long serialVersionUID = 1L; /*------- * State */ private IndexEntryComparator _entryComparator; @NotThreadSafe private static class IndexEntry implements Serializable { private static final long serialVersionUID = 1L; int[] _indices; int _sparseIndex; private IndexEntry(int[] indices, int sparseIndex) { _indices = indices; _sparseIndex = sparseIndex; } @Override protected IndexEntry clone() { return new IndexEntry(this._indices, this._sparseIndex); } @Override public boolean equals(@Nullable Object that) { return that instanceof IndexEntry && Arrays.equals(this._indices, ((IndexEntry)that)._indices); } @Override public int hashCode() { return Arrays.hashCode(_indices); } /** * Just print the indices as a debugging aid. */ @Override public String toString() { return Arrays.toString(_indices); } } @NotThreadSafe private static class IndexEntryWithWeight extends IndexEntry { private static final long serialVersionUID = 1L; private double _weight; private IndexEntryWithWeight(int[] indices, int sparseIndex, double weight) { super(indices, sparseIndex); _weight = weight; } } @Immutable @NonNullByDefault(false) private final static class IndexEntryComparator implements Comparator<IndexEntry>, Serializable { private static final long serialVersionUID = 1L; private final Comparator<int[]> _indicesComparator; IndexEntryComparator(@NonNull JointDomainIndexer domains) { _indicesComparator = domains.getIndicesComparator(); } @Override public int compare(IndexEntry entry1, IndexEntry entry2) { return _indicesComparator.compare(entry1._indices, entry2._indices); } } private IndexEntry[] _indexArray = new IndexEntry[0]; private final Map<IndexEntry,IndexEntry> _indexSet; private IndexEntry _scratchEntry = new IndexEntry(ArrayUtil.EMPTY_INT_ARRAY, -1); private final int[] _scratchIndices; /*-------------- * Construction */ SparseFactorTable(final JointDomainIndexer domains) { super(domains); _representation = FactorTable.SPARSE_ENERGY; _scratchIndices = domains.allocateIndices(null); _indexSet = indexMapForDomains(domains, 8); _entryComparator = new IndexEntryComparator(domains); } SparseFactorTable(SparseFactorTable that) { super(that); _indexArray = new IndexEntry[that._indexArray.length]; for (int i = _indexArray.length; --i>=-0;) { _indexArray[i] = that._indexArray[i].clone(); } JointDomainIndexer domains = getDomainIndexer(); _indexSet = indexMapForDomains(domains, _indexArray.length); for (IndexEntry entry : _indexArray) { _indexSet.put(entry, entry); } _scratchIndices = domains.allocateIndices(null); _entryComparator = new IndexEntryComparator(domains); } SparseFactorTable(IFactorTable other, JointDomainReindexer converter) { this(other, converter, other.getRepresentation()); } SparseFactorTable(IFactorTable other, JointDomainReindexer converter, FactorTableRepresentation representation) { this(converter.getToDomains()); final FactorTableRepresentation otherRep = other.getRepresentation(); // Start out with just sparse weights, and convert below if necessary _representation = FactorTable.SPARSE_WEIGHT; int curRep = 0; if (representation.hasWeight() || representation.isDeterministic()) { curRep = FactorTable.SPARSE_WEIGHT; } if (representation.hasEnergy()) { curRep |= FactorTable.SPARSE_ENERGY; } if (representation.hasSparseIndices()) { curRep |= FactorTable.SPARSE_INDICES; } Indices scratch = converter.getScratch(); // Note: this may change the representation of 'other', so we will change it back below. double[] otherWeights = other.getWeightsSparseUnsafe(); int oldSparseSize = other.sparseSize(); final DiscreteIndicesIterator addedIterator = scratch.addedIndices.length > 0 ? new DiscreteIndicesIterator(Objects.requireNonNull(converter.getAddedDomains()), scratch.addedIndices) : new DiscreteIndicesIterator(ArrayUtil.EMPTY_INT_ARRAY, scratch.addedIndices); for (int i = 0; i < oldSparseSize; ++i) { other.sparseIndexToIndices(i, scratch.fromIndices); double weight = otherWeights[i]; while (addedIterator.hasNext()) { addedIterator.next(); converter.convertIndices(scratch); if (scratch.toIndices[0] >= 0) { IndexEntryWithWeight entry = new IndexEntryWithWeight(scratch.toIndices.clone(), i, weight); IndexEntryWithWeight prevEntry = (IndexEntryWithWeight)_indexSet.put(entry, entry); if (prevEntry != null) { prevEntry._weight = weight; } } } addedIterator.reset(); } scratch.release(); int size = _indexSet.size(); _indexArray = _indexSet.values().toArray(new IndexEntry[size]); Arrays.sort(_indexArray, _entryComparator); double[] weights = new double[size]; for (int i = 0; i < size; ++i) { IndexEntryWithWeight entry = ((IndexEntryWithWeight)_indexArray[i]); double weight = entry._weight; if (weight != 0.0) { ++_nonZeroWeights; } weights[i] = weight; IndexEntry newEntry = new IndexEntry(entry._indices, i); _indexArray[i] = newEntry; _indexSet.put(newEntry, newEntry); } _sparseWeights = weights; other.setRepresentation(otherRep); setRepresentation(curRep); } /*---------------- * Object methods */ @Override public SparseFactorTable clone() { return new SparseFactorTable(this); } /*------------------ * Iterable methods */ @Override public IFactorTableIterator iterator() { return new SparseFactorTableIterator(this); } @Override public FactorTableIterator fullIterator() { throw notDense("fullIterator"); } /*-------------------------- * IFactorTableBase methods */ @Override public double density() { double d = _nonZeroWeights; for (DiscreteDomain domain : getDomainIndexer()) { d /= domain.size(); } return d; } @Override public void evalDeterministic(Value[] arguments) { throw notDeterministic("evalDeterministic"); } @Override public double getEnergyForElements(Object ... elements) { return getEnergyForIndices(getDomainIndexer().elementsToIndices(elements, _scratchIndices)); } @Override public double getEnergyForIndices(int ... indices) { IndexEntry entry = _indexSet.get(getScratchEntry(indices)); return entry != null ? getEnergyForSparseIndex(entry._sparseIndex) : Double.POSITIVE_INFINITY; } @Override public double getEnergyForIndicesDense(int... indices) { throw notDense("getEnergyForIndicesDense"); } @Override public double getEnergyForValuesDense(Value ... values) { throw notDense("getEnergyForValuesDense"); } @Override public double getWeightForIndicesDense(int... indices) { throw notDense("getWeightForIndicesDense"); } @Override public double getWeightForValuesDense(Value ... values) { throw notDense("getWeightForValuesDense"); } @Override public double getEnergyForJointIndex(int jointIndex) { throw notDense("getEnergyForJointIndex"); } @Override public double getEnergyForSparseIndex(int sparseIndex) { return hasSparseEnergies() ? _sparseEnergies[sparseIndex] : weightToEnergy(_sparseWeights[sparseIndex]); } @Override public double getWeightForElements(Object ... elements) { return getWeightForIndices(getDomainIndexer().elementsToIndices(elements, _scratchIndices)); } @Override public double getWeightForIndices(int ... indices) { IndexEntry entry = _indexSet.get(getScratchEntry(indices)); return entry != null ? getWeightForSparseIndex(entry._sparseIndex) : 0.0; } @Override public double getWeightForJointIndex(int jointIndex) { throw notDense("getWeightForJointIndex"); } @Override public double getWeightForSparseIndex(int sparseIndex) { return hasSparseWeights() ? _sparseWeights[sparseIndex] : energyToWeight(_sparseEnergies[sparseIndex]); } @Override public boolean hasDenseRepresentation() { return false; } @Override public boolean hasDenseEnergies() { return false; } @Override public boolean hasDenseWeights() { return false; } @Override public boolean hasMaximumDensity() { // Because this class is only intended to be used when the joint cardinality is larger than // 2^31, and is not designed to hold more than that many elements, it should not be possible // for this to ever be true. return false; } @Override public boolean hasSparseRepresentation() { return true; } @Override public boolean isDeterministicDirected() { return false; } @Override public boolean isConditional() { if ((_computedMask & CONDITIONAL_COMPUTED) == 0) { if (isDirected()) { normalizeDirected(true, false); } _computedMask |= CONDITIONAL_COMPUTED; } return (_computedMask & CONDITIONAL) != 0; } @Override public void setEnergyForElements(double energy, Object ... elements) { setEnergyForIndices(energy, getDomainIndexer().elementsToIndices(elements, _scratchIndices)); } @Override public void setEnergyForIndices(double energy, int ... indices) { getDomainIndexer().validateIndices(indices); setEnergyForSparseIndex(energy, createSparseIndexForIndices(indices)); } @Override public void setWeightForElements(double weight, Object ... elements) { setWeightForIndices(weight, getDomainIndexer().elementsToIndices(elements, _scratchIndices)); } @Override public void setWeightForIndices(double weight, int ... indices) { getDomainIndexer().validateIndices(indices); setWeightForSparseIndex(weight, createSparseIndexForIndices(indices)); } @Override public int sparseIndexFromElements(Object ... elements) { return sparseIndexFromIndices(getDomainIndexer().elementsToIndices(elements, _scratchIndices)); } @Override public int sparseIndexFromIndices(int ... indices) { IndexEntry entry = _indexSet.get(getScratchEntry(indices)); return entry != null ? entry._sparseIndex : -1; } @Override public Object[] sparseIndexToElements(int sparseIndex, @Nullable Object[] elements) { return getDomainIndexer().elementsFromIndices(sparseIndexToIndices(sparseIndex, _scratchIndices), elements); } @Override public int[] sparseIndexToIndices(int sparseIndex, @Nullable int[] indices) { indices = getDomainIndexer().allocateIndices(indices); System.arraycopy(_indexArray[sparseIndex]._indices, 0, indices, 0, indices.length); return indices; } /** * Like {@link #sparseIndexToIndices(int)} but returns actual internal indices array (which must not * be modified!). */ int[] sparseIndexToIndicesUnsafe(int sparseIndex) { return _indexArray[sparseIndex]._indices; } @Override public int sparseIndexFromJointIndex(int joint) { throw notDense("sparseIndexFromJointIndex"); } @Override public int sparseIndexToJointIndex(int sparseIndex) { throw notDense("sparseIndexToJointIndex"); } @Override public int sparseSize() { return _indexArray.length; } /*---------------------- * IFactorTable methods */ @Override public int compact() { int nRemoved = 0; final int curSparseSize = sparseSize(); if (curSparseSize > _nonZeroWeights) { nRemoved = curSparseSize - _nonZeroWeights; final boolean hasEnergy = hasSparseEnergies(); final double[] sparseEnergies = hasEnergy ? new double[_nonZeroWeights] : ArrayUtil.EMPTY_DOUBLE_ARRAY; final boolean hasWeight = hasSparseWeights(); final double[] sparseWeights = hasWeight ? new double[_nonZeroWeights] : ArrayUtil.EMPTY_DOUBLE_ARRAY; final IndexEntry[] indexArray = new IndexEntry[_nonZeroWeights]; if (hasWeight) { for (int i = 0, j = 0; i < curSparseSize; ++i) { IndexEntry entry = _indexArray[i]; double w = _sparseWeights[i]; if (w == 0.0) { _indexSet.remove(entry); } else { sparseWeights[j] = w; if (hasEnergy) { sparseEnergies[j] = _sparseEnergies[i]; } entry._sparseIndex = j; indexArray[j] = entry; ++j; } } } else { for (int i = 0, j = 0; i < curSparseSize; ++i) { IndexEntry entry = _indexArray[i]; double e = _sparseEnergies[i]; if (Double.isInfinite(e)) { _indexSet.remove(entry); } else { sparseEnergies[j] = e; entry._sparseIndex = j; indexArray[j] = entry; ++j; } } } _sparseEnergies = sparseEnergies; _sparseWeights = sparseWeights; _indexArray = indexArray; recomputeSparseIndices(); } return nRemoved; } @Override public void copy(IFactorTable that) { if (that == this) { return; } // REFACTOR: share if (!getDomainIndexer().domainsEqual(that.getDomainIndexer())) { throw new DimpleException("Cannot copy from factor table with different domains"); } if (!that.hasSparseRepresentation()) { throw new DimpleException("Cannot copy to SparseFactorTable from table without sparse representation"); } int sparseSize = that.sparseSize(); _indexArray = new IndexEntry[sparseSize]; _indexSet.clear(); for (int si = 0; si < sparseSize; ++si) { int[] indices = that.sparseIndexToIndices(si); IndexEntry entry = new IndexEntry(indices, si); _indexArray[si] = entry; _indexSet.put(entry, entry); } _computedMask = 0; _sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY; _sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY; _sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY; _representation = 0; if (that.hasSparseEnergies()) { _representation |= FactorTable.SPARSE_ENERGY; _sparseEnergies = that.getEnergiesSparseUnsafe().clone(); } if (that.hasSparseWeights()) { _representation |= FactorTable.SPARSE_WEIGHT; _sparseWeights = that.getWeightsSparseUnsafe().clone(); } if (that.hasSparseIndices()) { getIndicesSparseUnsafe(); } _nonZeroWeights = that.countNonZeroWeights(); } @Override public IFactorTable joinVariablesAndCreateNewTable(int[] varIndices, int[] indexToJointIndex, DiscreteDomain[] allDomains, DiscreteDomain jointDomain) { final JointDomainIndexer domains = getDomainIndexer(); final JointDomainReindexer converter = FactorTable.makeConverterForJoinVariables(domains, varIndices, indexToJointIndex, allDomains, jointDomain); return new SparseFactorTable(this, converter); } @Override public boolean hasDeterministicRepresentation() { return false; } @Override public double[] getEnergiesSparseUnsafe() { if (_sparseEnergies.length == 0 && !hasSparseEnergies()) { setRepresentation(_representation | FactorTable.SPARSE_ENERGY); } return _sparseEnergies; } @Override public double[] getEnergiesDenseUnsafe() { throw notDense("getEnergiesDenseUnsafe"); } @Override public double[] getEnergySlice(@Nullable double[] slice, int sliceDimension, int... indices) { final int[] scratchIndices = _scratchIndices; System.arraycopy(indices, 0, scratchIndices, 0, scratchIndices.length); return getEnergySliceImpl(slice, sliceDimension, scratchIndices); } @Override public double[] getEnergySlice(@Nullable double[] slice, int sliceDimension, Value ... values) { final int[] scratchIndices = _scratchIndices; for (int i = scratchIndices.length; --i>=0;) { scratchIndices[i] = values[i].getIndex(); } return getEnergySliceImpl(slice, sliceDimension, scratchIndices); } private double[] getEnergySliceImpl(@Nullable double[] slice, int sliceDimension, int[] scratchIndices) { JointDomainIndexer indexer = getDomainIndexer(); int size = indexer.getDomainSize(sliceDimension); if (slice == null || slice.length < size) { slice = new double[size]; } for (int i = 0; i < size; ++i) { scratchIndices[sliceDimension] = i; slice[i] = getEnergyForIndices(scratchIndices); } return slice; } @Override public int[][] getIndicesSparseUnsafe() { if (!hasSparseIndices()) { _representation |= FactorTable.SPARSE_INDICES; recomputeSparseIndices(); } return _sparseIndices; } @Override public double[] getWeightsSparseUnsafe() { if (_sparseWeights.length == 0 && !hasSparseWeights()) { setRepresentation(_representation | FactorTable.SPARSE_WEIGHT); } return _sparseWeights; } @Override public double[] getWeightSlice(@Nullable double[] slice, int sliceDimension, int... indices) { final int[] scratchIndices = _scratchIndices; System.arraycopy(indices, 0, scratchIndices, 0, scratchIndices.length); return getWeightSliceImpl(slice, sliceDimension, scratchIndices); } @Override public double[] getWeightSlice(@Nullable double[] slice, int sliceDimension, Value ... values) { final int[] scratchIndices = _scratchIndices; for (int i = scratchIndices.length; --i>=0;) { scratchIndices[i] = values[i].getIndex(); } return getWeightSliceImpl(slice, sliceDimension, scratchIndices); } private double[] getWeightSliceImpl(@Nullable double[] slice, int sliceDimension, int[] scratchIndices) { JointDomainIndexer indexer = getDomainIndexer(); int size = indexer.getDomainSize(sliceDimension); if (slice == null || slice.length < size) { slice = new double[size]; } for (int i = 0; i < size; ++i) { scratchIndices[sliceDimension] = i; slice[i] = getWeightForIndices(scratchIndices); } return slice; } @Override public void setEnergiesDense(double[] energies) { throw notDense("setEnergiesDense"); } @Override public void setWeightsDense(double[] weights) { throw notDense("setWeightsDense"); } @Override public void setDeterministicOutputIndices(int[] outputIndices) { throw notDeterministic("setDeterministicOutputIndices"); } @Override public void setEnergyForJointIndex(double energy, int jointIndex) { throw notDense("setEnergyForJointIndex"); } @Override public void setEnergyForSparseIndex(double energy, int sparseIndex) { final double prevEnergy = getEnergyForSparseIndex(sparseIndex); if (prevEnergy != energy) { _computedMask = 0; double weight = hasSparseWeights() ? energyToWeight(energy) : 0.0; setWeightEnergyForSparseIndex(weight, energy, sparseIndex); if (Double.isInfinite(prevEnergy)) { ++_nonZeroWeights; } else if (Double.isInfinite(energy)) { --_nonZeroWeights; } } } @Override public void setWeightForJointIndex(double weight, int jointIndex) { throw notDense("setWeightForJointIndex"); } @Override public void setWeightForSparseIndex(double weight, int sparseIndex) { final double prevWeight = getWeightForSparseIndex(sparseIndex); if (prevWeight != weight) { _computedMask = 0; double energy = hasSparseEnergies() ? weightToEnergy(weight) : Double.POSITIVE_INFINITY; setWeightEnergyForSparseIndex(weight, energy, sparseIndex); if (prevWeight == 0.0) { ++_nonZeroWeights; } else if (weight == 0.0) { --_nonZeroWeights; } } } @Override public void setEnergiesSparse(int[] jointIndices, double[] energies) { throw notDense("setEnergiesSparse(int[] jointIndices, double[])"); } @Override public void setWeightsSparse(int[] jointIndices, double[] weights) { throw notDense("setWeightsSparse(int[] jointIndices, double[])"); } /*----------------- * Private methods */ private static Map<IndexEntry,IndexEntry> indexMapForDomains(JointDomainIndexer domains, int capacity) { // TODO: it is possible that for a large domains list with small fixed domain sizes, // it might better to use a radix tree, especially when all of the domains are binary. return new HashMap<IndexEntry,IndexEntry>(capacity); } private int createSparseIndexForIndices(int[] indices) { IndexEntry scratchEntry = getScratchEntry(indices); IndexEntry entry = _indexSet.get(scratchEntry); if (entry != null) { return entry._sparseIndex; } // Need to insert a new sparse index. indices = Arrays.copyOf(indices, indices.length); // Find position by doing a binary search in _indexArray int sparseIndex = -Arrays.binarySearch(_indexArray, scratchEntry, _entryComparator) - 1; entry = new IndexEntry(indices, sparseIndex); int newSize = _indexArray.length + 1; IndexEntry[] indexArray = new IndexEntry[newSize]; double[] sparseEnergies = hasSparseEnergies() ? new double[newSize] : _sparseEnergies; double[] sparseWeights = hasSparseWeights() ? new double[newSize] : _sparseWeights; int[][] sparseIndices = hasSparseIndices() ? new int[newSize][] : _sparseIndices; if (sparseIndex > 0) { System.arraycopy(_indexArray, 0, indexArray, 0, sparseIndex); if (sparseEnergies.length > 0) { System.arraycopy(_sparseEnergies, 0, sparseEnergies, 0, sparseIndex); } if (sparseWeights.length > 0) { System.arraycopy(_sparseWeights, 0, sparseWeights, 0, sparseIndex); } if (sparseIndices.length > 0) { System.arraycopy(_sparseIndices, 0, sparseIndices, 0, sparseIndex); } } indexArray[sparseIndex] = entry; if (sparseEnergies.length > 0) { sparseEnergies[sparseIndex] = Double.POSITIVE_INFINITY; } // No need to initialize value for sparseWeights because default is 0.0 for new array. if (sparseIndices.length > 0) { sparseIndices[sparseIndex] = indices; } if (sparseIndex < _indexArray.length) { int endSize = _indexArray.length - sparseIndex; System.arraycopy(_indexArray, sparseIndex, indexArray, sparseIndex + 1, endSize); if (sparseEnergies.length > 0) { System.arraycopy(_sparseEnergies, sparseIndex, sparseEnergies, sparseIndex + 1, endSize); } if (sparseWeights.length > 0) { System.arraycopy(_sparseWeights, sparseIndex, sparseWeights, sparseIndex + 1, endSize); } if (sparseIndices.length > 0) { System.arraycopy(_sparseIndices, sparseIndex, sparseIndices, sparseIndex + 1, endSize); } for (int i = sparseIndex + 1; i < newSize; ++i) { indexArray[i]._sparseIndex = i; } } _sparseEnergies = sparseEnergies; _sparseWeights = sparseWeights; _sparseIndices = sparseIndices; _indexArray = indexArray; _indexSet.put(entry, entry); return sparseIndex; } private void computeNonZeroWeights() { int count = 0; if (hasSparseWeights()) { for (double w : _sparseWeights) if (w != 0) ++count; } else { for (double e : _sparseEnergies) if (!Double.isInfinite(e)) ++count; } _nonZeroWeights = count; } private IndexEntry getScratchEntry(int[] indices) { _scratchEntry._indices = indices; return _scratchEntry; } private DimpleException notDense(String method) { return DimpleException.unsupportedMethod(getClass(), method, "dense representation not supported."); } @Override int normalizeDirected(boolean justCheck, boolean ignoreZeroWeightInputs) { final JointDomainIndexer domains = getDomainIndexer(); boolean computeNormalizedTotal = justCheck; double normalizedTotal = 1.0; int nNotNormalized = 0; // We represent the joint index such that the outputs for the same // input are stored consecutively, so we only need to walk through // the values in order. // // When just checking, we allow total to equal something other than one // as long as they are all the same. final int size = _indexArray.length; for (int si = 0, nextsi = 1, start = 0; si < size; si = nextsi++) { if (nextsi == size || !domains.hasSameInputs(_indexArray[si]._indices, _indexArray[nextsi]._indices)) { double totalForInput = 0.0; if (hasSparseWeights()) { for (int i = start; i < nextsi; ++i) { totalForInput += _sparseWeights[i]; } } else { for (int i = start; i < nextsi; ++i) { totalForInput += energyToWeight(_sparseEnergies[i]); } } if (totalForInput == 0.0) { if (ignoreZeroWeightInputs) ++nNotNormalized; else return normalizeDirectedHandleZeroForInput(justCheck); } if (computeNormalizedTotal) { normalizedTotal = totalForInput; computeNormalizedTotal = false; } else if (!DoubleMath.fuzzyEquals(totalForInput, normalizedTotal, 1e-12)) { if (justCheck) { return 1; } } if (!justCheck) { if (totalForInput != 0.0) { if (hasSparseWeights()) { for (int i = start; i < nextsi; ++i) { _sparseWeights[i] /= totalForInput; } } if (hasSparseEnergies()) { double logTotalForInput = Math.log(totalForInput); for (int i = start; i < nextsi; ++i) { _sparseEnergies[i] += logTotalForInput; } } } } start = nextsi; } } if (nNotNormalized == 0) { _computedMask |= CONDITIONAL|CONDITIONAL_COMPUTED; } else { _computedMask |= CONDITIONAL_COMPUTED; } return nNotNormalized; } @Override boolean normalizeUndirected(boolean justCheck) { if ((_computedMask & NORMALIZED) != 0) { return true; } double total = 0.0; if (hasSparseWeights()) { for (double w : _sparseWeights) { total += w; } } else { for (double e: _sparseEnergies) { total += energyToWeight(e); } } if (!DoubleMath.fuzzyEquals(total, 1.0, 1e-12)) { if (justCheck) { return false; } if (total == 0.0) { throw normalizeUndirectedHandleZero(); } for (int i = _sparseWeights.length; --i>=0;) { _sparseWeights[i] /= total; } if (_sparseEnergies.length > 0) { final double logTotal = Math.log(total); for (int i = _sparseEnergies.length; --i>=0;) { _sparseEnergies[i] += logTotal; } } } _computedMask |= NORMALIZED|NORMALIZED_COMPUTED; return true; } private DimpleException notDeterministic(String method) { return DimpleException.unsupportedMethod(getClass(), method, "deterministic representation not supported."); } /** * If {@link #hasSparseIndices()} this will recompute their values based on the current * index entries. */ private void recomputeSparseIndices() { if (hasSparseIndices()) { IndexEntry[] indexArray = _indexArray; final int sparseSize = indexArray.length; if (sparseSize == 0) { _sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY; } else { int[][] sparseIndices = _sparseIndices = new int[indexArray.length][]; for (int i = indexArray.length; --i >=0;) { sparseIndices[i] = indexArray[i]._indices; } } } } @Override void setDirected(@Nullable BitSet outputSet, boolean assertConditional) { // REFACTOR: share? final JointDomainIndexer oldDomains = getDomainIndexer(); final JointDomainIndexer newDomains = JointDomainIndexer.create(outputSet, oldDomains); if (oldDomains.equals(newDomains)) { if (assertConditional) { assertIsConditional(); } return; } _computedMask = 0; setDomainIndexer(newDomains); _entryComparator = new IndexEntryComparator(newDomains); if (!oldDomains.hasCanonicalDomainOrder() | !newDomains.hasCanonicalDomainOrder()) { // Need to reorder the entries and values. int sparseSize = sparseSize(); Arrays.sort(_indexArray, _entryComparator); if (_sparseWeights.length > 0) { double[] sparseWeights = new double[sparseSize]; for (int si = 0; si < sparseSize; ++ si) { sparseWeights[si] = _sparseWeights[_indexArray[si]._sparseIndex]; } _sparseWeights = sparseWeights; } if (_sparseEnergies.length > 0) { double[] sparseEnergies = new double[sparseSize]; for (int si = 0; si < sparseSize; ++ si) { sparseEnergies[si] = _sparseEnergies[_indexArray[si]._sparseIndex]; } _sparseEnergies = sparseEnergies; } recomputeSparseIndices(); for (int si = 0; si < sparseSize; ++ si) { _indexArray[si]._sparseIndex = si; } } if (assertConditional) { assertIsConditional(); } } @Override void setRepresentation(int newRep) { if (_representation == newRep) { return; } boolean convertFromWeights = false; boolean convertFromEnergies = false; switch (newRep) { case FactorTable.SPARSE_ENERGY: case FactorTable.SPARSE_ENERGY_WITH_INDICES: convertFromWeights = !hasSparseEnergies(); break; case FactorTable.SPARSE_WEIGHT: case FactorTable.SPARSE_WEIGHT_WITH_INDICES: convertFromEnergies = !hasSparseWeights(); break; case FactorTable.ALL_SPARSE: case FactorTable.ALL_SPARSE_WITH_INDICES: if (!hasSparseEnergies()) { convertFromWeights = true; } else { convertFromEnergies = true; } break; default: throw new DimpleException( "Cannot set representation to '%s' because '%s' does not support dense representations.", FactorTableRepresentation.forMask(newRep).name(), getClass().getSimpleName() ); } if (convertFromWeights) { double[] sparseWeights = _sparseWeights; double[] sparseEnergies = _sparseEnergies = new double[sparseWeights.length]; for (int i = sparseWeights.length; --i>=0;) { sparseEnergies[i] = weightToEnergy(sparseWeights[i]); } } else if (convertFromEnergies) { double[] sparseEnergies = _sparseEnergies; double[] sparseWeights = _sparseWeights = new double[sparseEnergies.length]; for (int i = sparseEnergies.length; --i>=0;) { sparseWeights[i] = energyToWeight(sparseEnergies[i]); } } if (!hasSparseIndices() && (newRep & FactorTable.SPARSE_INDICES) != 0) { getIndicesSparseUnsafe(); } _representation = newRep; if (!hasSparseEnergies()) { _sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY; } if (!hasSparseWeights()) { _sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY; } if (!hasSparseIndices()) { _sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY; } } @Override void setSparseValues(int[][] indicesArray, double[] values, int representation) { _function = null; int size = indicesArray.length; if (size != values.length) { // REFACTOR: share throw new IllegalArgumentException( String.format("'Arrays have different sizes: %d and %d", size, values.length)); } final JointDomainIndexer domainIndexer = getDomainIndexer(); final IndexEntry[] indexArray = new IndexEntry[size]; for (int i = 0; i < size; ++i) { int[] indices = indicesArray[i]; domainIndexer.validateIndices(indices); indexArray[i] = new IndexEntry(Objects.requireNonNull(ArrayUtil.cloneArray(indices)), i); } boolean doSort = false; for (int i = 1; i < size; ++i) { if (0 < _entryComparator.compare(indexArray[i-1], indexArray[i])) { doSort = true; break; } } double[] copiedValues = new double[size]; if (doSort) { Arrays.sort(indexArray, _entryComparator); for (int i = 0; i < size; ++i) { IndexEntry entry = indexArray[i]; copiedValues[i] = values[entry._sparseIndex]; entry._sparseIndex = i; } } else { System.arraycopy(values, 0, copiedValues, 0, size); } for (int i = 1; i < size; ++i) { IndexEntry entry1 = indexArray[i - 1]; IndexEntry entry2 = indexArray[i]; if (entry1.equals(entry2)) { throw new IllegalArgumentException(String.format( "Multiple entries with same set of indices %s", entry1._indices)); } } _indexArray = indexArray; _indexSet.clear(); for (IndexEntry entry : _indexArray) { _indexSet.put(entry, entry); } switch (representation) { case FactorTable.SPARSE_ENERGY: _sparseEnergies = copiedValues; _sparseWeights = ArrayUtil.EMPTY_DOUBLE_ARRAY; break; case FactorTable.SPARSE_WEIGHT: _sparseWeights = copiedValues; _sparseEnergies = ArrayUtil.EMPTY_DOUBLE_ARRAY; break; default: assert(false); } _sparseIndices = ArrayUtil.EMPTY_INT_ARRAY_ARRAY; _representation = representation; recomputeSparseIndices(); _computedMask = 0; computeNonZeroWeights(); } /** * For implementation of {@link #setWeightForSparseIndex(double, int)} and * {@link #setEnergyForSparseIndex(double, int)} */ private void setWeightEnergyForSparseIndex(double weight, double energy, int sparseIndex) { if (hasSparseEnergies()) { _sparseEnergies[sparseIndex] = energy; } if (hasSparseWeights()) { _sparseWeights[sparseIndex] = weight; } } @Override public double[] getWeightsDenseUnsafe() { throw notDense("getWeightsDenseUnsafe"); } }