/*
* File: SumProductPairwiseBayesNet.java
* Authors: Jeremy Wendt
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.graph.inference;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.util.Pair;
import java.util.Collection;
/**
* This class implements a Bayes Net -- but only allowing for pairwise influence
* along edges. That is p(x | a, b, c) is approximated by p(x | a) * p(x | b) *
* p(x | c). We realize that's a considerable approximation, but on real graphs
* with very high degree nodes, it quickly becomes infeasible to compute p(x |
* a, b, ...) when there are hundreds of incoming edges and some approximation
* must be chosen.
*
* @author jdwendt
* @param <LabelType>
*/
@PublicationReference(author
= "Tu-Thach Quach and Jeremy D. Wendt",
title
= "A diffusion model for maximizing influence spread in large networks",
type = PublicationType.Conference,
publication
= "Proceedings of the International Conference on Social Informatics", year
= 2016)
public class SumProductDirectedPropagation<LabelType>
extends SumProductInferencingAlgorithm<LabelType>
{
/**
* Creates a new pairwise Bayes Net solver with the specified settings.
*
* @param maxNumIterations The maximum number of iterations that will be run
* @param eps The stopping epsilon that will be used
* @param numThreads The number of threads that will be used
*/
public SumProductDirectedPropagation(int maxNumIterations,
double eps,
int numThreads)
{
super(maxNumIterations, eps, numThreads);
}
/**
* Creates a new pairwise Bayes Net solver with the default settings
* excepting max number of iterations.
*
* @param maxNumIterations The maximum number of iterations that will be run
*/
public SumProductDirectedPropagation(int maxNumIterations)
{
this(maxNumIterations, DEFAULT_EPS, DEFAULT_NUM_THREADS);
}
/**
* Creates a new pairwise Bayes Net solver with the default settings.
*/
public SumProductDirectedPropagation()
{
this(DEFAULT_MAX_ITERATIONS, DEFAULT_EPS, DEFAULT_NUM_THREADS);
}
@Override
void initMessages(Pair<Integer, Integer> edgePair)
{
Node<LabelType> node = nodes.get(edgePair.getSecond());
node.link(edgePair.getFirst(), false);
}
/**
* Private helper that computes the temporary message for the specified edge
* for the current iteration going in the specified direction
*
* @param edge The edge to compute the temporary message for (you can't
* replace the current message yet as the current message may be needed for
* other edges).
*/
@Override
protected void computeTemporaryMessage(int edge)
{
Pair<Integer, Integer> edgePair = fn.getEdge(edge);
Node<LabelType> sourceNode = nodes.get(edgePair.getFirst());
Node<LabelType> targetNode = nodes.get(edgePair.getSecond());
int sourceNodeId = sourceNode.getId();
int targetNodeId = targetNode.getId();
// Compute m_{ij}(x_j).
Collection<LabelType> sourceLabels = fn.getPossibleLabels(sourceNodeId);
Collection<LabelType> targetLabels = fn.getPossibleLabels(targetNodeId);
int size = sourceLabels.size() * targetLabels.size();
double[] values = new double[size];
double max = -Double.MAX_VALUE;
int ij = 0;
for (LabelType targetLabel : targetLabels)
{
int sourceLabelIdx = 0;
for (LabelType sourceLabel : sourceLabels)
{
values[ij] = -fn.getUnaryCost(sourceNodeId, sourceLabel);
values[ij] += -fn.getPairwiseCost(edge, sourceLabel,
targetLabel);
values[ij] += sourceNode.getLogMessageSum(sourceLabelIdx,
targetNodeId);
max = Math.max(values[ij], max);
++sourceLabelIdx;
++ij;
}
}
// Now correct so that all log values are less than or equal to 0
ij = 0;
Message targetMessage = targetNode.getMessageFromSource(sourceNodeId);
for (int i = 0; i < targetLabels.size(); ++i)
{
double value = 0;
for (int j = 0; j < sourceLabels.size(); ++j)
{
value += Math.exp(values[ij] - max);
++ij;
}
targetMessage.setTempValue(i, value);
}
}
}