/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.nestedgraphs; import java.util.Arrays; import java.util.HashSet; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; public class MultiplexerCPD extends FactorGraph { private Discrete _y; private Discrete _a; private Discrete _za; private Discrete [] _zs; public MultiplexerCPD(Object [] domain, int numZs) { this(buildDomains(domain, numZs),false,false); } /** * @since 0.05 */ public MultiplexerCPD(DiscreteDomain domain, int numZs) { this(buildDomains(domain, numZs),false,false); } public MultiplexerCPD(Object [][] zDomains) { this(zDomains,false,false); } @SuppressWarnings("null") public MultiplexerCPD(Object [][] zDomains, boolean oneBased, boolean aAsDoubles) { super("MultiplexerCPD"); create(zDomains,oneBased,aAsDoubles); } @SuppressWarnings("null") public MultiplexerCPD(Object [] domain, int numZs, boolean oneBased, boolean aAsDoubles) { super("MultiplexerCPD"); create(buildDomains(domain,numZs),oneBased,aAsDoubles); } /** * @since 0.05 */ public MultiplexerCPD(DiscreteDomain [] domains) { this(domains,false,false); } /** * @since 0.05 */ @SuppressWarnings("null") public MultiplexerCPD(DiscreteDomain [] domains, boolean oneBased, boolean aAsDoubles) { super("MultiplexerCPD"); create(domains,oneBased,aAsDoubles); } public Discrete getY() { return _y; } public Discrete getA() { return _a; } public Discrete getZA() { return _za; } public Discrete [] getZs() { return _zs; } private MultiplexerCPD create(Object [][] zDomains, boolean oneBased, boolean aAsDouble) { DiscreteDomain [] domains = new DiscreteDomain[zDomains.length]; for (int i = 0; i < domains.length; i++) domains[i] = DiscreteDomain.create(zDomains[i]); return create(domains,oneBased,aAsDouble); } private MultiplexerCPD create(DiscreteDomain [] zDomains, boolean oneBased, boolean aAsDouble) { Discrete [] Zs = new Discrete[zDomains.length]; int zasize = 0; HashSet<Object> yDomainValues = new HashSet<Object>(); for (int i = 0; i < zDomains.length; i++) { Zs[i] = new Discrete(zDomains[i]); zasize += zDomains[i].size(); for (int j = 0; j < zDomains[i].size(); j++) { yDomainValues.add(zDomains[i].getElement(j)); } } Object [] yDomain = yDomainValues.toArray(); Arrays.sort(yDomain); Discrete Y = new Discrete(yDomain); return create(Y,Zs,zasize, oneBased, aAsDouble); } private MultiplexerCPD create(Discrete Y, Discrete [] Zs, int zasize, boolean oneBased, boolean aAsDouble) { Y.setLabel("Y"); java.util.Hashtable<Object, Integer> yDomainObj2index = new java.util.Hashtable<Object, Integer>(); final DiscreteDomain yDomain = Y.getDiscreteDomain(); for (int i = 0, end = yDomain.size(); i < end; i++) yDomainObj2index.put(yDomain.getElement(i), i); //Create a variable Object [] adomain = new Object[Zs.length]; for (int i = 0; i < adomain.length; i++) { int val = oneBased ? i+1 : i; if (aAsDouble) adomain[i] = (double)val; else adomain[i] = (int)val; } Discrete A = new Discrete(adomain); A.setLabel("A"); addBoundaryVariables(Y); addBoundaryVariables(A); addBoundaryVariables(Zs); //Make all of those boundary variables Variable [] vars = new Variable[Zs.length+2]; vars[0] = Y; vars[1] = A; for (int i = 0; i < Zs.length; i++) vars[i+2] = Zs[i]; //Create ZA variable Object [] zaDomain = new Object[zasize]; for (int i = 0; i < zaDomain.length; i++) zaDomain[i] = i; Discrete ZA = new Discrete(zaDomain); ZA.setLabel("ZA"); //Create Z* variables Discrete [] Zstars = new Discrete[Zs.length]; for (int i = 0; i < Zstars.length; i++) { Object [] domain = new Object[Zs[i].getDiscreteDomain().size() + 1]; for (int j = 0; j < domain.length; j++) domain[j] = j; Zstars[i] = new Discrete(domain); } //Create ZA Y factor int [][] indices = new int[zasize][2]; double [] weights = new double [zasize]; int index = 0; for (int i = 0; i < Zs.length; i++) { for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) { indices[index][0] = index; indices[index][1] = yDomainObj2index.get(Zs[i].getDiscreteDomain().getElement(j)); weights[index] = 1; index++; } } Factor f = this.addFactor(indices,weights,ZA,Y); f.setLabel("Y2ZA"); //Create ZA A factor indices = new int[zasize][2]; weights = new double[zasize]; index = 0; for (int i = 0; i < Zs.length; i++) { for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) { indices[index][0] = index; indices[index][1] = i; weights[index] = 1; index++; } } f = this.addFactor(indices,weights,ZA,A); f.setLabel("ZA2A"); //Create ZA Z* factors //Create Z* Z factors for (int a = 0; a < Zs.length; a++) { Zs[a].setLabel("Z" + a); Zstars[a].setLabel("Z*" + a); indices = new int[zasize][2]; weights = new double[zasize]; index = 0; //Factor from ZA to Z* for (int i = 0; i < Zs.length; i++) { for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) { indices[index][0] = index; if (a == i) { indices[index][1] = j; } else { int sz = Zs[a].getDiscreteDomain().size(); indices[index][1] = sz; } weights[index] = 1; index++; } } f = this.addFactor(indices,weights,ZA,Zstars[a]); f.setLabel("ZA2Z*"); //From Z* to Z indices = new int[Zs[a].getDiscreteDomain().size()*2][2]; weights = new double[indices.length]; int ds = Zs[a].getDiscreteDomain().size(); for (int i = 0; i < ds; i++) { indices[i][0] = i; indices[ds+i][0] = ds; indices[i][1] = i; indices[ds+i][1] = i; weights[i] = 1; weights[ds+i] = 1; } f = this.addFactor(indices, weights,Zstars[a],Zs[a]); f.setLabel("Z*2Z"); } this._y = Y; this._a = A; this._za = ZA; this._zs = Zs; return this; } public static DiscreteDomain [] buildDomains(DiscreteDomain domain, int numZs) { DiscreteDomain [] retval = new DiscreteDomain[numZs]; for (int i = 0; i < retval.length; i++) retval[i] = domain; return retval; } public static DiscreteDomain [] buildDomains(Object [] domain, int numZs) { return buildDomains(DiscreteDomain.create(domain), numZs); } }