/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.factorfunctions.core; import java.util.Arrays; import java.util.BitSet; import java.util.Random; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.model.domains.JointDomainReindexer; import com.analog.lyric.dimple.model.values.Value; import cern.colt.list.DoubleArrayList; import cern.colt.list.IntArrayList; import net.jcip.annotations.NotThreadSafe; @NotThreadSafe public abstract class FactorTableBase implements IFactorTableBase, IFactorTable { /*-------- * State */ private static final long serialVersionUID = 1L; private JointDomainIndexer _domains; protected @Nullable FactorFunction _function; /*-------------- * Construction */ protected FactorTableBase(JointDomainIndexer domains) { _domains = domains; } protected FactorTableBase(@Nullable BitSet directedTo, DiscreteDomain ... domains) { _domains = JointDomainIndexer.create(directedTo, domains); } protected FactorTableBase(FactorTableBase that) { _domains = that._domains; _function = that._function; } /*---------------- * Object methods */ @Override public abstract FactorTableBase clone(); /*------------------ * Iterable methods */ @Override public IFactorTableIterator iterator() { return new FactorTableIterator(this, false); } @Override public FactorTableIterator fullIterator() { return new FactorTableIterator(this, true); } /*-------------------------- * IFactorTableBase methods */ @Override public double density() { return (double)countNonZeroWeights() / (double)jointSize(); } @Override public final int getDimensions() { return _domains.size(); } @Override public final JointDomainIndexer getDomainIndexer() { return _domains; } protected final void setDomainIndexer(JointDomainIndexer newDomains) { assert(_domains.domainsEqual(newDomains)); _domains = newDomains; } @Override public double getEnergyForElements(Object ... elements) { return getEnergyForJointIndex(_domains.jointIndexFromElements(elements)); } @Override public double getWeightForElements(Object ... elements) { return getWeightForJointIndex(_domains.jointIndexFromElements(elements)); } @Override public final @Nullable BitSet getInputSet() { return _domains.getInputSet(); } @Override public final @Nullable BitSet getOutputSet() { return _domains.getOutputSet(); } @Override public double getEnergyForIndices(int ... indices) { return getEnergyForJointIndex(_domains.jointIndexFromIndices(indices)); } @Override public double getWeightForIndices(int ... indices) { return getWeightForJointIndex(_domains.jointIndexFromIndices(indices)); } @Override public double getEnergyForValues(Value ... values) { return getEnergyForJointIndex(_domains.jointIndexFromValues(values)); } @Override public double getWeightForValues(Value ... values) { return getWeightForJointIndex(_domains.jointIndexFromValues(values)); } @Override public IFactorTable createTableConditionedOn(int[] valueIndices) { return createTableConditionedOn(valueIndices, false); } @Override public IFactorTable createTableConditionedOn(int[] valueIndices, boolean retainDimensions) { JointDomainReindexer conditioner = JointDomainReindexer.createConditioner(getDomainIndexer(), valueIndices, retainDimensions); return convert(conditioner); } @Override public int sparseIndexFromElements(Object ... elements) { return sparseIndexFromJointIndex(_domains.jointIndexFromElements(elements)); } @Override public int sparseIndexFromIndices(int ... indices) { return sparseIndexFromJointIndex(_domains.jointIndexFromIndices(indices)); } @Override public int sparseIndexFromValues(Value ... values) { return sparseIndexFromJointIndex(_domains.jointIndexFromValues(values)); } @Override public Object[] sparseIndexToElements(int sparseIndex, @Nullable Object[] elements) { return _domains.jointIndexToElements(sparseIndexToJointIndex(sparseIndex), elements); } @Override public int[] sparseIndexToIndices(int sparseIndex, @Nullable int[] indices) { return _domains.jointIndexToIndices(sparseIndexToJointIndex(sparseIndex), indices); } @Override public final int[] sparseIndexToIndices(int sparseIndex) { return sparseIndexToIndices(sparseIndex, null); } @Override public boolean isDirected() { return _domains.isDirected(); } @Override public int jointSize() { return _domains.getCardinality(); } @Override public void setEnergyForElements(double energy, Object ... elements) { setEnergyForJointIndex(energy, _domains.jointIndexFromElements(elements)); } @Override public void setEnergyForIndices(double energy, int ... indices) { _domains.validateIndices(indices); setEnergyForJointIndex(energy, _domains.jointIndexFromIndices(indices)); } @Override public void setEnergyForValues(double energy, Value ... values) { _domains.validateValues(values); setEnergyForJointIndex(energy, _domains.jointIndexFromValues(values)); } @Override public void setWeightForElements(double weight, Object ... elements) { setWeightForJointIndex(weight, _domains.jointIndexFromElements(elements)); } @Override public void setWeightForIndices(double weight, int ... indices) { _domains.validateIndices(indices); setWeightForJointIndex(weight, _domains.jointIndexFromIndices(indices)); } @Override public void setWeightForValues(double energy, Value ... values) { _domains.validateValues(values); setWeightForJointIndex(energy, _domains.jointIndexFromValues(values)); } @Override public final boolean supportsJointIndexing() { return _domains.supportsJointIndexing(); } /*----------------------- * IFactorTable methods */ @Override public @Nullable FactorFunction getFactorFunction() { return _function; } @Override public void populateFromFunction(FactorFunction function) { final JointDomainIndexer domains = getDomainIndexer(); final IFactorTable table = this; final Value[] values = Value.createFromDomains(domains); if (function.isDeterministicDirected() && domains.isDirected()) { final int maxInput = domains.getInputCardinality(); final int[] outputs = new int[maxInput]; for (int inputIndex = 0; inputIndex < maxInput; ++inputIndex) { domains.inputIndexToValues(inputIndex, values); function.evalDeterministic(values); outputs[inputIndex] = domains.outputIndexFromValues(values); } table.setDeterministicOutputIndices(outputs); } else { IntArrayList indexes = new IntArrayList(); DoubleArrayList energies = new DoubleArrayList(); final int maxJoint = domains.getCardinality(); for (int jointIndex = 0; jointIndex < maxJoint; ++ jointIndex) { domains.jointIndexToValues(jointIndex, values); double energy = function.evalEnergy(values); if (!Double.isInfinite(energy)) { indexes.add(jointIndex); energies.add(energy); } } if (indexes.size() == maxJoint) { table.setEnergiesDense(Arrays.copyOf(energies.elements(), maxJoint)); } else { table.setEnergiesSparse(Arrays.copyOf(indexes.elements(), indexes.size()), Arrays.copyOf(energies.elements(), indexes.size())); } } _function = function; } @Override public void randomizeWeights(Random rand) { _function = null; if (hasDenseRepresentation()) { for (int i = jointSize(); --i >= 0;) { // nextDouble() produces range [0,1). Subtract that from 1.0 to get (0,1]. setWeightForJointIndex(1.0 - rand.nextDouble(), i); } } else { for (int i = sparseSize(); --i >= 0;) { setWeightForSparseIndex(1.0 - rand.nextDouble(), i); } } } }