/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.model.variables;
import static java.util.Objects.*;
import java.util.Arrays;
import org.eclipse.jdt.annotation.NonNull;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.data.IDatum;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.values.DiscreteValue;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteEnergyMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteWeightMessage;
import com.analog.lyric.dimple.solvers.interfaces.IDiscreteSolverVariable;
import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable;
import com.analog.lyric.util.misc.Internal;
@SuppressWarnings("deprecation")
public class Discrete extends VariableBase
{
/*--------------
* Construction
*/
public Discrete(DiscreteDomain domain)
{
this(domain, "Discrete");
}
public Discrete(Object... domain)
{
this(DiscreteDomain.create(domain),"Discrete");
if (domain.length < 1)
throw new DimpleException(String.format("ERROR Variable domain length %d must be at least 2", domain.length));
}
/**
* @deprecated as of release 0.08 use {@link #Discrete(DiscreteDomain)} instead.
*/
@Deprecated
public Discrete(DiscreteDomain domain, String modelerClassName)
{
super(domain, modelerClassName);
}
protected Discrete(Discrete that)
{
super(that);
}
@Override
@NonNull
public Discrete clone()
{
return new Discrete(this);
}
/*---------------------
* ISolverNode methods
*/
@Override
public @Nullable IDiscreteSolverVariable getSolver()
{
return (IDiscreteSolverVariable)super.getSolver();
}
/*------------------
* Variable methods
*/
@Override
public final Discrete asDiscreteVariable()
{
return this;
}
@Override
public DiscreteDomain getDomain()
{
return (DiscreteDomain)super.getDomain();
}
public DiscreteDomain getDiscreteDomain()
{
return getDomain();
}
@Override
public @Nullable Integer getFixedValueObject()
{
IDatum datum = getPrior();
if (datum instanceof Value)
{
return ((Value)datum).getIndex();
}
return null;
}
/*------------------
* Discrete methods
*/
public double [] getBelief()
{
return (double[])getBeliefObject();
}
@Override
public Object getBeliefObject()
{
final ISolverVariable svar = getSolver();
if (svar != null)
{
final Object belief = svar.getBelief();
if (belief != null)
{
return belief;
}
}
return getInputObject();
}
public int getGuessIndex()
{
return requireNonNull(getSolver()).getGuessIndex();
}
public void setGuessIndex(int guess)
{
requireNonNull(getSolver()).setGuessIndex(guess);
}
public Object getValue()
{
return requireSolver("getValue").getValue();
}
public int getValueIndex()
{
return ((IDiscreteSolverVariable)requireSolver("getValueIndex")).getValueIndex();
}
private double [] getDefaultPriors(DiscreteDomain domain)
{
final int length = domain.size();
double [] retval = new double[length];
double val = 1.0/length;
for (int i = 0; i < retval.length; i++)
retval[i] = val;
return retval;
}
/**
* {@inheritDoc}
*
* All {@code variables} must be of type {@link Discrete}. The domain of the returned
* variable will be a {@link JointDiscreteDomain} with the subdomains in the same order
* as {@code variables}.
*/
@Internal
@Override
public Variable createJointNoFactors(Variable ... variables)
{
final boolean thisIsFirst = (variables[0] == this);
final int dimensions = thisIsFirst ? variables.length: variables.length + 1;
final DiscreteDomain[] domains = new DiscreteDomain[dimensions];
final IDatum[] subdomainPriors = new IDatum[dimensions];
domains[0] = getDomain();
subdomainPriors[0] = getPrior();
for (int i = thisIsFirst ? 1 : 0; i < dimensions; ++i)
{
final Discrete var = variables[i].asDiscreteVariable();
domains[i] = var.getDomain();
subdomainPriors[i] = var.getPrior();
}
final JointDiscreteDomain<?> jointDomain = DiscreteDomain.joint(domains);
final Discrete jointVar = new Discrete(jointDomain);
jointVar.setPrior(joinPriors(jointDomain ,subdomainPriors));
return jointVar;
}
@Override
public @Nullable IDatum setPrior(@Nullable Object prior)
{
if (prior instanceof double[])
{
return setPrior((double[])prior);
}
if (prior instanceof Value)
{
Value value = (Value)prior;
final DiscreteDomain domain = getDomain();
if (!domain.equals(value.getDomain()))
{
// If domain does not match, create a new value with the correct domain. This ensures
// that indexing operations can be assumed to be correct for this variable.
prior = Value.create(domain, requireNonNull(value.getObject()));
}
}
return super.setPrior(prior);
}
/**
* Sets prior to new {@link DiscreteWeightMessage} with given weights.
* @param weights must have same length as variable's domain.
* @return previous value of prior
* @since 0.08
*/
public @Nullable IDatum setPrior(@Nullable double ... weights)
{
return setPrior(weights == null ? null : new DiscreteWeightMessage(weights));
}
/**
* Sets prior to a fixed discrete value with given index.
* @param index a valid index into the variable's {@linkplain #getDomain() domain}.
* @return previous value of prior
* @since 0.08
*/
public @Nullable IDatum setPriorIndex(int index)
{
return setPrior(Value.createWithIndex(getDomain(), index));
}
/**
* If prior is a {@link DiscreteValue}, returns its index, otherwise -1.
* @since 0.08
*/
public final int getPriorIndex()
{
Value value = getPriorValue();
return value != null ? value.getIndex() : -1;
}
/*--------------------
* Deprecated methods
*/
@Deprecated
public double [] getInput()
{
return (double[])getInputObject();
}
@Deprecated
@Override
public Object getInputObject()
{
Object input = priorToInput(getPrior());
if (input == null)
{
input = getDefaultPriors(getDiscreteDomain());
}
return input;
}
/**
* @deprecated use {@link #setPrior(double...)} instead
*/
@Deprecated
public void setInput(@Nullable double ... value)
{
setPrior(value);
}
/**
* @deprecated use {@link #getPriorIndex()} instead
*/
@Deprecated
public final int getFixedValueIndex()
{
Integer index = getFixedValueObject();
if (index == null)
throw new DimpleException("Fixed value not set");
return index;
}
/**
* @deprecated use {@link #getPriorValue()} instead
*/
@Deprecated
public final Object getFixedValue()
{
Integer index = getFixedValueObject();
if (index == null)
throw new DimpleException("Fixed value not set");
return getDomain().getElement(index);
}
/**
* @deprecated use {@link #setPriorIndex(int)} instead.
*/
@Deprecated
public void setFixedValueIndex(int fixedValueIndex)
{
setPrior(Value.createWithIndex(getDomain(), fixedValueIndex));
}
/**
* @deprecated use {@link #setPrior} instead.
*/
@Deprecated
public void setFixedValue(Object fixedValue)
{
setPrior(Value.create(getDomain(), fixedValue));
}
@Deprecated
@Override
public void setFixedValueObject(@Nullable Object value)
{
setPrior(value != null ? Value.createWithIndex(getDomain(), (Integer)value) : null);
}
/*----------------------------
* Protected/internal methods
*/
@Override
protected @Nullable Object priorToFixedValue(@Nullable IDatum prior)
{
return prior instanceof Value ? ((Value)prior).getIndex() : null;
}
@Override
protected @Nullable Object priorToInput(@Nullable IDatum prior)
{
if (prior instanceof Value)
{
final int index = ((Value)prior).getIndex();
final double[] input = new double[getDomain().size()];
input[index] = 1.0;
return input;
}
else if (prior instanceof DiscreteMessage)
{
return ((DiscreteMessage)prior).getWeights();
}
return prior;
}
/*-----------------
* Private methods
*/
private @Nullable IDatum joinPriors(JointDiscreteDomain<?> jointDomain, IDatum[] subdomainPriors)
{
final JointDomainIndexer domains = jointDomain.getDomainIndexer();
final int dimensions = jointDomain.getDimensions();
boolean hasPrior = false;
int[] fixedIndices = new int[dimensions];
Arrays.fill(fixedIndices, -1);
for (int i = 0; i < dimensions; ++i)
{
DiscreteDomain domain = domains.get(i);
IDatum prior = subdomainPriors[i];
if (prior != null)
{
hasPrior = true;
if (prior instanceof Value)
{
Value value = (Value)prior;
fixedIndices[i] =
domain.equals(value.getDomain()) ? value.getIndex() : domain.getIndex(value.getObject());
subdomainPriors[i] = new DiscreteEnergyMessage(domain, value);
}
else
{
DiscreteMessage msg = prior instanceof DiscreteMessage ? (DiscreteMessage)prior :
new DiscreteWeightMessage(domain, prior);
subdomainPriors[i] = msg;
fixedIndices[i] = msg.toDeterministicValueIndex();
}
}
}
if (!hasPrior)
{
// If none of the component variables has a prior, then neither will the joint variable.
return null;
}
boolean hasAllFixedPriors = true;
for (int i : fixedIndices)
{
if (i < 0)
{
hasAllFixedPriors = false;
break;
}
}
if (hasAllFixedPriors)
{
// Return fixed value with appropriate joint index.
return Value.createWithIndex(jointDomain, domains.jointIndexFromIndices(fixedIndices));
}
int cardinality = jointDomain.size();
double[] energies = new double[cardinality];
int inner = 1, outer = cardinality;
for (int dim = 0; dim < dimensions; ++dim)
{
final DiscreteDomain domain = domains.get(dim);
final DiscreteMessage prior = (DiscreteMessage)subdomainPriors[dim];
final int size = domain.size();
int i = 0;
outer /= size;
if (prior != null)
{
for (int o = 0; o < outer; ++o)
{
for (double energy : prior.getEnergies())
{
for (int r = 0; r < inner; ++r)
{
energies[i++] += energy;
}
}
}
}
inner *= size;
}
return new DiscreteEnergyMessage(energies);
}
}