/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions;
import static java.util.Objects.*;
import java.util.Map;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities;
import com.analog.lyric.dimple.factorfunctions.core.IParametricFactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.UnaryFactorFunction;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public abstract class CategoricalBase extends UnaryFactorFunction implements IParametricFactorFunction
{
/*-------
* State
*/
private static final long serialVersionUID = 1L;
protected DiscreteMessage _parameters;
protected boolean _parametersConstant;
protected int _firstDirectedToIndex;
/*--------------
* Construction
*/
protected CategoricalBase(DiscreteMessage parameters, int firstDirectedToIndex)
{
super((String)null);
_parameters = parameters;
_parametersConstant = firstDirectedToIndex == 0;
_firstDirectedToIndex = firstDirectedToIndex;
}
protected CategoricalBase(DiscreteMessage parameters)
{
this(parameters, 0);
}
protected CategoricalBase(CategoricalBase other)
{
super(other);
_parameters = other._parameters.clone();
_parametersConstant = other._parametersConstant;
_firstDirectedToIndex = other._firstDirectedToIndex;
}
/*----------------
* IDatum methods
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (this == other)
{
return true;
}
if (getClass().isInstance(other))
{
CategoricalBase that = (CategoricalBase)requireNonNull(other);
return _parametersConstant == that._parametersConstant &&
_firstDirectedToIndex == that._firstDirectedToIndex &&
_parameters.objectEquals(that._parameters);
}
return false;
}
/*------------------------
* FactorFunction methods
*/
@Override
public final boolean isDirected()
{
return true;
}
@Override
public final int[] getDirectedToIndices(int numEdges)
{
// All edges except the parameter edges (if present) are directed-to edges
return FactorFunctionUtilities.getListOfIndices(_firstDirectedToIndex, numEdges-1);
}
/*-----------------------------------
* IParametricFactorFunction methods
*/
@Override
public int copyParametersInto(Map<String, Object> parameters)
{
if (_parametersConstant)
{
parameters.put("alpha", getParameters().clone());
return 1;
}
return 0;
}
@Override
@Nullable
public double[] getParameter(String parameterName)
{
if (_parametersConstant)
{
switch (parameterName)
{
case "alpha":
case "alphas":
return getParameters().clone();
}
}
return null;
}
@Override
public final boolean hasConstantParameters()
{
return _parametersConstant;
}
@Override
public DiscreteMessage getParameterizedMessage()
{
return _parameters;
}
public final double[] getParameters()
{
return _parameters.representation();
}
public final int getDimension()
{
return _parameters.size();
}
}