/******************************************************************************* * Copyright 2013 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.factorfunctions; import java.util.Map; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities; import com.analog.lyric.dimple.factorfunctions.core.IParametricFactorFunction; import com.analog.lyric.dimple.factorfunctions.core.UnaryFactorFunction; import com.analog.lyric.dimple.model.values.Value; /** * Parameterized Bernoulli distribution, which corresponds to p(x | p), * where p is the probability parameter * * The conjugate prior for p is a Beta distribution. * Depending on the solver, it may or may not be necessary to use a * conjugate prior (for the Gibbs solver, for example, it is not). * * The variables in the argument list are ordered as follows: * * 1) p: Probability parameter * 2...) An arbitrary number of discrete output variable (MUST be zero-based integer values [e.g., Bit variable]) // TODO: remove this restriction * * The parameter may optionally be specified as a constant in the constructor. * In this case, the parameter is not included in the list of arguments. */ public class Bernoulli extends UnaryFactorFunction implements IParametricFactorFunction { private static final long serialVersionUID = 1L; private double _p; private boolean _parametersConstant; private int _firstDirectedToIndex; /*-------------- * Construction */ public Bernoulli() // Variable parameter { super((String)null); _parametersConstant = false; _firstDirectedToIndex = 1; } /** * @since 0.05 */ public Bernoulli(double p) // Constant parameter { this(); _p = p; _parametersConstant = true; _firstDirectedToIndex = 0; if (p < 0 || p > 1) throw new DimpleException("Invalid parameter value. Must be in range [0, 1]."); } /** * Constructs with fixed probability parameter. * @param parameterMap specifies the Bernoulli parameter in the entry with the key "p". * If there is no such entry, the parameter will default to .5. * @since 0.07 */ public Bernoulli(Map<String,Object> parameterMap) { this((double)getOrDefault(parameterMap, "p", .5)); } protected Bernoulli(Bernoulli other) { super(other); _p = other._p; _parametersConstant = other._parametersConstant; _firstDirectedToIndex = other._firstDirectedToIndex; } @Override public Bernoulli clone() { return new Bernoulli(this); } /*---------------- * IDatum methods */ @Override public boolean objectEquals(@Nullable Object other) { if (this == other) { return true; } if (other instanceof Bernoulli) { Bernoulli that = (Bernoulli)other; return _p == that._p && _parametersConstant == that._parametersConstant && _firstDirectedToIndex == that._firstDirectedToIndex; } return false; } /*------------------------- * FactorFunction methods */ @Override public final double evalEnergy(Value[] arguments) { double p; if (_parametersConstant) { p = _p; } else { p = arguments[0].getDouble(); // First argument, if present, is parameter, p if (p < 0 || p > 1) return Double.POSITIVE_INFINITY; } int numZeros = 0; final int length = arguments.length; for (int i = _firstDirectedToIndex; i < length; i++) { int x = arguments[i].getInt(); // Remaining arguments are Discrete or Bit variables if (x == 0) numZeros++; } final int N = length - _firstDirectedToIndex; // Number of non-parameter variables final int numOnes = N - numZeros; if (p == 0) if (numOnes > 0) return Double.POSITIVE_INFINITY; else return 0; else if (p == 1) if (numZeros > 0) return Double.POSITIVE_INFINITY; else return 0; else return -(numOnes * Math.log(p) + numZeros * Math.log(1-p)); } @Override public final boolean isDirected() {return true;} @Override public final int[] getDirectedToIndices(int numEdges) { // All edges except the parameter edges (if present) are directed-to edges return FactorFunctionUtilities.getListOfIndices(_firstDirectedToIndex, numEdges-1); } /*----------------------------------- * IParametricFactorFunction methods */ @Override public int copyParametersInto(Map<String, Object> parameters) { if (_parametersConstant) { parameters.put("p", _p); return 1; } else { return 0; } } @Override public @Nullable Object getParameter(String parameterName) { if (_parametersConstant) { switch (parameterName) { case "p": return _p; } } return null; } @Override public final boolean hasConstantParameters() { return _parametersConstant; } /*-------------------------- * Factor-specific methods */ public final double getParameter() { return _p; } }