/* * File: SumProductInferencingAlgorithm.java * Authors: Tu-Thach Quach, Jeremy D. Wendt * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright 2016, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.graph.inference; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.util.Pair; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ThreadFactory; /** * Base class for Sum-Product inferencing algorithms on graphs/energy functions * * @author jdwendt, tong * @param <LabelType> The type for labels */ @PublicationReference(author = "Jonahtan S. Yedidia, William T. Freeman, and Yair Weiss", title = "Understanding Belief Propagation and its Generalizations", type = PublicationType.TechnicalReport, year = 2001, notes = { "Institution: Mitsubishi Electric Research Laboratories" }) abstract public class SumProductInferencingAlgorithm<LabelType> implements EnergyFunctionSolver<LabelType> { /** * The default stopping epsilon that will be used */ public static final double DEFAULT_EPS = 0.001; /** * The default maximum number of iterations that will be run */ public static final int DEFAULT_MAX_ITERATIONS = 20; /** * The default number of threads that will be used */ public static final int DEFAULT_NUM_THREADS = 4; /** * The actual stopping epsilon that will be used */ private double eps; /** * The actual maximum number of iterations that will be run */ private int maxNumIterations; /** * The actual number of threads that will be used */ private int numThreads; /** * This internally stores the nodes with their values for the learning */ protected List<Node<LabelType>> nodes; /** * The input energy function to learn against */ protected EnergyFunction<LabelType> fn; /** * The edges as split for multi-threading */ private ConcurrentLinkedQueue<List<Integer>> edgeGroups; /** * The nodes as split for multi-threading */ private ConcurrentLinkedQueue<List<Node<LabelType>>> nodeGroups; /** * The splitting of the edges in the graph for the multithreading */ private List<List<Integer>> edgeGroupsMaster; /** * The splitting of the nodes in the graph for the multithreading */ private List<List<Node<LabelType>>> nodeGroupsMaster; /** * Creates a new solver with the specified settings. * * @param maxNumIterations The maximum number of iterations that will be run * @param eps The stopping epsilon that will be used * @param numThreads The number of threads that will be used */ public SumProductInferencingAlgorithm(int maxNumIterations, double eps, int numThreads) { assert (maxNumIterations > 0); this.maxNumIterations = maxNumIterations; this.eps = eps; this.numThreads = numThreads; fn = null; } /** * Creates a new solver with the default settings excepting max number of * iterations. * * @param maxNumIterations The maximum number of iterations that will be run */ public SumProductInferencingAlgorithm(int maxNumIterations) { this(maxNumIterations, DEFAULT_EPS, DEFAULT_NUM_THREADS); } /** * Creates a new solver with the default settings. */ public SumProductInferencingAlgorithm() { this(DEFAULT_MAX_ITERATIONS, DEFAULT_EPS, DEFAULT_NUM_THREADS); } /** * Internal enum that maintains what portion of the solve computation is * being run */ private enum SolverSetting { COMPUTE_MESSAGES, NORMALIZE_NODES, COMPUTE_BELIEFS; }; /** * This class actually implements the solver steps. It's a distinct class as * that's required for threading in Java. */ private class SolveThread implements Runnable { /** * The maximally changed value from this latest iteration that was * computed by this thread */ private double delta; /** * This setting allows for changing what computation will occur with the * next data that is seen */ private SolverSetting setting; @Override public void run() { delta = 0; switch (setting) { case COMPUTE_MESSAGES: computeMesssages(); break; case NORMALIZE_NODES: normalizeNodes(); break; case COMPUTE_BELIEFS: computeBeliefs(); break; default: throw new RuntimeException("Unhandled case, setting = " + setting); } } /** * Computes the messages that are passed between neighboring nodes for * the assigned edges in the graph */ private void computeMesssages() { while (true) { List<Integer> edges = edgeGroups.poll(); // Null is only returned on edgeGroups.isEmpty() if (edges == null) { return; } // Compute messages. for (int edge : edges) { computeTemporaryMessage(edge); } } } /** * Normalizes the messages for each node in the assigned portion of the * graph */ private void normalizeNodes() { while (true) { List<Node<LabelType>> nodes = nodeGroups.poll(); // Null is only returned on nodeGroups.isEmpty() if (nodes == null) { return; } // Normalize and update the messages. for (Node<LabelType> node : nodes) { node.normalizeMessagesForSumProductAlgorithm(); delta = Math.max(delta, node.update()); } } } /** * Computes the final belief after all messages have converged or max * iterations are reached */ private void computeBeliefs() { while (true) { List<Node<LabelType>> nodes = nodeGroups.poll(); // Null is only returned on nodeGroups.isEmpty() if (nodes == null) { return; } // Compute beliefs for these nodes for (Node<LabelType> node : nodes) { node.computeBeliefsForSumProductAlgorithm(fn); } } } /** * Returns the amount of change seen by this thread during this * iteration * * @return the amount of change seen by this thread during this * iteration */ public double getDelta() { return delta; } }; @Override public boolean solve() { boolean converged = false; edgeGroups.clear(); nodeGroups.clear(); int iterCount = 0; // Create the factory that creates the threads final ThreadFactory threadFactory = new ThreadFactory() { private final String baseName = "BpSolver-"; private int counter = 0; @Override public Thread newThread(Runnable r) { return new Thread(r, baseName + counter++); } }; final ExecutorService executorService = Executors.newFixedThreadPool( numThreads, threadFactory); List<SolveThread> threads = new ArrayList<>(numThreads); for (int i = 0; i < numThreads; ++i) { SolveThread thread = new SolveThread(); threads.add(thread); } List<Future<?>> futures = new ArrayList<>(numThreads); while (!converged && iterCount < maxNumIterations) { // Initialize the per-thread queues ... they're emptied each // iteration copyFromMasters(); // First update and run solvers for all of the messages loadAndStartFutures(futures, executorService, threads, SolverSetting.COMPUTE_MESSAGES); waitForThreadsToComplete(futures); // Now update and run solvers for the normalizing nodes step loadAndStartFutures(futures, executorService, threads, SolverSetting.NORMALIZE_NODES); waitForThreadsToComplete(futures); // Check to see if we're converged double delta = 0; for (int i = 0; i < numThreads; ++i) { delta = Math.max(delta, threads.get(i).getDelta()); } if (delta < eps) { converged = true; } iterCount++; } // Compute final beliefs. copyFromMasters(); // Now update and run solvers for the normalizing nodes step loadAndStartFutures(futures, executorService, threads, SolverSetting.COMPUTE_BELIEFS); waitForThreadsToComplete(futures); executorService.shutdown(); return converged; } /** * Private helper that initializes and starts the multi-threading for a new * portion of the sum-product algorithm * * @param futures The Java object for threads * @param executorService The Java object for maintaining the threads * @param threads The local class that contains per-thread logic * @param setting The setting for the threads */ private void loadAndStartFutures(List<Future<?>> futures, ExecutorService executorService, List<SolveThread> threads, SolverSetting setting) { for (int i = 0; i < numThreads; ++i) { threads.get(i).setting = setting; futures.add(executorService.submit(threads.get(i))); } } /** * Private helper that waits for the threads to all complete. * * @param futures The Java object for the threads */ private void waitForThreadsToComplete(List<Future<?>> futures) { for (int i = 0; i < numThreads; ++i) { try { // This is a blocking call futures.get(i).get(); } catch (ExecutionException | InterruptedException e) { throw new RuntimeException(e); } } futures.clear(); } /** * Private helper that computes the temporary message for the specified edge * for the current iteration going in the specified direction (a different * message for each direction) * * @param edge The edge to compute the temporary message for (you can't * replace the current message yet as the current message may be needed for * other edges). */ abstract protected void computeTemporaryMessage(int edge); /** * Private helper that copies the contents of the per-thread start/stop * values into the per-thread-run queues (that will be emptied) */ private void copyFromMasters() { if (!(edgeGroups.isEmpty() && nodeGroups.isEmpty())) { throw new RuntimeException( "Can't copy if the destinations aren't empty"); } for (List<Integer> edgeIds : edgeGroupsMaster) { edgeGroups.add(edgeIds); } for (List<Node<LabelType>> masterNodes : nodeGroupsMaster) { nodeGroups.add(masterNodes); } } /** * Children classes must implement this such that it adds the connections in * the correct direction for the input edge. The nodes object should be * updated by the child in this method. * * @param edgePair The two endpoints of the specified edge */ abstract void initMessages(Pair<Integer, Integer> edgePair); @Override public void init(EnergyFunction<LabelType> f) { // Add nodes. nodes = new ArrayList<>(f.numNodes()); for (int i = 0; i < f.numNodes(); i++) { Node<LabelType> node = new Node<>(i, f.getPossibleLabels(i)); nodes.add(node); } // Add messages to node. for (int edge = 0; edge < f.numEdges(); edge++) { initMessages(f.getEdge(edge)); } // Set all message values to 1. for (Node<LabelType> node : nodes) { node.resetToOne(); } this.fn = f; // Initialize the multi-threading queues // Create the data queues edgeGroupsMaster = new ArrayList<>(); int numPieces = (numThreads * 10); int numPerPiece = f.numEdges() / numPieces; int startAt = 0; for (int i = 0; i < numPieces - 1; ++i) { List<Integer> l = new ArrayList<>(numPerPiece); for (int j = 0; j < numPerPiece; ++j) { l.add(j + startAt); } edgeGroupsMaster.add(l); startAt += numPerPiece; } // Last piece may be different sized than all others List<Integer> l = new ArrayList<>(f.numEdges() - startAt); for (int i = startAt; i < f.numEdges(); ++i) { l.add(i); } edgeGroupsMaster.add(l); nodeGroupsMaster = new ArrayList<>(); numPerPiece = f.numNodes() / numPieces; List<Node<LabelType>> labels = new ArrayList<>(numPerPiece); for (Node<LabelType> node : nodes) { labels.add(node); if (labels.size() == numPerPiece) { nodeGroupsMaster.add(labels); labels = new ArrayList<>(numPerPiece); } } if (!labels.isEmpty()) { nodeGroupsMaster.add(labels); } nodeGroups = new ConcurrentLinkedQueue<>(); edgeGroups = new ConcurrentLinkedQueue<>(); } @Override public double getBelief(int i, int label) { return nodes.get(i).getBelief(label); } }