/******************************************************************************* * Copyright 2014 Felipe Takiyama * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package br.usp.poli.takiyama.common; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; import br.usp.poli.takiyama.prv.Prv; import br.usp.poli.takiyama.prv.RangeElement; import br.usp.poli.takiyama.prv.StdPrv; import br.usp.poli.takiyama.prv.Substitution; import br.usp.poli.takiyama.prv.Term; import br.usp.poli.takiyama.utils.Lists; import br.usp.poli.takiyama.utils.MathUtils; /** * Factors are table representations for joint distributions. * <p> * In order to save space, factors do not store all tuples, only their values. * Values are indexed based on the order of variables (columns) and the order * of the range of each variable. * </p> * @author ftakiyama * */ public class StdFactor implements Factor { /** * The name of this factor. */ private final String name; /** * A list with all parameterized random variables. Variables must be * ordered because of internal implementation. */ private final List<? extends Prv> variables; /** * The list of values mapped to tuples of PRVs. */ private final List<BigDecimal> values; private final int size; /* ************************************************************************ * Constructors * ************************************************************************/ /** * Creates a parameterized factor. * * @param name The name of this factor * @param variables A ordered list of {@link Prv}. * @param values A ordered list of {@link Number}, determined by the order * of Prvs. * @throws IllegalArgumentException If the number of values specified is not * compatible with the PRVs specified. */ private StdFactor(String name, List<? extends Prv> variables, List<BigDecimal> values) throws IllegalArgumentException { this.name = new String(name); this.variables = new ArrayList<Prv>(variables); this.values = new ArrayList<BigDecimal>(values); this.size = getSize(variables); if (values.size() != 0 && values.size() != size) { throw new IllegalArgumentException("Wrong number of values. Expected: " + size + ", received: " + values.size()); } } /** * Returns the expected size of this factor. * * @return The expected size of this factor. */ private static int getSize(List<? extends Prv> variables) { int size = 1; if (variables.isEmpty()) { size = 0; } for (Prv prv : variables) { size = size * prv.range().size(); } return size; } /* ************************************************************************ * Static factories * ************************************************************************/ /** * Returns a parameterized factor. * <p> * The order of <code>values</code> is dependent on the order of the list * of PRVs and their ranges. * </p> * * @param name The name of this factor * @param variables A ordered list of {@link Prv}. * @param values A ordered list of {@link Number}, determined by the order * of Prvs. * @throws IllegalArgumentException If the number of values specified is not * compatible with the PRVs specified. */ public static Factor getInstance(String name, List<? extends Prv> variables, List<BigDecimal> values) throws IllegalArgumentException { return new StdFactor(name, variables, values); } /** * Returns a parameterized factor with one PRV. * * @param name The name of this factor * @param variable A {@link Prv}. * @param values A ordered list of {@link Number} */ public static Factor getInstance(String name, Prv variable, List<BigDecimal> values) { List<Prv> variables = new ArrayList<Prv>(1); variables.add(variable); return StdFactor.getInstance(name, variables, values); } /** * Returns a parameterized factor with the same variables and values as * the specified factor. * * @param factor The factor to "copy" * @return A parameterized factor with the same variables and values as * the specified factor. */ public static Factor getInstance(Factor factor) { return new StdFactor(factor.name(), factor.variables(), factor.values()); } /** * Returns a constant parameterized factor. * <p> * A constant factor returns the value 1 for all tuples in * the factor. * </p> * * @param variables A ordered list of {@link Prv} * @return A constant parameterized factor with the specified PRVs. */ public static Factor getInstance(List<? extends Prv> variables) { // int size = getSize(variables); // List<BigDecimal> values = new ArrayList<BigDecimal>(size); // Lists.fill(values, BigDecimal.ONE, size); // return new Factor("1", variables, values); return ConstantFactor.getInstance(variables); } /** * Returns a constant factor with one PRV. * <p> * A constant factor returns the value 1 for all tuples in * the factor. * </p> * * @see #getInstance(List) * @param variable A {@link Prv} * @return A constant factor with the specified PRV. */ public static Factor getInstance(Prv variable) { List<Prv> variables = Lists.listOf(variable); return ConstantFactor.getInstance(variables); } /** * Returns an empty factor. * <p> * An empty factor has no values, nor variables. * </p> * * @see #getInstance(List) * @return An empty factor. */ public static Factor getInstance() { List<Prv> empty = new ArrayList<Prv>(0); return StdFactor.getInstance(empty); } /* ************************************************************************ * Getters * ************************************************************************/ @Override public int getIndex(Tuple<RangeElement> tuple) throws IllegalArgumentException { if (tuple.isEmpty()) { throw new IllegalArgumentException("This tuple is empty!"); } int index = 0; int r = 1; for (int i = tuple.size() - 1; i >= 0; i--) { index = index + r * indexOf(i, tuple); r = r * rangeSize(i); } return index; } /** * Returns the index of the range element that occupies the specified * position in the tuple. * * @param i The position in the tuple * @param tuple A tuple of {@link RangeElement} * @return */ private int indexOf(int i, Tuple<RangeElement> tuple) { return variables.get(i).range().indexOf(tuple.get(i)); } /** * Returns PRV's range size occupying the specified position in this * factor. * * @param i PRV's index in this factor * @return PRV's range size occupying the specified position in this * factor. */ private int rangeSize(int i) { return variables.get(i).range().size(); } private RangeElement rangeElementAt(int rangeIndex, int prvIndex) { return variables.get(prvIndex).range().get(rangeIndex); } @Override public Tuple<RangeElement> getTuple(int index) { List<RangeElement> values = new ArrayList<RangeElement>(variables.size()); for (int j = variables.size() - 1; j > 0; j--) { int domainSize = rangeSize(j); values.add(rangeElementAt(index % domainSize, j)); index = index / domainSize; } values.add(rangeElementAt(index, 0)); Collections.reverse(values); return Tuple.getInstance(values); } @Override public BigDecimal getValue(int index) { return values.get(index); } @Override public BigDecimal getValue(Tuple<RangeElement> tuple) { return getValue(getIndex(tuple)); } /************ Iterator **************************************************/ /** * This class is an Iterator over all tuples of this factor. * <p> * The code was inspired on OpenJDK's implementation of ArrayList Iterator. * </p> */ private class Itr implements Iterator<Tuple<RangeElement>> { int nextElementToReturn; @Override public boolean hasNext() { return nextElementToReturn != size; } @Override public Tuple<RangeElement> next() { int i = nextElementToReturn; if (i > size) { throw new NoSuchElementException(); } nextElementToReturn = i + 1; return getTuple(i); } /** * Throws {@link UnsupportedOperationException}. */ public void remove() { throw new UnsupportedOperationException(); } } /* (non-Javadoc) * @see br.usp.poli.takiyama.common.Factor#iterator() */ @Override public Iterator<Tuple<RangeElement>> iterator() { return new Itr(); } /** * Returns an iterator over all tuples of a parameterized factor having * the specified variables. * <p> * The tuples returned <b>depend</b> on the order of the give * parameterized random variable list, that is, the order of parameterized * random variables define the way the tuples are created. * </p> * * @param variables A list of parameterized random variables. * @return An iterator over all tuples of a parameterized factor having * the specified parameterized random variables. */ public static Iterator<Tuple<RangeElement>> iterator(List<Prv> variables) { return StdFactor.getInstance(variables).iterator(); } /**************************************************************************/ @Override public int size() { return size; } @Override public String name() { return name; } @Override public List<Prv> variables() { return new ArrayList<Prv>(variables); } @Override public List<BigDecimal> values() { // TODO make it more flexible return new ArrayList<BigDecimal>(values); } @Override public boolean contains(Term t) { for (Prv prv : variables) { if (prv.contains(t)) { return true; } } return false; } @Override public int occurrences(Term t) { int count = 0; for (Prv prv : variables) { if (prv.contains(t)) { count++; } } return count; } @Override public Prv getVariableHaving(Term t) { Prv result = StdPrv.getInstance(); for (Prv prv : variables) { if (prv.contains(t)) { result = prv; } } return result; } /** * Returns <code>true</code> if this factor is a sub-factor of the * specified factor. * <p> * Factor F1 is a sub-factor of factor F2 if the set of {@link Prv}s from * F1 is a subset of the of Prvs from F2. * </p> * * @return <code>true</code> if this factor is a sub-factor of the * specified factor, <code>false</code> otherwise. */ @Override public boolean isSubFactorOf(Factor factor) { // Quick check: return false if this factor is bigger if (factor.variables().size() < variables.size()) { return false; } // Checks if all variables in this factor exist in the specified factor Iterator<? extends Prv> it = variables.iterator(); while (it.hasNext()) { Prv prv = it.next(); if (!factor.variables().contains(prv)) { return false; } } return true; } @Override public boolean isConstant() { for (BigDecimal val : values) { if (!val.equals(BigDecimal.ONE)) { return false; } } return true; } @Override public boolean isEmpty() { boolean noVariables = variables.isEmpty(); boolean noValues = values.isEmpty(); return noVariables && noValues; } /* ************************************************************************ * hashCode, equals and toString * ************************************************************************/ @Override public int hashCode() { int result = 17; result = 31 + result + Arrays.hashCode(variables.toArray(new Prv[variables.size()])); result = 31 + result + Lists.hashCode(values); return result; } @Override public boolean equals(Object other) { // Tests if both refer to the same object if (this == other) return true; // Tests if the Object is an instance of this class if (!(other instanceof StdFactor)) return false; // Tests if both have the same attributes StdFactor targetObject = (StdFactor) other; return ((variables == null) ? targetObject.variables == null : variables.equals(targetObject.variables)) && ((values == null) ? targetObject.values == null : Lists.areEqual(values, targetObject.values)); } @Override public String toString() { // StringBuilder result = new StringBuilder(); // // // Appends the name it is not empty // if (name == null || name.isEmpty()) { // result.append(this.name + "\n"); // } // // if (this.variables.isEmpty()) { // return this.name + " is empty."; // } // // String thinRule = ""; // String thickRule = ""; // String cellFormat = "%-10s"; //TODO: change to something more dynamic // String valueCellFormat = "%-10s\n"; // // // Create the rules - aesthetic // for (int i = 0; i <= this.variables.size(); i++) { // thinRule += String.format(cellFormat, "").replace(" ", "-"); // } // thickRule = thinRule.replace("-", "="); // // // Top rule // result.append("\n").append(thickRule).append("\n"); // // // Print the variables names // for (Prv prv : variables) { // result.append(String.format(cellFormat, prv.toString())); // } // // // Value column // result.append(String.format(cellFormat, "VALUE")).append("\n"); // // // Mid rule // result.append(thinRule).append("\n"); // // // Print the contents // for (int i = 0; i < values.size(); i++) { // Tuple<RangeElement> tuple = getTuple(i); // for (int j = 0; j < tuple.size(); j++) { // result.append(String.format(cellFormat, tuple.get(j))); // } // // Round the value to 6 digits // result.append(String.format(valueCellFormat, values.get(i).toString())); // } // // // Bottom rule // result.append(thickRule).append("\n"); // // return result.toString(); return ""; } /* ************************************************************************ * Setters * ************************************************************************/ /** * Returns the result of applying the specified substitution to this * Factor. The substitution is applied to PRVs in this Factor, but its * values are not modified. * @param s The substitution to apply * @return The result of applying the specified substitution to this * Factor */ @Override public Factor apply(Substitution s) { List<Prv> substitutedVars = new ArrayList<Prv>(variables.size()); for (Prv prv : variables) { Prv substituted = prv.apply(s); substitutedVars.add(substituted); } return StdFactor.getInstance(name, substitutedVars, values); } /** * Returns a copy of this factor with the value of the specified tuple * replaced by the specified value. * * @param tuple The tuple whose value must be modified * @param value The new value of the tuple * @return a copy of this factor with the value of the specified tuple * replaced by the specified value. */ @Override public Factor set(Tuple<RangeElement> tuple, BigDecimal value) { List<BigDecimal> vals = new ArrayList<BigDecimal>(values); vals.set(getIndex(tuple), value); return StdFactor.getInstance(name, variables, vals); } /* ************************************************************************ * Multiplication, Power and Sum Out * ************************************************************************/ /** * Sums out a random variable from a factor. * <p> * Suppose F is a factor on random variables x<sub>1</sub>,..., * x<sub>i</sub>,...,x<sub>j</sub>. The summing out of random variable * x<sub>i</sub> from F, denoted as Σ<sub>x<sub>i</sub></sub>F * is the factor on * random variables x<sub>1</sub>,...,x<sub>i-1</sub>, x<sub>i+1</sub>, * ..., x<sub>j</sub> such that * </p> * (Σ<sub>x<sub>i</sub></sub> F) * (x<sub>1</sub>,...,x<sub>i-1</sub>,x<sub>i+1</sub>,...,x<sub>j</sub>) = * Σ <sub>y ∈ dom(x<sub>i</sub>)</sub> * F(x<sub>1</sub>,...,x<sub>i-1</sub>,x<sub>i</sub> = y, * x<sub>i+1</sub>,...,x<sub>j</sub>). * <p> * If the variable to be summed out does not exist in the factor, this * method returns the specified factor unmodified. * </p> * * @param prv The {@link Prv} to be summed out. * @return A factor with the specified Prv summed out, or this factor * if <code>prv</code> does not exist in this factor. */ @Override public Factor sumOut(Prv prv) { // Checks if the random variable exists if (!variables.contains(prv)) { return this; } // Creates a flag for values in this factor that were already processed boolean[] wasVisited = new boolean[size]; Arrays.fill(wasVisited, false); // Removes the PRV to be summed out List<Prv> vars = new ArrayList<Prv>(variables); vars.remove(prv); // Creates the new mapping, summing out the PRV List<BigDecimal> vals = new ArrayList<BigDecimal>(size / prv.range().size()); int prvIndex = variables.indexOf(prv); for (int i = 0; i < size; i++) { if (!wasVisited[i]) { Tuple<RangeElement> current = getTuple(i); BigDecimal sum = BigDecimal.ZERO; // Builds all tuples varying only the PRV being summed out // and sums their values for (RangeElement e : prv.range()) { Tuple<RangeElement> next = current.set(prvIndex, e); BigDecimal correction = prv.getSumOutCorrection(e); sum = sum.add(getValue(next).multiply(correction, MathUtils.CONTEXT), MathUtils.CONTEXT); wasVisited[getIndex(next)] = true; } vals.add(sum); } } // Creates the new factor return getInstance(name, vars, vals); } /** * Returns this factor raised by <code>p/q</code>. * <p> * Raising a factor to some exponent is the same as raising its values to * that exponent. * </p> * * @param p Exponents's numerator * @param q Exponenet's denominator * @return The value of this factor raised to <code>p/q</code> */ @Override public Factor pow(int p, int q) { List<BigDecimal> newValues = new ArrayList<BigDecimal>(); for (BigDecimal base : values) { newValues.add(MathUtils.pow(base, p, q)); } return getInstance(name, variables, newValues); } /** * Multiplies this factor with the specified factor. * <p> * Given two factors * F<sub>1</sub>(x<sub>1</sub>,...,x<sub>n</sub>, * y<sub>1</sub>,...,y<sub>j</sub>) and * F<sub>2</sub>(y<sub>1</sub>,...,y<sub>j</sub>, * z<sub>1</sub>,...,z<sub>k</sub>) * the resulting factor will be * F(x<sub>1</sub>,...,x<sub>n</sub>,y<sub>1</sub>,..., * y<sub>j</sub>,z<sub>1</sub>,...,z<sub>k</sub>) = * F<sub>1</sub>(x<sub>1</sub>,...,x<sub>n</sub>,y<sub>1</sub>,..., * y<sub>j</sub>) x * F<sub>2</sub>(y<sub>1</sub>,...,y<sub>j</sub>,z<sub>1</sub>,..., * z<sub>k</sub>) * </p> * That is, for each assignment of values to the variables in the factors, * the method multiply the values that have the same assignment for common * variables. * * @param secondFactor The second factor to be multiplied * @return The multiplication of fisrtFactor by secondFactor. */ @Override public Factor multiply(Factor factor) { // Special case: multiplication by 1. Needed because of the way constant // factors are modeled if (this.isEmpty()) { return factor; } if (factor.isEmpty()) { return this; } if (this.isConstant()) { return factor; } if (factor.isConstant()) { return this; } String newName = name() + "*" + factor.name(); List<Prv> union = Lists.union(variables(), factor.variables()); List<BigDecimal> mult = new ArrayList<BigDecimal>(getSize(union)); int[][] mapOfCommomVariables = getMapOfCommomVariables(this, factor); for (Tuple<RangeElement> t1 : this) { for (Tuple<RangeElement> t2 : factor) { if (haveSameSubtuple(t1, t2, mapOfCommomVariables)) { mult.add(getValue(t1).multiply(factor.getValue(t2), MathUtils.CONTEXT)); } } } return getInstance(newName, union, mult); } /** * Returns a mapping from indexes of variables in the first factor to the * indexes of the variables that also appear in the second factor. * <p> * The mapping is a 2 x n matrix, where n is the number of common * random variables between the first factor and the second factor. * </p> * <p> * The set of random variables from the first factor is analyzed * sequentially, and for each random variable in the set, the set from the * second factor is searched for a match. Thus, the first line of the * result will be in ascending order, while the second line may have * an arbitrary ordering. * </p> * <p> * For example, suppose that f1(x1,x2,x3,x4,x5) and f2(x5,x4,x1) are factors * passed as parameter for this method. Then the mapping will be the * following matrix: * </p> * <table> * <tr><td>0</td><td>3</td><td>4</td></tr> * <tr><td>2</td><td>1</td><td>0</td></tr> * </table> * <p> * which means that x1 (index 0) in f1 has a match in f2 at index 2, and * so on. * </p> * * @param f1 The first factor. * @param f2 The second factor. * @return A mapping of indexes from common variables between f1 and f2. */ private int[][] getMapOfCommomVariables(Factor f1, Factor f2) { int[][] mapping = new int[2][f1.variables().size()]; int size = 0; for (Prv prv1 : f1.variables()) { if (f2.variables().contains(prv1)) { mapping[0][size] = f1.variables().indexOf(prv1); mapping[1][size] = f2.variables().indexOf(prv1); size++; } } return trim(mapping, size); } /** * Trims the specified matrix to the specified size. The length of the * matrix is preserved. * * @param matrix The matrix to trim * @param size The limit of the size of each line of the matrix. * @return The specified matrix trimmed to the specified size */ private int[][] trim(int[][] matrix, int size) { int[][] m = new int[matrix.length][size]; for (int j = 0; j < size; j++) { m[0][j] = matrix[0][j]; m[1][j] = matrix[1][j]; } return m; } /** * Returns <code>true</code> if both tuples have the same sub-tuple. * The sub-tuple is defined according to a map. * * @see Factor#getMapOfCommomVariables * @param t1 The first tuple to check * @param t2 The second tuple to check * @param map A mapping that connects indexes representing the same PRV * @return <code>true</code> if tuples have the same value for the * sub-tuple, <code>false</code> otherwise. */ private boolean haveSameSubtuple(Tuple<?> t1, Tuple<?> t2, int[][] map) { Tuple<?> st1 = t1.subTuple(map[0]); Tuple<?> st2 = t2.subTuple(map[1]); return st1.equals(st2); } @Override public Factor reorder(Factor reference) throws IllegalArgumentException { if (!Lists.sameElements(variables(), reference.variables())) { throw new IllegalArgumentException(); } int[][] mapOfCommomVariables = getMapOfCommomVariables(reference, this); List<BigDecimal> reordered = new ArrayList<BigDecimal>(size); for (Tuple<RangeElement> tuple : reference) { List<RangeElement> r = new ArrayList<RangeElement>(tuple.size()); // builds the reordered tuple for (int i = 0; i < mapOfCommomVariables[1].length; i++) { r.add(tuple.get(mapOfCommomVariables[1][i])); } Tuple<RangeElement> reorderedTuple = Tuple.getInstance(r); reordered.add(getValue(reorderedTuple)); } Factor result = StdFactor.getInstance(name, reference.variables(), reordered); return result; } }