/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.optimizedupdate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.eclipse.jdt.annotation.NonNullByDefault;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.util.misc.Internal;
import com.google.common.primitives.Ints;
/**
* Produces the sequence of marginalization and output steps necessary to apply the optimized
* update algorithm.
*
* @param <T> The type that represents a factor table within the optimized update tree.
* @since 0.06
* @author jking
*/
@Internal
final class TreeWalker<T>
{
private final IFactorTable _factorTable;
private final int[] _mapping;
/**
* Constructs an instance that walks a given factor table.
*
* @since 0.06
*/
public TreeWalker(IFactorTable factorTable)
{
_factorTable = factorTable;
/*
* The remainder of this method sorts the factor table dimensions by domain size,
* producing an array that contains a mapping from old dimension index to new dimension
* index.
*/
final int dimensions = factorTable.getDimensions();
final int[] domainSizes = new int[dimensions];
final JointDomainIndexer domainIndexer = factorTable.getDomainIndexer();
List<Integer> indices = new ArrayList<Integer>(dimensions);
for (int i = 0; i < dimensions; i++)
{
indices.add(i);
domainSizes[i] = domainIndexer.getDomainSize(i);
}
Comparator<Integer> comparator = new Comparator<Integer>() {
@NonNullByDefault(false)
@Override
public int compare(Integer i, Integer j)
{
// Decreasing order
return Integer.compare(domainSizes[j], domainSizes[i]);
}
};
Collections.sort(indices, comparator);
_mapping = Ints.toArray(indices);
}
/**
* Produce the optimized update sequence.
*
* @param treeBuilder Receives calls describing the sequence.
* @since 0.06
*/
public void accept(ITreeBuilder<T> treeBuilder)
{
T rootT = treeBuilder.createRootT(_factorTable);
int order = _factorTable.getDomainIndexer().size();
loop(0, 1, rootT, order, treeBuilder, _mapping);
}
private void loop(final int p,
int step,
final T f,
final int order,
final ITreeBuilder<T> treeBuilder,
final int[] entries)
{
final int left = p;
final int right = p + step;
loop2(left, right, step * 2, f, order, treeBuilder, entries);
if (right < order)
{
loop2(right, left, step * 2, f, order, treeBuilder, entries);
}
}
private void loop2(final int x,
final int y,
final int step,
T f,
final int order,
ITreeBuilder<T> treeBuilder,
int[] entries)
{
final int offset = x > y ? 1 : 0;
for (int i = 0; x + i * step < order; i++)
{
int portNum = _mapping[x + i * step];
int rawLocalDimension = i + offset;
int localDimension = entries[i + offset];
f = treeBuilder.buildMarginalize(f, portNum, localDimension);
// Remove the entry at rawLocalDimension from entries, and decrease all entries
// greater than that entry by 1:
int[] entries2 = new int[entries.length - 1];
int k = 0;
for (int j = 0; j < entries.length; j++)
{
if (j != rawLocalDimension)
{
int v = entries[j];
if (v > entries[rawLocalDimension])
{
v -= 1;
}
entries2[k] = v;
k += 1;
}
}
entries = entries2;
}
if (y + step < order)
{
loop(y, step, f, order, treeBuilder, entries);
}
else
{
int portNum = _mapping[y];
treeBuilder.buildOutput(f, portNum);
}
}
}