/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.data;
import static java.lang.String.*;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.FactorGraphIterables;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.IVariableToValue;
import com.analog.lyric.dimple.model.variables.Variable;
import com.google.common.collect.Lists;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public class DataStack extends AbstractList<DataLayer<?>> implements IVariableToValue
{
/*-------
* State
*/
private final ArrayList<DataLayer<?>> _stack;
/*--------------
* Construction
*/
public DataStack(Collection<DataLayer<?>> layers)
{
if (layers.size() == 0)
throw new IllegalArgumentException(format("Cannot create %s with no layers.", getClass().getSimpleName()));
_stack = new ArrayList<>(layers);
// Ensure that all layers are for the same graph
final FactorGraph root = _stack.get(0).rootGraph();
for (int i = 1, n = _stack.size(); i < n; ++i)
{
if (_stack.get(i).rootGraph() != root)
{
throw new IllegalArgumentException(format("Cannot create %s with layers from different graphs",
getClass().getSimpleName()));
}
}
}
public DataStack(DataLayer<?> firstLayer, DataLayer<?> ... additionalLayers)
{
this(Lists.asList(firstLayer, additionalLayers));
}
/*--------------
* List methods
*/
@Override
public DataLayer<?> get(int index)
{
return _stack.get(index);
}
@Override
public int size()
{
return _stack.size();
}
/*------------------
* IVariableToValue
*/
@Override
@Nullable
public Value varToValue(Variable var)
{
return getValue(var);
}
/*-------------------
* DataStack methods
*/
/**
* Computes the total energy for the graph tree represented by this data stack.
* <p>
* Computes the total energy by adding the energy evaluated for all the factors
* and variable priors and conditioning functions given the value specified for
* each variable in the data stack. Specifically:
* <ul>
* <li><b>for each factor</b>: {@linkplain #getValue(Variable) looks up the value} for each of the variables
* connected to the factor and passes them the {@linkplain FactorFunction#evalEnergy(Value[])
* evalEnergy} method of the factor's {@linkplain Factor#getFactorFunction() factor function}.
*
* <li><b>for each variable</b>: {@linkplain #getValue(Variable) looks up the value} for the variable and
* passes it to the {@linkplain IUnaryFactorFunction#evalEnergy(Value) evalEnergy} method of each
* {@link IUnaryFactorFunction} specified for that variable in layers that precede the layer containing
* the variable value.
* </ul>
* <p>
* @since 0.08
* @throws IllegalStateException if any variable in the graph lacks a value.
*/
public double computeTotalEnergy()
{
final FactorGraph root = rootGraph();
final int nLayers = _stack.size();
double energy = 0.0;
final IUnaryFactorFunction[] functions = new IUnaryFactorFunction[nLayers];
for (Variable var : FactorGraphIterables.variables(root))
{
for (int i = 0; i < nLayers; ++i)
{
IDatum datum = _stack.get(i).get(var);
if (datum instanceof Value)
{
Value value = (Value)datum;
if (!var.getDomain().valueInDomain(value))
{
return Double.POSITIVE_INFINITY;
}
while (--i >= 0)
{
IUnaryFactorFunction function = functions[i];
if (function != null)
{
energy += function.evalEnergy(value);
if (energy == Double.POSITIVE_INFINITY)
{
return energy;
}
}
}
break;
}
else if (datum instanceof IUnaryFactorFunction)
{
functions[i] = (IUnaryFactorFunction)datum;
}
else
{
functions[i] = null;
}
}
}
Value[] values = null;
for (Factor factor : FactorGraphIterables.factors(root))
{
values = factor.fillInArgumentValues(this, null);
energy += factor.evalEnergy(values);
if (energy == Double.POSITIVE_INFINITY)
{
return energy;
}
}
return energy;
}
public @Nullable Value getValue(Variable var)
{
for (DataLayer<?> layer : _stack)
{
IDatum datum = layer.get(var);
if (datum instanceof Value)
{
return (Value)datum;
}
}
return null;
}
/**
* Root graph for all layers in this stack.
* @since 0.08
* @see DataLayer#rootGraph()
*/
public FactorGraph rootGraph()
{
return _stack.get(0).rootGraph();
}
}