/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate;
import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;
import static java.util.Objects.*;
import java.util.List;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.data.IDatum;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState;
public class NormalSampler implements IRealConjugateSampler
{
private final double MAX_SIGMA = 1e12;
private final NormalParameters _parameters = new NormalParameters();
@Override
public final double nextSample(ISolverEdgeState[] edges, List<? extends IDatum> inputs)
{
aggregateParameters(_parameters, edges, inputs);
return nextSample(_parameters);
}
@Override
public final void aggregateParameters(IParameterizedMessage aggregateParameters, ISolverEdgeState[] edges,
List<? extends IDatum> inputs)
{
NormalParameters parameters = (NormalParameters)aggregateParameters;
parameters.setNull();
for (IDatum input : inputs)
{
NormalParameters normalInput;
if (input instanceof NormalParameters)
{
normalInput = (NormalParameters)input;
}
else if (input instanceof Normal)
{
normalInput = ((Normal)input).getParameterizedMessage();
}
else
continue; // should be impossible
parameters.addFrom(normalInput);
}
final int numEdges = edges.length;
for (int i = 0; i < numEdges; i++)
{
// The message from each neighboring factor is an array with elements (mean, precision)
NormalParameters message = requireNonNull((NormalParameters)edges[i].getFactorToVarMsg());
parameters.addFrom(message);
}
}
public final double nextSample(NormalParameters parameters)
{
double mean = parameters.getMean();
double precision = parameters.getPrecision();
double normal = activeRandom().nextGaussian();
if (precision > 0)
return mean + normal / Math.sqrt(precision);
else
return mean + normal * MAX_SIGMA;
}
@Override
public IParameterizedMessage createParameterMessage()
{
return new NormalParameters();
}
// A static factory that creates a sampler of this type
public static final IRealConjugateSamplerFactory factory = new IRealConjugateSamplerFactory()
{
@Override
public IRealConjugateSampler create() {return new NormalSampler();}
@Override
public boolean isCompatible(@Nullable IUnaryFactorFunction factorFunction)
{
if (factorFunction == null)
return true;
else if (factorFunction instanceof Normal || factorFunction instanceof NormalParameters)
return true;
else
return false;
}
@Override
public boolean isCompatible(RealDomain domain)
{
return (domain.getLowerBound() == Double.NEGATIVE_INFINITY) && (domain.getUpperBound() == Double.POSITIVE_INFINITY);
}
};
}