/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.matlabproxy; import java.util.ArrayList; import java.util.Collection; import com.analog.lyric.collect.Supers; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.matlabproxy.repeated.IPDataSink; import com.analog.lyric.dimple.matlabproxy.repeated.IPDataSource; import com.analog.lyric.dimple.matlabproxy.repeated.PDoubleArrayDataSink; import com.analog.lyric.dimple.matlabproxy.repeated.PDoubleArrayDataSource; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.INode; import com.analog.lyric.dimple.model.core.Node; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.domains.RealDomain; import com.analog.lyric.dimple.model.domains.RealJointDomain; import com.analog.lyric.dimple.model.factors.DiscreteFactor; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.factors.FactorBase; import com.analog.lyric.dimple.model.repeated.DoubleArrayDataSink; import com.analog.lyric.dimple.model.repeated.DoubleArrayDataSource; import com.analog.lyric.dimple.model.repeated.MultivariateDataSink; import com.analog.lyric.dimple.model.repeated.MultivariateDataSource; import com.analog.lyric.dimple.model.repeated.VariableStreamBase; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.model.variables.RealJoint; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableList; import com.analog.lyric.util.misc.Matlab; // TODO: how many of these functions are intended to be invoked from MATLAB? And how many don't need to be public? public class PHelpers { @Matlab public static Variable [] convertToVariableArray(Object [] vlVectors) { return convertToVariableArray(vlVectors, 0); } public static Variable[] convertToVariableArray(Object[] vectors, int start) { ArrayList<Variable> al = new ArrayList<Variable>(); for (int i = start, n = vectors.length; i < n; ++i) { PVariableVector vec = (PVariableVector)vectors[i]; for (Variable vb : vec.getVariableArray()) al.add(vb); } Variable [] retval = new Variable[al.size()]; return al.toArray(retval); } public static Node convertToNode(Object obj) { return convertToNode((PNodeVector)obj); } public static Node [] convertToNodeArray(Object nodeVector) { return convertToNodeArray((PNodeVector)nodeVector); } public static Node [] convertToNodeArray(PNodeVector nodeVector) { Node [] retval = new Node [nodeVector.size()]; for (int i = 0; i < retval.length; i++) retval[i] = nodeVector.getModelerNode(i); return retval; } public static Node convertToNode(PNodeVector nodeVector) { if (nodeVector.size() != 1) throw new DimpleException("only works with 1 node currently"); return nodeVector.getModelerNode(0); } public static DiscreteDomain [] convertDomains(PDiscreteDomain [] domains) { DiscreteDomain [] retval = new DiscreteDomain[domains.length]; for (int i = 0; i < domains.length; i++) { if (!domains[i].getModelerObject().isDiscrete()) throw new RuntimeException("ack"); retval[i] = domains[i].getModelerObject(); } return retval; } public static PDomain wrapDomain(Domain d) { if (d instanceof RealJointDomain) return new PRealJointDomain((RealJointDomain)d); else if (d instanceof DiscreteDomain) return new PDiscreteDomain((DiscreteDomain)d); else if (d instanceof RealDomain) return new PRealDomain((RealDomain)d); else return new PDomain(d); } public static PFactorBaseVector convertToFactorVector(Node [] nodes) { if (nodes.length == 0) return new PFactorBaseVector(); Class<?> superclass = Supers.nearestCommonSuperClass(nodes); if (DiscreteFactor.class.isAssignableFrom(superclass)) return new PDiscreteFactorVector(nodes); else if (FactorGraph.class.isAssignableFrom(superclass)) return new PFactorGraphVector(nodes); else if (Factor.class.isAssignableFrom(superclass)) return new PFactorVector(nodes); else return new PFactorBaseVector(nodes); } public static PFactorVector [] convertToFactorVector(Collection<Factor> factors) { return convertFactorListToFactors(factors); } public static PVariableVector convertToVariableVector(VariableList vars) { Variable [] array = new Variable[vars.size()]; vars.toArray(array); return convertToVariableVector(array); } @SuppressWarnings("deprecation") static public PVariableVector convertToVariableVector(Variable [] variables) { final int n = variables.length; if (n == 0) return new PVariableVector(); // TODO: When VariableBase is removed, change this to Variable.class variables = Supers.narrowArrayOf(com.analog.lyric.dimple.model.variables.VariableBase.class, 1, variables); Class<?> commonVarClass = variables.getClass().getComponentType(); if (Discrete.class.isAssignableFrom(commonVarClass)) { return new PDiscreteVariableVector(variables); } else if (Real.class.isAssignableFrom(commonVarClass)) { return new PRealVariableVector(variables); } else if (RealJoint.class.isAssignableFrom(commonVarClass)) { return new PRealJointVariableVector(variables); } else { return new PVariableVector(variables); } } public static Factor [] convertObjectArrayToFactors(Object [] objects) { Factor [] retval = new Factor[objects.length]; for (int i = 0; i < objects.length; i++) { Node n = convertToNode(objects[i]); retval[i] = (Factor)n; } return retval; } public static PNodeVector [] convertObjectArrayToNodeVectorArray(Object [] objects) { PNodeVector [] vars = new PNodeVector[objects.length]; for (int i = 0; i < objects.length; i++) vars[i] = (PNodeVector)objects[i]; return vars; } public static PVariableVector [] convertObjectArrayToVariableVectorArray(Object [] objects) { PVariableVector [] vars = new PVariableVector[objects.length]; for (int i = 0; i < objects.length; i++) vars[i] = (PVariableVector)objects[i]; return vars; } public static PFactorVector [] convertToFactors(FactorBase [] functions) { PFactorVector [] factors = new PFactorVector[functions.length]; for (int i = 0; i < functions.length; i++) factors[i] = (PFactorVector)wrapObject(functions[i]); return factors; } public static PFactorVector [] convertFactorListToFactors(Collection<Factor> vbs) { return convertToFactors(vbs.toArray(new FactorBase[0])); } @SuppressWarnings("unchecked") public static Object [] convertToMVariablesAndConstants(Object [] vars) { @SuppressWarnings("rawtypes") ArrayList alVars = new ArrayList(); for (int i = 0; i < vars.length; i++) { if (vars[i] instanceof PVariableVector) { PVariableVector varVec = (PVariableVector)vars[i]; for (int j = 0; j < varVec.size(); j++) { alVars.add(varVec.getModelerNode(j)); } } else { alVars.add(vars[i]); } } Object [] newvars = new Object[alVars.size()]; for (int i = 0; i < newvars.length; i++) { newvars[i] = alVars.get(i); } return newvars; } public static PNodeVector wrapObject(INode node) { if (node instanceof DiscreteFactor) { return new PDiscreteFactorVector((DiscreteFactor)node); } else if (node instanceof Factor) { return new PFactorVector((Factor)node); } else if (node instanceof Real) { return new PRealVariableVector((Real)node); } else if (node instanceof Discrete) { return new PDiscreteVariableVector((Discrete)node); } else if (node instanceof FactorGraph) { return new PFactorGraphVector((FactorGraph)node); } else throw new DimpleException("unrecognized type"); } public static PNodeVector [][] extractVectorization(PNodeVector [] nodeVectors, int [][][] indices) { int numNodeVectorsPerAddFactor = indices.length; int numaddFactors = indices[0].length; PNodeVector [][] retval = new PNodeVector[numaddFactors][]; for (int i = 0; i < indices.length; i++) if (indices[i].length != numaddFactors) throw new DimpleException("mismatch of variables sizes"); for (int i = 0; i < numaddFactors; i++) { retval[i] = new PNodeVector[numNodeVectorsPerAddFactor]; for (int j = 0; j < numNodeVectorsPerAddFactor; j++) { retval[i][j] = nodeVectors[j].getSlice(indices[j][i]); } } return retval; } public static int [][][] extractIndicesVectorized(Object [] indices) { int [][][] retval = new int[indices.length][][]; for (int i = 0; i < indices.length; i++) { if (indices[i] instanceof Double) { int index = (int)(double)(Double)indices[i]; retval[i] = new int[1][1]; retval[i][0][0] = index; } else if (indices[i] instanceof double[][]) { double [][] tmp = (double[][])indices[i]; retval[i] = new int[tmp.length][tmp[0].length]; for (int j = 0; j < tmp.length; j++) for (int k = 0; k < tmp[0].length; k++) retval[i][j][k] = (int)tmp[j][k]; } else if (indices[i] instanceof double[]) { double [] tmp = (double[])indices[i]; retval[i] = new int[tmp.length][]; for (int j= 0; j < tmp.length; j++) retval[i][j] = new int[]{(int)tmp[j]}; } else { throw new DimpleException("unsupported indices format: " + indices[i]); } } return retval; } // For non-vectorized node, second index dimension are indices themselves, rather than an array for each vector element public static int[][][] extractIndicesNonVectorized(Object[] indices) { int [][][] retval = new int[indices.length][][]; for (int i = 0; i < indices.length; i++) { if (indices[i] instanceof Double) { int index = (int)(double)(Double)indices[i]; retval[i] = new int[1][1]; retval[i][0][0] = index; } else if (indices[i] instanceof double[]) { double[] tmp = (double[])indices[i]; retval[i] = new int[1][tmp.length]; for (int k= 0; k < tmp.length; k++) retval[i][0][k] = (int)tmp[k]; } else { throw new DimpleException("unsupported indices format: " + indices[i]); } } return retval; } public static IPDataSource getDataSources(VariableStreamBase<?> [] streams) { if (streams[0].getDataSource() instanceof DoubleArrayDataSource) { DoubleArrayDataSource [] dads = new DoubleArrayDataSource[streams.length]; for (int i = 0; i < dads.length; i++) dads[i] = (DoubleArrayDataSource)streams[i].getDataSource(); return new PDoubleArrayDataSource(dads); } else if (streams[0].getDataSource() instanceof MultivariateDataSource) { throw new DimpleException("not currently supported"); } else throw new DimpleException("not currently supported"); } public static IPDataSink getDataSinks(VariableStreamBase<?> [] streams) { if (streams[0].getDataSink() instanceof DoubleArrayDataSink) { DoubleArrayDataSink [] dads = new DoubleArrayDataSink[streams.length]; for (int i = 0; i < dads.length; i++) dads[i] = (DoubleArrayDataSink)streams[i].getDataSink(); return new PDoubleArrayDataSink(dads); } else if (streams[0].getDataSink() instanceof MultivariateDataSink) { throw new DimpleException("Multivariate not currently supported"); } else throw new DimpleException("other not currently supported " + streams[0].getDataSource() + " end"); } }