/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.runtime.instructions.cp;
import java.util.HashMap;
import com.ibm.bi.dml.lops.Lop;
import com.ibm.bi.dml.parser.Statement;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext;
import com.ibm.bi.dml.runtime.functionobjects.ParameterizedBuiltin;
import com.ibm.bi.dml.runtime.functionobjects.ValueFunction;
import com.ibm.bi.dml.runtime.instructions.Instruction;
import com.ibm.bi.dml.runtime.instructions.InstructionUtils;
import com.ibm.bi.dml.runtime.instructions.mr.GroupedAggregateInstruction;
import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
import com.ibm.bi.dml.runtime.matrix.operators.Operator;
import com.ibm.bi.dml.runtime.matrix.operators.SimpleOperator;
public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
{
private int arity;
protected HashMap<String,String> params;
public ParameterizedBuiltinCPInstruction(Operator op, HashMap<String,String> paramsMap, CPOperand out, String opcode, String istr )
{
super(op, null, null, out, opcode, istr);
_cptype = CPINSTRUCTION_TYPE.ParameterizedBuiltin;
params = paramsMap;
}
public int getArity() {
return arity;
}
public HashMap<String,String> getParameterMap() { return params; }
public static HashMap<String, String> constructParameterMap(String[] params) {
// process all elements in "params" except first(opcode) and last(output)
HashMap<String,String> paramMap = new HashMap<String,String>();
// all parameters are of form <name=value>
String[] parts;
for ( int i=1; i <= params.length-2; i++ ) {
parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
paramMap.put(parts[0], parts[1]);
}
return paramMap;
}
public static Instruction parseInstruction ( String str )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// first part is always the opcode
String opcode = parts[0];
// last part is always the output
CPOperand out = new CPOperand( parts[parts.length-1] );
// process remaining parts and build a hash map
HashMap<String,String> paramsMap = constructParameterMap(parts);
// determine the appropriate value function
ValueFunction func = null;
if ( opcode.equalsIgnoreCase("cdf") ) {
if ( paramsMap.get("dist") == null )
throw new DMLRuntimeException("Invalid distribution: " + str);
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode, paramsMap.get("dist"));
// Determine appropriate Function Object based on opcode
return new ParameterizedBuiltinCPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
else if ( opcode.equalsIgnoreCase("invcdf") ) {
if ( paramsMap.get("dist") == null )
throw new DMLRuntimeException("Invalid distribution: " + str);
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode, paramsMap.get("dist"));
// Determine appropriate Function Object based on opcode
return new ParameterizedBuiltinCPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
else if ( opcode.equalsIgnoreCase("groupedagg")) {
// check for mandatory arguments
String fnStr = paramsMap.get("fn");
if ( fnStr == null )
throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
if ( fnStr.equalsIgnoreCase("centralmoment") ) {
if ( paramsMap.get("order") == null )
throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
}
Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order"));
return new ParameterizedBuiltinCPInstruction(op, paramsMap, out, opcode, str);
}
else if( opcode.equalsIgnoreCase("rmempty")
|| opcode.equalsIgnoreCase("replace")
|| opcode.equalsIgnoreCase("rexpand") )
{
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinCPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
else if ( opcode.equals("transform")) {
return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str);
}
else {
throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction.");
}
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException, DMLUnsupportedOperationException
{
String opcode = getOpcode();
ScalarObject sores = null;
if ( opcode.equalsIgnoreCase("cdf")) {
SimpleOperator op = (SimpleOperator) _optr;
double result = op.fn.execute(params);
sores = new DoubleObject(result);
ec.setScalarOutput(output.getName(), sores);
}
else if ( opcode.equalsIgnoreCase("invcdf")) {
SimpleOperator op = (SimpleOperator) _optr;
double result = op.fn.execute(params);
sores = new DoubleObject(result);
ec.setScalarOutput(output.getName(), sores);
}
else if ( opcode.equalsIgnoreCase("groupedagg") ) {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get(Statement.GAGG_TARGET));
MatrixBlock groups = ec.getMatrixInput(params.get(Statement.GAGG_GROUPS));
MatrixBlock weights= null;
if ( params.get(Statement.GAGG_WEIGHTS) != null )
weights = ec.getMatrixInput(params.get(Statement.GAGG_WEIGHTS));
int ngroups = -1;
if ( params.get(Statement.GAGG_NUM_GROUPS) != null) {
ngroups = (int) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
}
// compute the result
MatrixBlock soresBlock = (MatrixBlock) (groups.groupedAggOperations(target, weights, new MatrixBlock(), ngroups, _optr));
ec.setMatrixOutput(output.getName(), soresBlock);
// release locks
target = groups = weights = null;
ec.releaseMatrixInput(params.get(Statement.GAGG_TARGET));
ec.releaseMatrixInput(params.get(Statement.GAGG_GROUPS));
if ( params.get(Statement.GAGG_WEIGHTS) != null )
ec.releaseMatrixInput(params.get(Statement.GAGG_WEIGHTS));
}
else if ( opcode.equalsIgnoreCase("rmempty") ) {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get("target"));
MatrixBlock select = params.containsKey("select")? ec.getMatrixInput(params.get("select")):null;
// compute the result
String margin = params.get("margin");
MatrixBlock soresBlock = null;
if( margin.equals("rows") )
soresBlock = target.removeEmptyOperations(new MatrixBlock(), true, select);
else if( margin.equals("cols") )
soresBlock = target.removeEmptyOperations(new MatrixBlock(), false, select);
else
throw new DMLRuntimeException("Unspupported margin identifier '"+margin+"'.");
//release locks
ec.setMatrixOutput(output.getName(), soresBlock);
ec.releaseMatrixInput(params.get("target"));
if (params.containsKey("select"))
ec.releaseMatrixInput(params.get("select"));
}
else if ( opcode.equalsIgnoreCase("replace") ) {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get("target"));
// compute the result
double pattern = Double.parseDouble( params.get("pattern") );
double replacement = Double.parseDouble( params.get("replacement") );
MatrixBlock ret = (MatrixBlock) target.replaceOperations(new MatrixBlock(), pattern, replacement);
//release locks
ec.setMatrixOutput(output.getName(), ret);
ec.releaseMatrixInput(params.get("target"));
}
else if ( opcode.equalsIgnoreCase("rexpand") ) {
// acquire locks
MatrixBlock target = ec.getMatrixInput(params.get("target"));
// compute the result
double maxVal = Double.parseDouble( params.get("max") );
boolean dirVal = params.get("dir").equals("rows");
boolean cast = Boolean.parseBoolean(params.get("cast"));
boolean ignore = Boolean.parseBoolean(params.get("ignore"));
MatrixBlock ret = (MatrixBlock) target.rexpandOperations(new MatrixBlock(), maxVal, dirVal, cast, ignore);
//release locks
ec.setMatrixOutput(output.getName(), ret);
ec.releaseMatrixInput(params.get("target"));
}
else {
throw new DMLRuntimeException("Unknown opcode : " + opcode);
}
}
}