/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.model.repeated;
import static java.util.Objects.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.Port;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
/*
* This class represents one stream of Nested Factor Graphs.
*/
public class FactorGraphStream
{
ArrayList<ArrayList<BlastFromThePastFactor>> _blastFromThePastChains = new ArrayList<ArrayList<BlastFromThePastFactor>>();
ArrayList<FactorGraph> _nestedGraphs = new ArrayList<FactorGraph>();
ArrayList<VariableStreamBase<?>> _variableStreams = new ArrayList<>();
private int _bufferSize = 0;
private Object [] _args;
private FactorGraph _graph;
private FactorGraph _repeatedGraph;
private FactorGraph _parameterFactorGraph;
/*
* The constructor adds Factors and Variables and BlastFromThePastFactors.
*/
public FactorGraphStream(FactorGraph fg,
FactorGraph repeatedGraph,
int bufferSize,
Object ... args)
{
//Save arguments so that we can increase buffer size later
_args = args;
_graph = fg;
_repeatedGraph = repeatedGraph;
_parameterFactorGraph = new FactorGraph();
_parameterFactorGraph.setSolverFactory(_graph.getFactorGraphFactory());
//The following few lines of code retrieve the unique variable streams
HashSet<VariableStreamBase<?>> variableStreams = new HashSet<>();
//Find unique variable streams
for (int i = 0; i < args.length; i++)
{
if (args[i] instanceof IVariableStreamSlice)
{
VariableStreamBase<?> vsb = ((IVariableStreamSlice<?>)args[i]).getStream();
variableStreams.add(vsb);
}
}
for (VariableStreamBase<?> vsb : variableStreams)
_variableStreams.add(vsb);
//Here we build up the nested graphs.
setBufferSize(bufferSize);
//Setup blast from past factors
//Get first Factor Graph's ports
FactorGraph firstGraph = _nestedGraphs.get(0);
Collection<Port> ports = firstGraph.getPorts();
//For each port
for (Port p : ports)
{
//figure out which variable stream this is connected to
Variable var = (Variable)p.getSiblingNode();
VariableStreamBase<?> vsb = getVariableStream(var);
if (vsb == null)
{
//This is a parameter
//Add BlastFrom the Past Factor and save it
//Find out if we've encountered this parameter before
if (! _parameter2blastFromThePastHandler.containsKey(var))
{
BlastFromThePastFactor f = _graph.addBlastFromPastFactor(var, p.getSiblingPort());
ParameterBlastFromThePastHandler pbftph = new ParameterBlastFromThePastHandler(
var,_parameterFactorGraph,f);
_parameter2blastFromThePastHandler.put(var,pbftph);
}
_parameter2blastFromThePastHandler.get(var).addBlastFromThePast(p.getSiblingPort());
}
else
{
//This is not a parameter
//Retrieve the index of this variable within the stream
int index = vsb.indexOf(var);
//Set next port to this port
Port nextPort = p;
if (index > 0 )
{
ArrayList<BlastFromThePastFactor> bfc = new ArrayList<BlastFromThePastFactor>();
_blastFromThePastChains.add(bfc);
//For each variable before this one
for (int i = index-1; i >= 0; i--)
{
//add blast from the past
Variable var2 = vsb.get(i);
//Initalize the input msg
BlastFromThePastFactor f = fg.addBlastFromPastFactor(var2,nextPort.getSiblingPort());
bfc.add(f);
//Set the next port
nextPort = f.getPort(0);
}
}
}
}
}
public FactorGraph getParameterFactorGraph()
{
return _parameterFactorGraph;
}
public int getBufferSize()
{
return _nestedGraphs.size();
}
public void setBufferSize(int size)
{
if (size > _bufferSize)
{
for (int i = 0; i < size-_bufferSize; i++)
addStep();
_bufferSize = size;
}
else if (size < _bufferSize)
{
for (int i = _bufferSize-1; i >= size; i--)
{
_graph.remove(_nestedGraphs.get(i));
_nestedGraphs.remove(i);
}
for (VariableStreamBase<?> v : _variableStreams)
v.cleanupUnusedVariables();
}
}
private HashMap<Variable, ParameterBlastFromThePastHandler> _parameter2blastFromThePastHandler = new HashMap<Variable, FactorGraphStream.ParameterBlastFromThePastHandler>();
private class ParameterBlastFromThePastHandler
{
private Variable _otherVar;
private Variable _myVar;
private FactorGraph _fg;
private BlastFromThePastFactor _mainBlastFromThePast;
private ArrayList<BlastFromThePastFactor> _allBlastFromThePasts = new ArrayList<BlastFromThePastFactor>();
public ParameterBlastFromThePastHandler(Variable var,FactorGraph fg,
BlastFromThePastFactor originalPlastFromPast)
{
_otherVar = var;
_myVar = _otherVar.clone();
_myVar.setPrior(null);
_fg = fg;
_mainBlastFromThePast = originalPlastFromPast;
Port factorPort = originalPlastFromPast.getPort(0);
// create a data structure to represent it
// Add a blast from the past for this variable
// Create a Factor Graph for this variable (maybe share with others)
// Add a blast to the past to be paired with the blast from the past
addBlastFromThePast(factorPort.getSiblingPort());
}
public void addBlastFromThePast(Port p)
{
_allBlastFromThePasts.add(_fg.addBlastFromPastFactor( _myVar, p));
}
public void advance()
{
final ISolverFactor sfactor = requireNonNull(_mainBlastFromThePast.getSolver());
for (BlastFromThePastFactor f : _allBlastFromThePasts)
{
f.advance();
}
requireNonNull(sfactor.getSiblingEdgeState(0)).setFactorToVarMsg(_myVar.getBeliefObject());
}
}
public void advance()
{
//Deal with parameters
//for each parameter
//Get data structure associated with that parameter
//Tell that data structure to advance
for (ParameterBlastFromThePastHandler h : _parameter2blastFromThePastHandler.values())
{
h.advance();
}
//For each blast from the past chain
for (ArrayList<BlastFromThePastFactor> al : _blastFromThePastChains)
{
//For each blast from the past
for (BlastFromThePastFactor bfp : al)
{
//Get new message
bfp.advance();
}
}
//For each graph in list of nested graphs
for (int j = 0; j < _nestedGraphs.size()-1; j++)
{
//Tell it to move all factor messages to left
final ISolverFactorGraph otherGraph = requireNonNull(_nestedGraphs.get(j+1).getSolver());
requireNonNull(_nestedGraphs.get(j).getSolver()).moveMessages(otherGraph);
}
//Newest nested graph should initialize its messages
_nestedGraphs.get(_nestedGraphs.size()-1).recreateMessages();
}
public boolean hasNext()
{
for (VariableStreamBase<?> s : _variableStreams)
{
if (!s.hasNext())
return false;
}
return true;
}
private void addStep()
{
Variable [] boundaryVariables = new Variable[_args.length];
for (int j = 0; j < _args.length; j++)
{
if (_args[j] instanceof IVariableStreamSlice)
boundaryVariables[j] = ((IVariableStreamSlice<?>)_args[j]).get(_nestedGraphs.size(),true);
else
boundaryVariables[j] = (Variable)_args[j];
}
//Add nested graph
FactorGraph ng = _graph.addFactor(_repeatedGraph, boundaryVariables);
_nestedGraphs.add(ng);
}
private @Nullable VariableStreamBase<?> getVariableStream(Variable var)
{
for (int i = 0; i < _variableStreams.size(); i++)
{
if (_variableStreams.get(i).contains(var))
return _variableStreams.get(i);
}
return null;
}
public ArrayList<FactorGraph> getNestedGraphs()
{
return _nestedGraphs;
}
}