/******************************************************************************* * Copyright 2014-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.core.parameterizedMessages; import static com.analog.lyric.math.Utilities.*; import static java.lang.String.*; import java.io.PrintStream; import java.util.Arrays; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.data.IDatum; import com.analog.lyric.dimple.exceptions.NormalizationException; import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.values.DiscreteValue; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.util.misc.Matlab; /** * * @since 0.06 * @author Christopher Barber */ @Matlab(wrapper="DiscreteMessage") public abstract class DiscreteMessage extends ParameterizedMessageBase { private static final long serialVersionUID = 1L; /*------- * State */ protected final double[] _message; /*-------------- * Construction */ DiscreteMessage(double[] message) { _message = message.clone(); } DiscreteMessage(DiscreteMessage other) { super(other); _message = other._message.clone(); } /*---------------- * IDatum methods */ @Override public double evalEnergy(Value value) { return getEnergy(value.getIndexOrInt()); } @Override public boolean objectEquals(@Nullable Object other) { if (this == other) { return true; } if (other != null && other.getClass() == getClass()) { DiscreteMessage that = (DiscreteMessage)other; return super.objectEquals(other) && Arrays.equals(_message, that._message); } return false; } /*-------------------- * IPrintable methods */ @Override public void print(PrintStream out, int verbosity) { if (verbosity >= 0) { out.print(storesWeights() ? "weights" : "energies"); out.print('('); for (int i = 0, end = _message.length; i < end; ++i) { if (i > 0) { out.print(','); if (verbosity > 1) { out.print(' '); } } if (verbosity > 1) { out.format("%d=", i); } out.format("%g", _message[i]); } out.print(')'); final double normalizationEnergy = getNormalizationEnergy(); if (normalizationEnergy != 0.0) { if (storesWeights()) { out.format(" [/ %g]", energyToWeight(normalizationEnergy)); } else { out.format(" [- %g]", normalizationEnergy); } } } } /*------------------------------- * IParameterizedMessage methods */ @Override public abstract DiscreteMessage clone(); /** * {@inheritDoc} * <p> * @param other is {@link DiscreteMessage} with matching {@link #size()}. * @throws ClassCastException if {@code other} is not a {@link DiscreteMessage}. * @throws IllegalArgumentException if {@code other} does not have matching size. * @since 0.08 */ @Override public void addFrom(IParameterizedMessage other) { addEnergiesFrom((DiscreteMessage)other); } /** * {@inheritDoc} * <p> * Discrete messages compute KL using: * <blockquote> * <big>Σ</big> ln(P<sub>i</sub> / Q<sub>i</sub>) P<sub>i</sub> * </blockquote> */ @Override public double computeKLDivergence(IParameterizedMessage that) { if (that instanceof DiscreteMessage) { // KL(P|Q) == sum(log(Pi/Qi) * Pi) // // To normalize you need to divide Pi by sum(Pi) and Qi by sum(Qi), denote these // by Ps and Qs: // // ==> sum(log((Pi/Ps)/(Qi/Qs)) * Pi/Ps) // // ==> 1/Ps * sum(log(Pi/Qi) * Pi + log(Qs/Ps) * Pi) // // ==> sum(Pi*(log(Pi) - log(Qi)))/Ps + log(Qs/Ps) // // This formulation allows you to perform the computation using a single loop. final DiscreteMessage P = this; final DiscreteMessage Q = (DiscreteMessage)that; final int size = P.size(); if (size != Q.size()) { throw new IllegalArgumentException( String.format("Mismatched domain sizes '%d' and '%d'", P.size(), Q.size())); } double Ps = 0.0, Qs = 0.0, unnormalizedKL = 0.0; for (int i = 0; i < size; ++i) { final double pw = P.getWeight(i); if (pw == 0.0) continue; final double qw = Q.getWeight(i); Ps += pw; Qs += qw; final double pe = P.getEnergy(i); final double qe = Q.getEnergy(i); unnormalizedKL += pw * (qe - pe); } return unnormalizedKL / Ps + Math.log(Qs/Ps); } throw new IllegalArgumentException(String.format("Expected '%s' but got '%s'", getClass(), that.getClass())); } @Override public final boolean hasDeterministicValue() { return toDeterministicValueIndex() >= 0; } @Override public void setDeterministic(Value value) { int index = value.getIndex(); if (index < 0) { throw new IllegalArgumentException(format("%s is not discrete", value)); } setDeterministicIndex(index); } @Override public void setFrom(IParameterizedMessage other) { setFrom((DiscreteMessage)other); } @Override public final @Nullable Value toDeterministicValue(Domain domain) { int index = toDeterministicValueIndex(); return index >= 0 ? Value.createWithIndex((DiscreteDomain) domain, index) : null; } /*------------------------- * DiscreteMessage methods */ /** * The size of the message, i.e. the number of discrete elements of the domain. * * @since 0.06 */ public final int size() { return _message.length; } /** * Add energies from other message. * @since 0.08 */ public void addEnergiesFrom(DiscreteMessage other) { assertSameSize(other.size()); for (int i = _message.length; --i>=0;) { setEnergy(i, getEnergy(i) + other.getEnergy(i)); } } public void addWeightsFrom(DiscreteMessage other) { assertSameSize(other.size()); for (int i = _message.length; --i>=0;) { setWeight(i, getWeight(i) + other.getWeight(i)); } } /** * Returns copy of all of the energy values in the message. * @since 0.08 */ public final double[] getEnergies() { return getEnergies(new double[_message.length]); } /** * Copies energies into provided array and returns it. * @param array an array with length >= {@link #size}. * @since 0.08 */ public abstract double[] getEnergies(double[] array); /** * Returns copy of all of the weight values in the message. * @since 0.08 */ public final double[] getWeights() { return getWeights(new double[_message.length]); } /** * Copies weights into provided array and returns it. * @param array an array with length >= {@link #size}. * @since 0.08 */ public abstract double[] getWeights(double[] array); public abstract double getWeight(int i); public abstract void setWeight(int i, double weight); public abstract double getEnergy(int i); public abstract void setEnergy(int i, double energy); public abstract void setWeights(double ... weights); public abstract void setEnergies(double ... energies); public abstract void setWeightsToZero(); /** * True if weight at given index is zero. * <p> * This is the same as: * <blockquote><tt> * {@link #getWeight}(i) == 0.0 * </tt></blockquote> * but may be faster if underlying representation uses energies. * @since 0.08 */ public abstract boolean hasZeroWeight(int i); /** * Compute sum of all weights in message. * @since 0.08 */ public abstract double sumOfWeights(); /** * Normalize so that weights sum to one. * <p> * This will compute the {@linkplain #sumOfWeights() sum of the weights} and use that to * normalize the message. * <p> * If the {@link #getNormalizationEnergy() normalization energy} has not already been computed, it * will be set to zero. If it had been computed, then the energy of the sum of weights will be added * to it. * <p> * @throws NormalizationException if {@link #sumOfWeights()} is zero. */ public abstract void normalize(); /** * Sets parameters to produce weight of 1.0 for given index and zero elsewhere. * * @param index must be non-negative and less than {@link #size()}. * @since 0.08 */ public void setDeterministicIndex(int index) { setWeightsToZero(); setWeight(index, 1.0); } /** * Sets values from another message of the same size. * * @param other is another message with the same {@link #size()} as this one but not necessarily * the same representation. * @since 0.08 * @throws IllegalArgumentException if {@code other} does not have the same size. */ public void setFrom(DiscreteMessage other) { final double[] otherRep = other.representation(); if (other.storesWeights()) { setWeights(otherRep); } else { setEnergies(otherRep); } _normalizationEnergy = other._normalizationEnergy; } /** * Sets values for domain from datum. * <p> * @param domain discrete domain with size matching {@link #size()}. * @param datum either an exact {@link Value}, another {@link DiscreteMessage} or other * {@link IUnaryFactorFunction} used to evaluate energies for all possible discrete values. * @since 0.08 */ public void setFrom(DiscreteDomain domain, IDatum datum) { if (datum instanceof DiscreteMessage) { setFrom((DiscreteMessage)datum); } else if (datum instanceof Value) { Value value = (Value)datum; if (domain.equals(value.getDomain())) { setDeterministicIndex(value.getIndex()); } else { setDeterministicIndex(domain.getIndex(value.getObject())); } } else { IUnaryFactorFunction function = (IUnaryFactorFunction)datum; assertSameSize(domain.size()); DiscreteValue value = Value.create(domain); for (int i = domain.size(); --i>=0;) { value.setIndex(i); setEnergy(i, function.evalEnergy(value)); } } } /** * Returns underlying message representation. * @since 0.08 */ public final double[] representation() { return _message; } /** * True if underlying representation uses weights, false if it uses energies. * @since 0.06 */ public abstract boolean storesWeights(); /** * Returns the only index with non-zero probability (non-infinite energy) if there is one. * <p> * @return index or -1 if there is not a unique index. * @since 0.08 */ public abstract int toDeterministicValueIndex(); /*------------------- * Protected methods */ protected void assertSameSize(int otherSize) { if (size() != otherSize) { throw new IllegalArgumentException(String.format("Cannot set from message with different size (%d vs %d)", size(), otherSize)); } } protected NormalizationException weightsAddUpToZero() { return new NormalizationException("Cannot normalize message because weights add up to zero"); } @Override protected final double computeNormalizationEnergy() { return weightToEnergy(sumOfWeights()); } }