/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.matlabproxy; import static java.util.Objects.*; import java.util.ArrayList; import java.util.Collection; import java.util.UUID; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.CustomFactorFunctionWrapper; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.factorfunctions.core.TableFactorFunction; import com.analog.lyric.dimple.matlabproxy.repeated.IPVariableStreamSlice; import com.analog.lyric.dimple.matlabproxy.repeated.PFactorGraphStream; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.INode; import com.analog.lyric.dimple.model.core.Node; import com.analog.lyric.dimple.model.factors.DiscreteFactor; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.factors.FactorBase; import com.analog.lyric.dimple.model.repeated.FactorGraphStream; import com.analog.lyric.dimple.model.repeated.IVariableStreamSlice; import com.analog.lyric.dimple.model.variables.Constant; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.IConstantOrVariable; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.schedulers.IScheduler; import com.analog.lyric.dimple.schedulers.SchedulerBase; import com.analog.lyric.dimple.solvers.interfaces.IFactorGraphFactory; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.util.misc.FactorGraphDiffs; import com.analog.lyric.util.misc.IMapList; import com.analog.lyric.util.misc.Matlab; @Matlab(wrapper="FactorGraph") public class PFactorGraphVector extends PFactorBaseVector { /*-------------- * Construction */ public PFactorGraphVector(FactorGraph f) { super(new Node [] {f}); } public PFactorGraphVector(Node [] nodes) { super(nodes); } /*----------------- * PObject methods */ @Override public boolean isDiscrete() { for (Factor f : getGraph().getFactors()) if (!f.isDiscrete()) return false; return true; } @Override public boolean isGraph() { return true; } /*--------------------- * PNodeVector methods */ @Override public PFactorGraphVector createNodeVector(Node[] nodes) { return new PFactorGraphVector(nodes); } /*----------------------------- * PFactorGraphVector methods */ public FactorGraph getGraph() { if (size() != 1) throw new DimpleException("operation not supported"); return (FactorGraph)getModelerNode(0); } public @Nullable String getMatlabSolveWrapper() { ISolverFactorGraph solverGraph = getGraph().getSolver(); return solverGraph != null ? solverGraph.getMatlabSolveWrapper() : null; } public int getNumSteps() { return getGraph().getNumSteps(); } public void setNumSteps(int numSteps) { getGraph().setNumSteps(numSteps); } public void setNumStepsInfinite(boolean val) { getGraph().setNumStepsInfinite(val); } public boolean getNumStepsInfinite() { return getGraph().getNumStepsInfinite(); } public PFactorGraphVector addGraph(PFactorGraphVector childGraph, PVariableVector varVector) { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); return new PFactorGraphVector(getGraph().addGraph(childGraph.getGraph(), varVector.getVariableArray())); } @Deprecated public boolean customFactorExists(String funcName) { return getGraph().customFactorExists(funcName); } public PFactorVector createFactor(PFactorTable factorTable, Object [] vars) { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); Factor f = getGraph().addFactor(new TableFactorFunction("table", factorTable.getModelerObject()), PHelpers.convertToMVariablesAndConstants(vars)); if (f.isDiscrete()) return new PDiscreteFactorVector((DiscreteFactor) f); else return new PFactorVector(f); } public PFactorVector createFactor(FactorFunction factorFunction, Object [] vars) { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); Factor f = getGraph().addFactor(factorFunction,PHelpers.convertToMVariablesAndConstants(vars)); if (f.isDiscrete()) return new PDiscreteFactorVector((DiscreteFactor) f); else return new PFactorVector(f); } public PFactorVector createFactor(PFactorFunction factorFunction, Object [] vars) { FactorFunction ff = factorFunction.getModelerObject(); return createFactor(ff,vars); } public void solve() { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); getGraph().solve(); } public void startContinueSolve() { getSolverGraph().continueSolve(); } public void continueSolve() { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); getGraph().continueSolve(); } public void solveOneStep() { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); getGraph().solveOneStep(); } public void startSolveOneStep() { getSolverGraph().startSolveOneStep(); } public boolean isSolverRunning() { return getGraph().isSolverRunning(); } public void startSolver() { getSolverGraph().startSolver(); } public PVariableVector getVariableVector(int relativeNestingDepth,int forceIncludeBoundaryVariables) { if (isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); PVariableVector tmp = PHelpers.convertToVariableVector(getGraph().getVariables(relativeNestingDepth,forceIncludeBoundaryVariables!=0)); return tmp; } public PFactorBaseVector getFactors(int relativeNestingDepth) { return getFactors(getGraph().getFactors(relativeNestingDepth)); } public PFactorBaseVector getFactors(IMapList<FactorBase> factors) { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); return PHelpers.convertToFactorVector(factors.toArray(new Node[factors.size()])); } public int [][] getAdjacencyMatrix() { return getGraph().getAdjacencyMatrix(); } //Returns an adjacency matrix with the given nesting depth. public int [][] getAdjacencyMatrix(int relativeNestingDepth) { return getGraph().getAdjacencyMatrix(relativeNestingDepth); } //Returns an adjacency matrix of the given objects. public int [][] getAdjacencyMatrix(Object [] objects) { ArrayList<Node> alNodes = new ArrayList<Node>(); for (int i = 0; i < objects.length; i++) { PNodeVector tmp = (PNodeVector)objects[i]; for (int j= 0; j < tmp.size(); j++) { alNodes.add(tmp.getModelerNode(j)); } } Node [] array = new Node[alNodes.size()]; for (int i =0 ; i < array.length; i++) array[i] = alNodes.get(i); return getGraph().getAdjacencyMatrix(array); } public void interruptSolver() { getSolverGraph().interruptSolver(); } /** * Add multiple factors. * <p> * @param factor is used as a prototype. It's factor function and constants should be copied. * @param vars should actually contain PNodeVectors. Each one contains the full set of variables * from which the factor arguments will be chosen. * @param indices */ public PNodeVector addFactorVectorized(PFactorVector factor, Object [] vars, Object [] indices) { PNodeVector [] nodes = PHelpers.convertObjectArrayToNodeVectorArray(vars); int [][][] intIndices = PHelpers.extractIndicesVectorized(indices); PNodeVector [][] args = PHelpers.extractVectorization(nodes, intIndices); final Factor modelFactor = factor.getFactor(0); final int nArgs = modelFactor.getArgumentCount(); Object[] argsWithConstants = null; if (modelFactor.hasConstants()) { argsWithConstants = new Object[nArgs]; for (int j = 0; j < nArgs; ++j) { IConstantOrVariable arg = modelFactor.getArgument(j); if (arg instanceof Constant) { argsWithConstants[j] = arg; } } } Node [] retval = new Node[args.length]; for (int i = 0; i < args.length; i++) { Object[] argsi = args[i]; if (argsWithConstants != null) { for (int j = 0; j < argsi.length; ++j) { argsWithConstants[modelFactor.siblingNumberToArgIndex(j)] = argsi[j]; } argsi = argsWithConstants; } retval[i] = createFactor(factor.getFactorFunction(),argsi).getFactor(0); } return PHelpers.convertToFactorVector(retval); } public void addBoundaryVariables(Object [] vars) { for (Object var : vars) { PVariableVector varvec = (PVariableVector)var; getGraph().addBoundaryVariables(varvec.getVariableArray()); } } public PFactorGraphVector addGraphVectorized(PFactorGraphVector graph, Object [] vars, Object [] indices) { PNodeVector [] nodes = PHelpers.convertObjectArrayToNodeVectorArray(vars); int [][][] intIndices = PHelpers.extractIndicesVectorized(indices); PNodeVector [][] args = PHelpers.extractVectorization(nodes, intIndices); PVariableVector varVector = new PVariableVector(); Node [] retval = new Node[args.length]; for (int i = 0; i < args.length; i++) { varVector = (PVariableVector)varVector.concat(args[i]); retval[i] = addGraph(graph,varVector).getModelerNode(0); } return new PFactorGraphVector(retval); } public void setSolver(@Nullable IFactorGraphFactory<?> solver) { getGraph().setSolverFactory(solver); } public PFactorVector createCustomFactor(String funcName,PVariableVector varVector) { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); Variable [] vars = varVector.getVariableArray(); Factor f = getGraph().addFactor(new CustomFactorFunctionWrapper(funcName), vars); return new PFactorVector(f); } public PFactorVector createCustomFactor(String funcName, Object [] variables) { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); Factor f = getGraph().addFactor(new CustomFactorFunctionWrapper(funcName),PHelpers.convertToMVariablesAndConstants(variables)); return new PFactorVector(f); } public PFactorGraphVector [] getNestedGraphs() { Collection<FactorGraph> graphs = getGraph().getOwnedGraphs(); PFactorGraphVector [] retval = new PFactorGraphVector[graphs.size()]; int i = 0; for (FactorGraph g : graphs) { retval[i] = new PFactorGraphVector(g); i++; } return retval; } public boolean isForest(int relativeNestingDepth) { return getGraph().isForest(relativeNestingDepth); } public boolean isTree(int relativeNestingDepth) { return getGraph().isTree(relativeNestingDepth); } public PNodeVector [] depthFirstSearch(PNodeVector root, int searchDepth, int relativeNestingDepth) { if (root.size() != 1) throw new DimpleException("choose one root"); IMapList<INode> nodes = getGraph().depthFirstSearch(root.getModelerNode(0), searchDepth,relativeNestingDepth); PNodeVector [] retval = new PNodeVector[nodes.size()]; for (int i = 0; i < retval.length; i++) { retval[i] = PHelpers.wrapObject(nodes.getByIndex(i)); } return retval; } public PFactorGraphStream addRepeatedFactor(PFactorGraphVector nestedGraph, int bufferSize,Object ... vars) { //Object [] arr = new Object[vars.length]; ArrayList<Object> al = new ArrayList<Object>(); for (int i = 0; i < vars.length; i++) { if (vars[i] instanceof PVariableVector) { PVariableVector pvv = (PVariableVector)vars[i]; if (pvv.size() != 1) throw new DimpleException("only support one var for now"); al.add(pvv.getModelerNode(0)); } else if (vars[i] instanceof IPVariableStreamSlice) { IVariableStreamSlice<?> [] slices = ((IPVariableStreamSlice)vars[i]).getModelerObjects(); for (int j = 0; j < slices.length; j++) al.add(slices[j]); } else { throw new DimpleException("when this happen?"); //arr[i] = vars[i]; } } Object [] newarray = al.toArray(); FactorGraphStream rfg = getGraph().addRepeatedFactorWithBufferSize(nestedGraph.getGraph(), bufferSize, newarray); return new PFactorGraphStream(rfg); } public void baumWelch(Object [] factorsAndTables,int numRestarts,int numSteps) { Object [] mfandt = new Object[factorsAndTables.length]; for (int i = 0; i < factorsAndTables.length; i++) { if (factorsAndTables[i] instanceof PFactorTable) mfandt[i] = ((PFactorTable)factorsAndTables[i]).getModelerObject(); else if (factorsAndTables[i] instanceof PFactorVector) { PFactorVector pfv = (PFactorVector)factorsAndTables[i]; if (pfv.size() != 1) throw new DimpleException("for now we only support factor vectors with a single factor"); mfandt[i] = pfv.getModelerNode(0); } else throw new DimpleException("Unsupported argument to estimateParameters"); } this.getGraph().baumWelch(mfandt,numRestarts,numSteps); } public void estimateParameters(Object [] factorsAndTables,int numRestarts,int numSteps, double stepScaleFactor) { Object [] mfandt = new Object[factorsAndTables.length]; for (int i = 0; i < factorsAndTables.length; i++) { if (factorsAndTables[i] instanceof PFactorTable) mfandt[i] = ((PFactorTable)factorsAndTables[i]).getModelerObject(); else if (factorsAndTables[i] instanceof PFactorVector) { PFactorVector pfv = (PFactorVector)factorsAndTables[i]; if (pfv.size() != 1) throw new DimpleException("for now we only support factor vectors with a single factor"); mfandt[i] = pfv.getModelerNode(0); } else throw new DimpleException("Unsupported argument to estimateParameters"); } this.getGraph().estimateParameters(mfandt,numRestarts,numSteps,stepScaleFactor); } public void advance() { getGraph().advance(); } public boolean hasNext() { return getGraph().hasNext(); } public boolean isAncestorOf(Object o) { if (! (o instanceof PNodeVector)) return false; PNodeVector pn = (PNodeVector)o; if (pn.size() != 1) throw new DimpleException("only support variable of size 1"); return getGraph().isAncestorOf(pn.getModelerNode(0)); } public void removeFactor(PFactorVector factor) { Node [] factors = factor.getModelerNodes(); for (Node n : factors) getGraph().remove((Factor)n); } public void initialize() { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); getGraph().initialize(); } public PFactorVector [] getNonGraphFactors(int relativeNestingDepth) { return getNonGraphFactors(getGraph().getNonGraphFactors(relativeNestingDepth)); } public PFactorVector [] getNonGraphFactors(IMapList<Factor> factors) { if (getGraph().isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); return PHelpers.convertFactorListToFactors(factors.values()); } public boolean hasParentGraph() { return getGraph().hasParentGraph(); } public @Nullable PFactorGraphVector getParentGraph() { FactorGraph mgraph = getGraph().getParentGraph(); if(mgraph != null) return new PFactorGraphVector(mgraph); else return null; } public PFactorGraphVector getRootGraph() { return new PFactorGraphVector(getGraph().getRootGraph()); } private ISolverFactorGraph getSolverGraph() { ISolverFactorGraph solverGraph = getGraph().getSolver(); if (solverGraph == null) { throw new DimpleException("Solver not set."); } return solverGraph; } public @Nullable PVariableVector getVariableByName(String name) { Variable mo = getGraph().getVariableByName(name); if (mo != null) return (PVariableVector)PHelpers.wrapObject(mo); else return null; } public @Nullable PFactorVector getFactorByName(String name) { Factor mo = getGraph().getFactorByName(name); if(mo != null) return (PFactorVector)PHelpers.wrapObject(mo); else return null; } public @Nullable PFactorGraphVector getGraphByName(String name) { FactorGraph mo = getGraph().getGraphByName(name); if(mo != null) return new PFactorGraphVector(mo); else return null; } public @Nullable PVariableVector getVariableByUUID(UUID uuid) { Variable mo = getGraph().getVariableByUUID(uuid); if(mo != null) return (PVariableVector)PHelpers.wrapObject(mo); else return null; } public @Nullable PFactorVector getFactorByUUID(UUID uuid) { Factor mo = getGraph().getFactorByUUID(uuid); if(mo != null) return (PFactorVector) PHelpers.wrapObject(mo); else return null; } public @Nullable PFactorGraphVector getGraphByUUID(UUID uuid) { FactorGraph mo = getGraph().getGraphByUUID(uuid); if(mo != null) return (PFactorGraphVector)PHelpers.wrapObject(mo); else return null; } @SuppressWarnings("deprecation") public void setScheduler(@Nullable Object obj) { final FactorGraph graph = getGraph(); if (graph.isSolverRunning()) throw new DimpleException("No changes allowed while the solver is running."); graph.setScheduler(obj != null ? SchedulerBase.instantiate(graph.getEnvironment(), obj) : null); } @SuppressWarnings("deprecation") public @Nullable PScheduler getScheduler() { IScheduler scheduler = getGraph().getScheduler(); return scheduler != null ? new PScheduler(scheduler) : null; } public PFactorGraphStream [] getFactorGraphStreams() { PFactorGraphStream [] retval = new PFactorGraphStream[getGraph().getFactorGraphStreams().size()]; for (int i = 0; i < retval.length; i++) { FactorGraphStream fgs = getGraph().getFactorGraphStreams().get(i); retval[i] = new PFactorGraphStream(fgs); } return retval; } public FactorGraphDiffs getFactorGraphDiffsByName(PFactorGraphVector b) { return FactorGraphDiffs.getFactorGraphDiffs( getGraph(), b.getGraph(), false, true); } public PFactorVector joinFactors(Object [] factors) { //convert Object [] to Factor array Factor [] facs = new Factor[factors.length]; for (int i = 0; i < factors.length; i++) { //TODO: error check? facs[i] = (Factor)PHelpers.convertToNode(factors[i]); } Factor f = getGraph().join(facs); return (PFactorVector)PHelpers.wrapObject(f); } public PVariableVector joinVariables(Object [] variables) { Variable [] vars = new Variable[variables.length]; for (int i = 0; i < variables.length; i++) { if (! (variables[i] instanceof PVariableVector)) throw new DimpleException("only variable bases supported"); vars[i] = (Variable)PHelpers.convertToNode(variables[i]); } return (PVariableVector)PHelpers.wrapObject(getGraph().join(vars)); } public PVariableVector split(PVariableVector variable, @Nullable Object [] factors) { Factor [] pfactors = {}; if (factors != null) pfactors = PHelpers.convertObjectArrayToFactors(factors); Node n = PHelpers.convertToNode(variable); if (n instanceof Discrete) return new PDiscreteVariableVector(getGraph().split((Variable)n,pfactors)); else return new PRealVariableVector((Real)(getGraph().split((Variable)n,pfactors))); } public double getBetheFreeEnergy() { return getGraph().getBetheFreeEnergy(); } // For operating collectively on groups of variables that are not already part of a variable vector @Deprecated public int defineVariableGroup(Object[] variables) { return addVariableBlock(variables).getLocalId(); } public PVariableBlock addVariableBlock(Object[] variables) { return addVariableBlock(getGraph(), variables); } static PVariableBlock addVariableBlock(@Nullable FactorGraph graph, Object[] variables) { ArrayList<Variable> variableList = new ArrayList<Variable>(); for (int i = 0; i < variables.length; i++) { Variable[] modelerVariables = ((PVariableVector)variables[i]).getModelerVariables(); for (int j = 0; j < modelerVariables.length; j++) variableList.add(modelerVariables[j]); } if (graph == null) { // Choose graph from variables. If all come from the same graph, use that graph, otherwise // use the root graph. graph = variableList.get(0).getParentGraph(); for (int i = 1, n = variableList.size(); i < n; ++i) { Variable var = variableList.get(i); if (graph != var.getParentGraph()) { graph = var.getRootGraph(); break; } } } return new PVariableBlock(requireNonNull(graph).addVariableBlock(variableList)); } }