/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.matlabproxy;
import java.util.ArrayList;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.environment.DimpleEnvironment;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.matlabproxy.repeated.PDiscreteStream;
import com.analog.lyric.dimple.matlabproxy.repeated.PDoubleArrayDataSink;
import com.analog.lyric.dimple.matlabproxy.repeated.PDoubleArrayDataSource;
import com.analog.lyric.dimple.matlabproxy.repeated.PFactorFunctionDataSource;
import com.analog.lyric.dimple.matlabproxy.repeated.PMultivariateDataSink;
import com.analog.lyric.dimple.matlabproxy.repeated.PMultivariateDataSource;
import com.analog.lyric.dimple.matlabproxy.repeated.PRealJointStream;
import com.analog.lyric.dimple.matlabproxy.repeated.PRealStream;
import com.analog.lyric.dimple.matlabproxy.repeated.PVariableStreamBase;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.Model;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.schedulers.SchedulerOptionKey;
import com.analog.lyric.dimple.solvers.core.CustomFactors;
import com.analog.lyric.dimple.solvers.core.CustomFactorsOptionKey;
import com.analog.lyric.dimple.solvers.core.multithreading.ThreadPool;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteEnergyMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteWeightMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
import com.analog.lyric.dimple.solvers.interfaces.IFactorGraphFactory;
import com.analog.lyric.options.IOptionKey;
import com.analog.lyric.util.misc.Matlab;
/**
* The model factory creates variable vectors and FactorGraphs for MATLAB
*/
@Matlab
public class ModelFactory
{
/**
* Create empty {@link PCustomFactors} wrapper appropriate for given option.
* <p>
* @param optionName identifies a {@link CustomFactorsOptionKey} instance that is
* looked up in the {@link DimpleEnvironment#optionRegistry() optionRegistry} of the
* {@linkplain DimpleEnvironment#active() active environment}, e.g. {@code "GibbsOptions.customFactors"}.
* @throws DimpleException if matching option cannot be found or is not the correct type.
* @since 0.08
*/
public PCustomFactors createCustomFactors(String optionName)
{
IOptionKey<?> key = DimpleEnvironment.active().optionRegistry().get(optionName);
if (key == null)
{
throw new DimpleException("Cannot find option '%s'", optionName);
}
if (!(key instanceof CustomFactorsOptionKey))
{
throw new DimpleException("Option '%s' is not a CustomFactors option", optionName);
}
CustomFactorsOptionKey<?,?,?> customFactorsKey = (CustomFactorsOptionKey<?,?,?>)key;
CustomFactors<?,?> customFactors = customFactorsKey.createValue();
return new PCustomFactors(customFactors);
}
/**
* Create custom scheduler for graph.
*
* @param graph
* @param schedulerKey either a {@link SchedulerOptionKey} or its qualified name.
* @param scheduleEntries are in format given by CustomScheduler class in MATLAB.
* @since 0.08
*/
public PScheduler createCustomScheduler(PFactorGraphVector graph, @Nullable Object schedulerKey,
Object[] scheduleEntries)
{
if (schedulerKey instanceof SchedulerOptionKey)
{
return new PScheduler(graph, (SchedulerOptionKey)schedulerKey, scheduleEntries);
}
else
{
String schedulerKeyName = (String)schedulerKey;
if (schedulerKeyName == null || schedulerKeyName.isEmpty())
{
// This is only to support MATLAB FactorGraph.Schedule setter
@SuppressWarnings("deprecation")
PScheduler scheduler = new PScheduler(graph, scheduleEntries);
return scheduler;
}
else
{
return new PScheduler(graph, schedulerKeyName, scheduleEntries);
}
}
}
/**
* Create scheduler with given class name, looking up in active Dimple environment.
* @since 0.08
*/
public PScheduler createScheduler(String name)
{
return new PScheduler(DimpleEnvironment.active().schedulers().instantiate(name));
}
public MultivariateNormalParameters createMultivariateNormalParameters(double[] mean, double[][] covariance)
{
return new MultivariateNormalParameters(mean, covariance);
}
public NormalParameters createNormalParameters(double mean, double precision)
{
return new NormalParameters(mean, precision);
}
@Deprecated
public PRealJointVariableVector createRealJointVariableVector(String className, PRealJointDomain domain, int numEls)
{
return new PRealJointVariableVector(domain, numEls);
}
public PRealJointVariableVector createRealJointVariableVector(PRealJointDomain domain, int numEls)
{
return new PRealJointVariableVector(domain, numEls);
}
public DiscreteMessage createDiscreteMessage(String type, int size, @Nullable double[] values)
{
DiscreteMessage msg = type.equals("energy") ? new DiscreteEnergyMessage(size) : new DiscreteWeightMessage(size);
if (values != null)
{
System.arraycopy(values, 0, msg.representation(), 0, size);
}
return msg;
}
@Deprecated
public PDiscreteVariableVector createDiscreteVariableVector(String className, PDiscreteDomain domain, int numEls)
{
return new PDiscreteVariableVector(domain,numEls);
}
public PDiscreteVariableVector createDiscreteVariableVector(PDiscreteDomain domain, int numEls)
{
return new PDiscreteVariableVector(domain,numEls);
}
// For backwards compatibility with MATLAB
@Deprecated
public PDiscreteVariableVector createVariableVector(String className, PDiscreteDomain domain, int numEls)
{
return new PDiscreteVariableVector(domain,numEls);
}
public PFiniteFieldVariableVector createFiniteFieldVariableVector(PFiniteFieldDomain domain, int numEls)
{
return new PFiniteFieldVariableVector(domain, numEls);
}
public PRealJointDomain createRealJointDomain(Object [] realDomains)
{
return new PRealJointDomain(realDomains);
}
public PDiscreteDomain createDiscreteDomain(Object [] elements)
{
return new PDiscreteDomain(DiscreteDomain.create(elements));
}
public PFiniteFieldDomain createFiniteFieldDomain(int primitivePolynomial)
{
return new PFiniteFieldDomain(DiscreteDomain.finiteField(primitivePolynomial));
}
public PRealDomain createRealDomain(double lowerBound, double upperBound)
{
return new PRealDomain(RealDomain.create(lowerBound,upperBound));
}
public PVariableStreamBase createDiscreteStream(PDiscreteDomain domain, double numVars)
{
return new PDiscreteStream(domain,(int)numVars);
}
public PVariableStreamBase createRealStream(PRealDomain domain, int numVars)
{
return new PRealStream(domain,numVars);
}
public PVariableStreamBase createRealJointStream(PRealJointDomain domain, int numVars)
{
return new PRealJointStream(domain, numVars);
}
public PTableFactorFunction createTableFactorFunction(String name, int [][] indices, double [] values, Object [] domains)
{
PDiscreteDomain [] dds = new PDiscreteDomain[domains.length];
for (int i = 0; i < domains.length; i++)
{
dds[i] = (PDiscreteDomain)domains[i];
}
return new PTableFactorFunction(name,indices,values,dds);
}
public PFactorTable createFactorTable(Object table, Object [] domains)
{
PDiscreteDomain [] dds = new PDiscreteDomain[domains.length];
for (int i = 0; i < domains.length; i++)
{
dds[i] = (PDiscreteDomain)domains[i];
}
return new PFactorTable(table,dds);
}
public PFactorTable createFactorTable(int [][] indices, double [] values, Object [] domains)
{
PDiscreteDomain [] dds = new PDiscreteDomain[domains.length];
for (int i = 0; i < domains.length; i++)
{
dds[i] = (PDiscreteDomain)domains[i];
}
return new PFactorTable(indices,values,dds);
}
public PFactorTable createFactorTable(Object [] domains)
{
PDiscreteDomain [] dds = new PDiscreteDomain[domains.length];
for (int i = 0; i < domains.length; i++)
{
dds[i] = (PDiscreteDomain)domains[i];
}
return new PFactorTable(dds);
}
@Deprecated
public PRealVariableVector createRealVariableVector(String className, PRealDomain domain, int numEls)
{
return new PRealVariableVector(domain, numEls);
}
public PRealVariableVector createRealVariableVector(PRealDomain domain, int numEls)
{
return new PRealVariableVector(domain, numEls);
}
// Create graph
public PFactorGraphVector createGraph(Object [] vector)
{
ArrayList<Variable> alVars = new ArrayList<Variable>();
for (int i = 0; i < vector.length; i++)
{
PVariableVector tmp = (PVariableVector)vector[i];
Variable [] vars = tmp.getVariableArray();
for (int j = 0; j <vars.length; j++)
alVars.add(vars[j]);
}
Variable [] input = new Variable[alVars.size()];
alVars.toArray(input);
FactorGraph f = new FactorGraph(input);
return new PFactorGraphVector(f);
}
public PDimpleEventLogger createEventLogger()
{
return new PDimpleEventLogger();
}
// Set solver
public void setSolver(@Nullable IFactorGraphFactory<?> solver)
{
DimpleEnvironment.active().setDefaultSolver(solver);
}
public PFactorFunctionDataSource getFactorFunctionDataSource(double numVars)
{
return new PFactorFunctionDataSource((int)numVars);
}
public PDoubleArrayDataSource getDoubleArrayDataSource(double numVars)
{
return new PDoubleArrayDataSource((int)numVars);
}
public PDoubleArrayDataSink getDoubleArrayDataSink(double numVars)
{
return new PDoubleArrayDataSink((int)numVars);
}
public PMultivariateDataSource getMultivariateDataSource(double numVars)
{
return new PMultivariateDataSource((int)numVars);
}
public PMultivariateDataSink getMultivariateDataSink(double numVars)
{
return new PMultivariateDataSink((int)numVars);
}
public PMultiplexerCPD getMultiplexerCPD(Object [] zDomains)
{
Object [][] domains = new Object[zDomains.length][];
for (int i = 0; i < zDomains.length; i++)
{
domains[i] = (Object[])zDomains[i];
}
return new PMultiplexerCPD(domains);
}
public PMultiplexerCPD getMultiplexerCPD(Object [] domain, double numZs)
{
return new PMultiplexerCPD(domain,(int)numZs);
}
public void setNumThreadsToDefault()
{
ThreadPool.setNumThreadsToDefault();
}
public void setNumThreads(int numThreads)
{
ThreadPool.setNumThreads(numThreads);
}
public int getNumThreads()
{
return ThreadPool.getNumThreads();
}
public PLogger getLogger()
{
return PLogger.INSTANCE;
}
}