/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.model;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.Random;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.BitSetUtil;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableRepresentation;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.google.common.collect.ObjectArrays;
import cern.colt.list.IntArrayList;
/**
* Test helper class for generating test graphs.
*
* @since 0.05
* @author Christopher Barber
*/
public class RandomGraphGenerator
{
/*-------
* State
*/
public static enum Direction
{
NONE, BACKWARD, FORWARD;
}
private static final DiscreteDomain[] _defaultDomains = { DiscreteDomain.bit() };
private final Random _rand;
private DiscreteDomain[] _domains = _defaultDomains;
private int _maxBranches = 1;
private int _maxTreeWidth = 1;
private Direction _direction = Direction.NONE;
/*--------------
* Construction
*/
public RandomGraphGenerator(Random rand)
{
_rand = rand;
}
/*-------------------
* Attribute methods
*/
public Direction direction()
{
return _direction;
}
public RandomGraphGenerator direction(Direction direction)
{
_direction = direction;
return this;
}
public DiscreteDomain[] domains()
{
return _domains;
}
public RandomGraphGenerator domains(@Nullable DiscreteDomain ... domains)
{
_domains = (domains != null && domains.length > 0) ? domains : _defaultDomains;
return this;
}
public int maxBranches()
{
return _maxBranches;
}
public RandomGraphGenerator maxBranches(int n)
{
_maxBranches = n;
return this;
}
public int maxTreeWidth()
{
return _maxTreeWidth;
}
public RandomGraphGenerator maxTreeWidth(int width)
{
_maxTreeWidth = width;
return this;
}
/*--------------------------
* Graph generation methods
*/
/**
* Builds an n x n grid graph with variable domains chosen randomly from {@link #domains()}.
*
* @param n
*/
public FactorGraph buildGrid(int n)
{
return buildGrid(n, n);
}
/**
* Builds a m x n grid graph with variable domains chosen randomly from {@link #domains()}.
*
* @param m
* @param n
*/
public FactorGraph buildGrid(int m, int n)
{
final FactorGraph graph = new FactorGraph();
final Discrete[][] vars = new Discrete[m][n];
for (int i = 0; i < m; ++i)
{
for (int j = 0; j < n; ++j)
{
Discrete var = newDiscrete(String.format("%s%d", intToBase26(i), j));
vars[i][j] = var;
if (i > 0)
{
Discrete prev = vars[i-1][j];
addClique(graph, var, prev);
}
if (j > 0)
{
Discrete prev = vars[i][j-1];
addClique(graph, var, prev);
}
}
}
return graph;
}
public FactorGraph buildRandomGraph(int size)
{
final FactorGraph graph = new FactorGraph();
final int nRoots = _rand.nextInt(Math.min(size, maxTreeWidth())) + 1;
final Discrete[] roots = newDiscretes(nRoots, "root", 0);
graph.addVariables(roots);
addRandomGraph(graph, size - nRoots, roots);
return graph;
}
public void addRandomGraph(FactorGraph graph, int size, Discrete ... roots)
{
if (size <= 0)
{
return;
}
final int nRoots = roots.length;
final int nBranches = 1 + _rand.nextInt(Math.min(size, maxBranches()));
final int sizePerBranch = size / nBranches;
for (int branch = nBranches; --branch>=0;)
{
final int branchSize = branch > 0 ? sizePerBranch : sizePerBranch + size % nBranches ;
final int nVars = 1 + _rand.nextInt(Math.min(branchSize, maxTreeWidth()));
final Discrete[] vars = newDiscretes(nVars);
final int nCliques = _rand.nextInt(Math.min(nVars, nRoots)) + 1;
if (nCliques == 1)
{
// Just make one big clique
addClique(graph, roots.length, ObjectArrays.concat(roots, vars, Discrete.class));
}
else
{
// Randomly order roots and new variables
final ArrayList<Discrete> randomRoots = new ArrayList<Discrete>(Arrays.asList(roots));
Collections.shuffle(randomRoots, _rand);
final ArrayList<Discrete> randomVars = new ArrayList<Discrete>(Arrays.asList(vars));
Collections.shuffle(randomVars, _rand);
for (int n = nCliques + 1; --n>=1;)
{
int nRootsChosen = randomRoots.size();
int nVarsChosen = randomVars.size();
if (n > 1)
{
nRootsChosen = Math.max(1, nRootsChosen / n);
nVarsChosen = Math.max(1, nVarsChosen / n);
}
final Discrete[] cliqueVars = new Discrete[nRootsChosen + nVarsChosen];
int i = 0;
for (int j = 0; j < nRootsChosen; ++j)
{
cliqueVars[i++] = randomRoots.remove(randomRoots.size() - 1);
}
for (int j = 0; j < nVarsChosen; ++j)
{
cliqueVars[i++] = randomVars.remove(randomVars.size() - 1);
}
addClique(graph, nRootsChosen, cliqueVars);
}
}
}
}
/**
* Builds a random graph in the form of a tree with {@code size} nodes with at most {@code maxBranches}
* (i.e. every node has at most {@code maxBranches}+1 siblings) with variable domains chosen randomly
* from {@code domains}.
*/
public FactorGraph buildRandomTree(int size)
{
final FactorGraph graph = new FactorGraph();
Discrete root = newDiscrete("root");
graph.addVariables(root);
addRandomTree(graph, size - 1, root);
return graph;
}
/**
* Adds a random tree rooted from given {@code root} with {@code size} nodes with at most {@code maxBranches}
* (i.e. every node has at most {@code maxBranches}+1 siblings) with variable domains chosen randomly
* from {@code domains}.
*/
public void addRandomTree(FactorGraph graph, int size, Discrete root)
{
if (size <= 0)
{
return;
}
int nChildren = 1 + _rand.nextInt(Math.min(size, maxBranches()));
int childSize = size / nChildren;
for (int i = 0; i < nChildren; ++i)
{
Discrete child = newDiscrete();
addClique(graph, child, root);
addRandomTree(graph, childSize - 1, child);
}
}
/**
* Extended student Bayesian network from Koller's Probabilistic Graphical Models (Figure 9.8)
* <pre>
* c[3]
* |
* v
* d[3] i[3]
* \ / \
* v v v
* g[5] s[10]
* / | |
* / v |
* | l[2] |
* | \ |
* | v v
* | j[2]
* | /
* \ /
* v v
* h[2]
* </pre>
* Numbers in brackets indicate the variable cardinality.
*/
public FactorGraph buildStudentNetwork()
{
FactorGraph model = new FactorGraph();
Discrete c = newDiscrete(3, "c");
Discrete d = newDiscrete(3, "d");
addDirectedClique(model, d, c);
Discrete i = newDiscrete(3, "i");
Discrete g = newDiscrete(5, "g");
Discrete s = newDiscrete(10, "s");
addDirectedClique(model, g, d, i);
addDirectedClique(model, s, i);
Discrete l = newDiscrete(2, "l");
addDirectedClique(model, l, g);
Discrete j = newDiscrete(2, "j");
addDirectedClique(model, j, l, s);
Discrete h = newDiscrete(2, "h");
addDirectedClique(model, h, g, j);
return model;
}
public FactorGraph buildTriangle()
{
FactorGraph model = new FactorGraph();
Discrete a = newDiscrete("a");
Discrete b = newDiscrete("b");
Discrete c = newDiscrete("c");
addClique(model, a, b);
addClique(model, b, c);
addClique(model, a, c);
return model;
}
/**
* Build graph consisting of smallest possible loop consisting of two variables with domains randomly choosen
* from {@link #domains()} and connected by two separate factors.
*/
public FactorGraph buildTrivialLoop()
{
FactorGraph model = new FactorGraph();
Discrete a = newDiscrete("a");
Discrete b = newDiscrete("b");
addClique(model, a, b);
addClique(model, a, b);
return model;
}
/*------------------------------
* RandomGraphGenerator methods
*/
public Factor addClique(FactorGraph model, Discrete ... variables)
{
return addClique(model, 1, variables);
}
public Factor addClique(FactorGraph model, int nOutputs, Discrete ... variables)
{
BitSet toSet = new BitSet(variables.length);
for (int i = 0; i < nOutputs; ++i)
{
toSet.set(i);
}
final Factor factor = model.addFactor(randomTable(variables), variables);
labelFactor(factor);
switch (_direction)
{
case NONE:
break;
case BACKWARD:
toSet.flip(0, variables.length);
//$FALL-THROUGH$
case FORWARD:
if (factor.hasFactorTable())
{
IFactorTable table = factor.getFactorTable();
table.setDirected(toSet);
table.normalizeConditional();
}
factor.setDirectedTo(BitSetUtil.bitsetToIndices(toSet));
break;
}
return factor;
}
/**
* Adds a directed factor with first variable as output (directedTo).
*/
public Factor addDirectedClique(FactorGraph model, Discrete ... variables)
{
Direction prevDirection = _direction;
try
{
_direction = Direction.FORWARD;
return addClique(model, variables);
}
finally
{
_direction = prevDirection;
}
}
/**
* Chooses a domain randomly from {@link #domains()}. Returns {@link DiscreteDomain#bit()} if
* empty.
*/
public DiscreteDomain chooseDomain()
{
return chooseDomain(domains());
}
/**
* Chooses a domain randomly from {@code domains}. Returns {@link DiscreteDomain#bit()} if
* {@code domains} is empty.
*/
public DiscreteDomain chooseDomain(DiscreteDomain ... domains)
{
final int nDomains = domains.length;
switch (nDomains)
{
case 0:
return DiscreteDomain.bit();
case 1:
return domains[0];
default:
return domains[_rand.nextInt(nDomains)];
}
}
public Discrete newDiscrete()
{
return newDiscrete(null);
}
public Discrete newDiscrete(String name, int counter)
{
return newDiscrete(name + counter);
}
public Discrete newDiscrete(@Nullable String name)
{
return newDiscrete(chooseDomain(), name);
}
public Discrete newDiscrete(int cardinality, String name)
{
return newDiscrete(DiscreteDomain.range(1, cardinality), name);
}
public Discrete newDiscrete(DiscreteDomain domain, @Nullable String name)
{
Discrete var = new Discrete(domain);
if (name != null)
{
var.setName(name);
}
return var;
}
public Discrete[] newDiscretes(int n)
{
final Discrete[] discretes = new Discrete[n];
for (int i = 0; i < n; ++i)
{
discretes[i] = newDiscrete();
}
return discretes;
}
public Discrete[] newDiscretes(int n, String namePrefix, int counter)
{
final Discrete[] discretes = new Discrete[n];
for (int i = 0; i < n; ++i)
{
discretes[i] = newDiscrete(namePrefix, counter+i);
}
return discretes;
}
public IFactorTable randomTable(Discrete ... variables)
{
DiscreteDomain[] domains = new DiscreteDomain[variables.length];
for (int i = variables.length; --i>=0;)
{
domains[i] = variables[i].getDomain();
}
return randomTable(domains);
}
public IFactorTable randomTable(DiscreteDomain ... domains)
{
IFactorTable table = FactorTable.create(domains);
table.setRepresentation(FactorTableRepresentation.DENSE_ENERGY);
table.randomizeWeights(_rand);
return table;
}
/*-----------------------
* Static helper methods
*/
/**
* Give factor a label of the form f(<i>variables</i>) if it doesn't already have a name.
*/
public static void labelFactor(Factor factor)
{
if (factor.getExplicitName() == null)
{
StringBuffer name = new StringBuffer("f(");
for (int i = 0, end = factor.getSiblingCount(); i<end; ++i)
{
if (i > 0)
name.append(",");
name.append(factor.getSibling(i).getLabel());
}
name.append(")");
factor.setLabel(name.toString());
}
}
public static void labelFactors(FactorGraph graph)
{
for (Factor factor : graph.getFactors())
{
labelFactor(factor);
}
}
/*-----------------
* Private methods
*/
private static String intToBase26(int i)
{
IntArrayList digits = new IntArrayList();
for (long l = i & 0xFFFFFFFFL; true; l /= 26)
{
digits.add((int)(l % 26));
if (l < 26)
break;
}
StringBuilder sb = new StringBuilder();
for (int j = digits.size(); --j>=0;)
{
sb.append((char)('a' + digits.get(j)));
}
return sb.toString();
}
}