/******************************************************************************* * Copyright 2014-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.core.parameterizedMessages; import static com.analog.lyric.math.Utilities.*; import java.util.Arrays; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.data.IDatum; import com.analog.lyric.dimple.exceptions.NormalizationException; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.math.Utilities; /** * * @since 0.06 * @author Christopher Barber */ public class DiscreteWeightMessage extends DiscreteMessage { private static final long serialVersionUID = 1L; /*-------------- * Construction */ /** * Create message with specified initial weights. */ public DiscreteWeightMessage(double[] weights) { super(weights); } /** * Create message with specified size and all weights initially set to one (zero energy). */ public DiscreteWeightMessage(int size) { this(new double[size]); setNull(); } /** * Copies values from another message. * @since 0.08 */ public DiscreteWeightMessage(DiscreteWeightMessage other) { super(other); } /** * Copies values from another message. * @since 0.08 */ public DiscreteWeightMessage(DiscreteMessage other) { this(other.size()); setFrom(other); } /** * Sets values for domain from datum. * @param domain is used to set the {@link #size} and used to evaluate {@code datum} * @param datum if not null is used to initialize the message via {@link #setFrom(DiscreteDomain, IDatum)}, * otherwise the message will be {@link #setUniform() set to uniform} distribution. * @since 0.08 */ public DiscreteWeightMessage(DiscreteDomain domain, @Nullable IDatum datum) { this(new double[domain.size()]); if (datum != null) { setFrom(domain, datum); } else { setUniform(); } } @Override public DiscreteWeightMessage clone() { return new DiscreteWeightMessage(this); } /*------------------------------- * IParameterizedMessage methods */ @Override public boolean isNull() { for (double w : _message) if (w != 1.0) return false; return true; } @Override public void setNull() { Arrays.fill(_message, 1.0); _normalizationEnergy = weightToEnergy(_message.length); } /** * {@inheritDoc} * <p> * Sets all weight values to 1 / N (normalized). */ @Override public void setUniform() { Arrays.fill(_message, 1.0 / _message.length); _normalizationEnergy = 0.0; } /*------------------------- * DiscreteMessage methods */ @Override public void addWeightsFrom(DiscreteMessage other) { assertSameSize(other.size()); final double[] message = _message; if (other.storesWeights()) { final double[] otherMessage = other._message; for (int i = _message.length; --i >= 0; ) { message[i] += otherMessage[i]; } } else { for (int i = _message.length; --i >= 0; ) { message[i] += other.getWeight(i); } } forgetNormalizationEnergy(); } @Override public double[] getEnergies(double[] energies) { for (int i = _message.length; --i>=0;) energies[i] = weightToEnergy(_message[i]); return energies; } @Override public double[] getWeights(double[] weights) { System.arraycopy(_message, 0, weights, 0, _message.length); return weights; } @Override public double getWeight(int i) { return _message[i]; } @Override public void setWeight(int i, double weight) { _message[i] = weight; forgetNormalizationEnergy(); } @Override public void setWeights(double... weights) { final int length = weights.length; assertSameSize(length); System.arraycopy(weights, 0, _message, 0, length); forgetNormalizationEnergy(); } @Override public double getEnergy(int i) { return Utilities.weightToEnergy(_message[i]); } @Override public void setEnergy(int i, double energy) { _message[i] = Utilities.energyToWeight(energy); forgetNormalizationEnergy(); } @Override public void setEnergies(double... energies) { final int length = energies.length; assertSameSize(length); for (int i = 0; i < length; ++i) { _message[i] = energyToWeight(energies[i]); } forgetNormalizationEnergy(); } @Override public final boolean hasZeroWeight(int i) { return _message[i] == 0.0; } @Override public double sumOfWeights() { double sum = 0.0; for (double w : _message) sum += w; return sum; } @Override public void normalize() { double sum = 0.0; for (double w : _message) { if (w < 0) { throw new NormalizationException("Cannot normalize message because it contains a negative weight"); } sum += w; } if (sum == 0.0) { throw weightsAddUpToZero(); } for (int i = _message.length; --i >=0;) _message[i] /= sum; final double normalizer = weightToEnergy(sum); if (_normalizationEnergy != _normalizationEnergy) // NaN { _normalizationEnergy = 0.0; } else { _normalizationEnergy += normalizer; } } @Override public void setWeightsToZero() { Arrays.fill(_message, 0.0); forgetNormalizationEnergy(); } @Override public final boolean storesWeights() { return true; } @Override public int toDeterministicValueIndex() { int index = -1; for (int i = _message.length; --i>=0;) { if (_message[i] != 0) { if (index >= 0) { index = -1; break; } index = i; } } return index; } }