/*******************************************************************************
* Copyright (c) 2012 Michael Kutschke. All rights reserved. This program and the accompanying materials are made
* available under the terms of the Eclipse Public License v1.0 which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors: Michael Kutschke - initial API and implementation
******************************************************************************/
package org.eclipse.recommenders.jayes.transformation;
import static org.eclipse.recommenders.jayes.transformation.util.ArrayFlatten.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.BayesNode;
import org.eclipse.recommenders.jayes.factor.AbstractFactor;
import org.eclipse.recommenders.jayes.factor.DenseFactor;
import org.eclipse.recommenders.jayes.factor.arraywrapper.DoubleArrayWrapper;
import org.eclipse.recommenders.jayes.transformation.util.CanonicalDoubleArrayManager;
import org.eclipse.recommenders.jayes.transformation.util.DecompositionFailedException;
import org.eclipse.recommenders.jayes.util.MathUtils;
import com.google.common.collect.Lists;
import com.google.common.primitives.Ints;
/**
* Abstract base class for Matrix Decomposition classes used for probability distributions.
*/
public abstract class AbstractDecomposition implements IDecompositionStrategy {
@Override
public final void decompose(BayesNet net, BayesNode node) throws DecompositionFailedException {
if (!net.getNodes().contains(node)) {
throw new IllegalArgumentException("Node " + node + " is not part of the bayesnet " + net.getName());
}
AbstractFactor f = node.getFactor();
if (f.getDimensions().length == 1) {
// in a bayesian network, there are no 0-dimensional factors
throw new DecompositionFailedException("Node " + node + " has no parents, impossible to decompose");
}
f = reorderFactor(f);
int[] dimensions = f.getDimensions();
List<double[]> basis;
double[] latentProb;
// TODO this line is one of those keeping this method from working with SparseFactor
// (and consequently BayesNode from using SparseFactor as well)
List<double[]> vectors = unflatten(f.getValues().toDoubleArray(), dimensions[dimensions.length - 1]);
basis = getBasis(f, vectors);
latentProb = getLatentProbabilities(vectors, basis);
if (f == node.getFactor()) {
// there was no reordering
createLatentNodeInOriginalOrder(net, node, basis, latentProb);
} else {
createLatentNodeReordered(net, node, f, basis, latentProb);
}
}
private AbstractFactor reorderFactor(AbstractFactor f) {
int[] dimensions = f.getDimensions();
int min = Ints.min(dimensions);
int minIndex = Ints.lastIndexOf(dimensions, min);
if (minIndex == dimensions.length - 1) {
return f;
}
int[] nDim = rotateRight(dimensions, dimensions.length - 1 - minIndex);
int[] nIDs = rotateRight(f.getDimensionIDs(), dimensions.length - 1 - minIndex);
AbstractFactor f2 = new DenseFactor();
f2.setDimensionIDs(nIDs);
f2.setDimensions(nDim);
f2.fill(1);
f2.multiplyCompatible(f);
return f2;
}
protected abstract List<double[]> getBasis(AbstractFactor f, List<double[]> vectors)
throws DecompositionFailedException;
private double[] getLatentProbabilities(List<double[]> vectors, List<double[]> best)
throws DecompositionFailedException {
CanonicalDoubleArrayManager canon = new CanonicalDoubleArrayManager(); // to make sure equals will work
best = Lists.transform(best, canon);
vectors = Lists.transform(vectors, canon);
List<double[]> newVectors = toLatentSpace(vectors, best);
return flatten(newVectors.toArray(new double[0][]));
}
private List<double[]> toLatentSpace(List<double[]> vectors, List<double[]> best)
throws DecompositionFailedException {
List<double[]> latent = new ArrayList<double[]>();
for (double[] v : vectors) {
latent.add(toLatentSpace(v, best));
}
return latent;
}
protected abstract double[] toLatentSpace(double[] v, List<double[]> best) throws DecompositionFailedException;
/**
* Example: assume C as the least outcomes, A and B are parents of C, C is decomposed
*
* <pre>
* A -> C => A -> latent-C -> C
* B / B /
* </pre>
**/
private void createLatentNodeInOriginalOrder(BayesNet net, BayesNode node, List<double[]> basis, double[] latentProb) {
BayesNode newNode = net.createNode("latent-" + node.getName());
addOutcomes(newNode, basis.size());
newNode.setParents(node.getParents());
newNode.setProbabilities(latentProb);
node.setParents(Arrays.asList(newNode));
node.setProbabilities(flatten(basis.toArray(new double[0][])));
}
/**
* Example: assume B is the least outcomes, A and B are parents of C, C is decomposed
*
* <pre>
* A -> C => A -------------> C
* B / B -> latent-C /
* </pre>
**/
private void createLatentNodeReordered(BayesNet net, BayesNode node, AbstractFactor f, List<double[]> basis,
double[] latentProb) {
BayesNode newNode = net.createNode("latent-" + node.getName());
addOutcomes(newNode, basis.size());
int[] dimensions = f.getDimensions();
BayesNode parentNode = net.getNode(f.getDimensionIDs()[dimensions.length - 1]);
newNode.setParents(Arrays.asList(parentNode));
newNode.setProbabilities(flatten(transpose(basis).toArray(new double[0][])));
List<BayesNode> parents = new ArrayList<BayesNode>(node.getParents());
int index = parents.indexOf(parentNode);
parents.remove(parentNode);
parents.add(index, newNode);
node.setParents(parents);
double[] nodeProbs = undoReordering(latentProb, node.getFactor(), f, newNode.getId());
node.setProbabilities(nodeProbs);
}
private void addOutcomes(BayesNode newNode, int d) {
for (int i = 0; i < d; i++) {
newNode.addOutcome("outcome" + i);
}
}
private double[] undoReordering(double[] latentProb, AbstractFactor originalFactor, AbstractFactor newFactor,
int originalId) {
AbstractFactor o2 = originalFactor.clone();
AbstractFactor n2 = newFactor.clone();
n2.getDimensionIDs()[n2.getDimensionIDs().length - 1] = originalId;
int oInd = Ints.indexOf(o2.getDimensionIDs(), originalId);
n2.getDimensions()[n2.getDimensions().length - 1] = o2.getDimensions()[oInd];
n2.setValues(new DoubleArrayWrapper(latentProb));
o2.setValues(new DoubleArrayWrapper(new double[MathUtils.product(o2.getDimensions())]));
o2.fill(1);
o2.multiplyCompatible(n2);
return o2.getValues().toDoubleArray();
}
protected final List<double[]> transpose(List<double[]> best) {
List<double[]> result = new ArrayList<double[]>();
for (int i = 0; i < best.get(0).length; i++) {
result.add(new double[best.size()]);
}
for (int i = 0; i < best.size(); i++) {
double[] arr = best.get(i);
for (int j = 0; j < arr.length; j++) {
result.get(j)[i] = arr[j];
}
}
return result;
}
private int[] rotateRight(int[] array, int amount) {
int[] result = new int[array.length];
System.arraycopy(array, 0, result, amount, array.length - amount);
System.arraycopy(array, array.length - amount, result, 0, amount);
return result;
}
}