/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions;
import java.util.StringTokenizer;
import org.apache.sysml.lops.AppendM;
import org.apache.sysml.lops.BinaryM;
import org.apache.sysml.lops.GroupedAggregateM;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.lops.MapMultChain;
import org.apache.sysml.lops.PMMJ;
import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType;
import org.apache.sysml.lops.UAggOuterChain;
import org.apache.sysml.lops.WeightedCrossEntropy;
import org.apache.sysml.lops.WeightedCrossEntropyR;
import org.apache.sysml.lops.WeightedDivMM;
import org.apache.sysml.lops.WeightedDivMMR;
import org.apache.sysml.lops.WeightedSigmoid;
import org.apache.sysml.lops.WeightedSigmoidR;
import org.apache.sysml.lops.WeightedSquaredLoss;
import org.apache.sysml.lops.WeightedSquaredLossR;
import org.apache.sysml.lops.WeightedUnaryMM;
import org.apache.sysml.lops.WeightedUnaryMMR;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.And;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.Divide;
import org.apache.sysml.runtime.functionobjects.Equals;
import org.apache.sysml.runtime.functionobjects.GreaterThan;
import org.apache.sysml.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysml.runtime.functionobjects.IndexFunction;
import org.apache.sysml.runtime.functionobjects.IntegerDivide;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
import org.apache.sysml.runtime.functionobjects.LessThan;
import org.apache.sysml.runtime.functionobjects.LessThanEquals;
import org.apache.sysml.runtime.functionobjects.Mean;
import org.apache.sysml.runtime.functionobjects.Minus;
import org.apache.sysml.runtime.functionobjects.Minus1Multiply;
import org.apache.sysml.runtime.functionobjects.MinusMultiply;
import org.apache.sysml.runtime.functionobjects.MinusNz;
import org.apache.sysml.runtime.functionobjects.Modulus;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Multiply2;
import org.apache.sysml.runtime.functionobjects.NotEquals;
import org.apache.sysml.runtime.functionobjects.Or;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.functionobjects.PlusMultiply;
import org.apache.sysml.runtime.functionobjects.Power;
import org.apache.sysml.runtime.functionobjects.Power2;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceDiag;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.mr.MRInstruction.MRINSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.spark.SPInstruction.SPINSTRUCTION_TYPE;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysml.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
public class InstructionUtils
{
public static int checkNumFields( String str, int expected )
throws DMLRuntimeException
{
//note: split required for empty tokens
int numParts = str.split(Instruction.OPERAND_DELIM).length;
int numFields = numParts - 2; // -2 accounts for execType and opcode
if ( numFields != expected )
throw new DMLRuntimeException("checkNumFields() for (" + str + ") -- expected number (" + expected + ") != is not equal to actual number (" + numFields + ").");
return numFields;
}
public static int checkNumFields( String[] parts, int expected )
throws DMLRuntimeException
{
int numParts = parts.length;
int numFields = numParts - 1; //account for opcode
if ( numFields != expected )
throw new DMLRuntimeException("checkNumFields() -- expected number (" + expected + ") != is not equal to actual number (" + numFields + ").");
return numFields;
}
public static int checkNumFields( String[] parts, int expected1, int expected2 )
throws DMLRuntimeException
{
int numParts = parts.length;
int numFields = numParts - 1; //account for opcode
if ( numFields != expected1 && numFields != expected2 )
throw new DMLRuntimeException("checkNumFields() -- expected number (" + expected1 + " or "+ expected2 +") != is not equal to actual number (" + numFields + ").");
return numFields;
}
public static int checkNumFields( String str, int expected1, int expected2 )
throws DMLRuntimeException
{
//note: split required for empty tokens
int numParts = str.split(Instruction.OPERAND_DELIM).length;
int numFields = numParts - 2; // -2 accounts for execType and opcode
if ( numFields != expected1 && numFields != expected2 )
throw new DMLRuntimeException("checkNumFields() for (" + str + ") -- expected number (" + expected1 + " or "+ expected2 +") != is not equal to actual number (" + numFields + ").");
return numFields;
}
/**
* Given an instruction string, strip-off the execution type and return
* opcode and all input/output operands WITHOUT their data/value type.
* i.e., ret.length = parts.length-1 (-1 for execution type)
*
* @param str instruction string
* @return instruction parts as string array
*/
public static String[] getInstructionParts( String str )
{
StringTokenizer st = new StringTokenizer( str, Instruction.OPERAND_DELIM );
String[] ret = new String[st.countTokens()-1];
st.nextToken(); // stripping-off the exectype
ret[0] = st.nextToken(); // opcode
int index = 1;
while( st.hasMoreTokens() ){
String tmp = st.nextToken();
int ix = tmp.indexOf(Instruction.DATATYPE_PREFIX);
ret[index++] = tmp.substring(0,((ix>=0)?ix:tmp.length()));
}
return ret;
}
/**
* Given an instruction string, this function strips-off the
* execution type (CP or MR) and returns the remaining parts,
* which include the opcode as well as the input and output operands.
* Each returned part will have the datatype and valuetype associated
* with the operand.
*
* This function is invoked mainly for parsing CPInstructions.
*
* @param str instruction string
* @return instruction parts as string array
*/
public static String[] getInstructionPartsWithValueType( String str )
{
//note: split required for empty tokens
String[] parts = str.split(Instruction.OPERAND_DELIM, -1);
String[] ret = new String[parts.length-1]; // stripping-off the exectype
ret[0] = parts[1]; // opcode
for( int i=1; i<parts.length; i++ )
ret[i-1] = parts[i];
return ret;
}
public static String getOpCode( String str )
{
int ix1 = str.indexOf(Instruction.OPERAND_DELIM);
int ix2 = str.indexOf(Instruction.OPERAND_DELIM, ix1+1);
return str.substring(ix1+1, ix2);
}
public static MRINSTRUCTION_TYPE getMRType( String str ) {
return MRInstructionParser.String2MRInstructionType.get( getOpCode(str) );
}
public static SPINSTRUCTION_TYPE getSPType( String str ) {
return SPInstructionParser.String2SPInstructionType.get( getOpCode(str) );
}
public static CPINSTRUCTION_TYPE getCPType( String str ) {
return CPInstructionParser.String2CPInstructionType.get( getOpCode(str) );
}
public static GPUINSTRUCTION_TYPE getGPUType( String str ) {
return GPUInstructionParser.String2GPUInstructionType.get( getOpCode(str) );
}
public static boolean isBuiltinFunction( String opcode ) {
Builtin.BuiltinCode bfc = Builtin.String2BuiltinCode.get(opcode);
return (bfc != null);
}
/**
* Evaluates if at least one instruction of the given instruction set
* used the distributed cache; this call can also be used for individual
* instructions.
*
* @param str instruction set
* @return true if at least one instruction uses distributed cache
*/
public static boolean isDistributedCacheUsed(String str)
{
String[] parts = str.split(Instruction.INSTRUCTION_DELIM);
for(String inst : parts)
{
String opcode = getOpCode(inst);
if( opcode.equalsIgnoreCase(AppendM.OPCODE)
|| opcode.equalsIgnoreCase(MapMult.OPCODE)
|| opcode.equalsIgnoreCase(MapMultChain.OPCODE)
|| opcode.equalsIgnoreCase(PMMJ.OPCODE)
|| opcode.equalsIgnoreCase(UAggOuterChain.OPCODE)
|| opcode.equalsIgnoreCase(GroupedAggregateM.OPCODE)
|| isDistQuaternaryOpcode( opcode ) //multiple quaternary opcodes
|| BinaryM.isOpcode( opcode ) ) //multiple binary opcodes
{
return true;
}
}
return false;
}
public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode)
{
AggregateUnaryOperator aggun = null;
if ( opcode.equalsIgnoreCase("uak+") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uark+") ) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uack+") ) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( opcode.equalsIgnoreCase("uasqk+") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uarsqk+") ) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uacsqk+") ) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( opcode.equalsIgnoreCase("uamean") ) {
// Mean
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uarmean") ) {
// RowMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uacmean") ) {
// ColMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOROWS);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( opcode.equalsIgnoreCase("uavar") ) {
// Variance
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uarvar") ) {
// RowVariances
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uacvar") ) {
// ColVariances
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURROWS;
AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( opcode.equalsIgnoreCase("ua+") ) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uar+") ) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uac+") ) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( opcode.equalsIgnoreCase("ua*") ) {
AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uamax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uamin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uatrace") ) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
}
else if ( opcode.equalsIgnoreCase("uaktrace") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
}
else if ( opcode.equalsIgnoreCase("uarmax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if (opcode.equalsIgnoreCase("uarimax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("maxindex"), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uarmin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if (opcode.equalsIgnoreCase("uarimin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("minindex"), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uacmax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( opcode.equalsIgnoreCase("uacmin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
return aggun;
}
public static AggregateTernaryOperator parseAggregateTernaryOperator(String opcode) {
return parseAggregateTernaryOperator(opcode, 1);
}
public static AggregateTernaryOperator parseAggregateTernaryOperator(String opcode, int numThreads) {
CorrectionLocationType corr = opcode.equalsIgnoreCase("tak+*") ?
CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.LASTROW;
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, corr);
IndexFunction ixfun = opcode.equalsIgnoreCase("tak+*") ?
ReduceAll.getReduceAllFnObject() : ReduceRow.getReduceRowFnObject();
return new AggregateTernaryOperator(Multiply.getMultiplyFnObject(), agg, ixfun, numThreads);
}
public static AggregateOperator parseAggregateOperator(String opcode, String corrExists, String corrLoc)
{
AggregateOperator agg = null;
if ( opcode.equalsIgnoreCase("ak+") || opcode.equalsIgnoreCase("aktrace") ) {
boolean lcorrExists = (corrExists==null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrExists, lcorrLoc);
}
else if ( opcode.equalsIgnoreCase("asqk+") ) {
boolean lcorrExists = (corrExists==null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), lcorrExists, lcorrLoc);
}
else if ( opcode.equalsIgnoreCase("a+") ) {
agg = new AggregateOperator(0, Plus.getPlusFnObject());
}
else if ( opcode.equalsIgnoreCase("a*") ) {
agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
}
else if (opcode.equalsIgnoreCase("arimax")){
agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("maxindex"), true, CorrectionLocationType.LASTCOLUMN);
}
else if ( opcode.equalsIgnoreCase("amax") ) {
agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
}
else if ( opcode.equalsIgnoreCase("amin") ) {
agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
}
else if (opcode.equalsIgnoreCase("arimin")){
agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("minindex"), true, CorrectionLocationType.LASTCOLUMN);
}
else if ( opcode.equalsIgnoreCase("amean") ) {
boolean lcorrExists = (corrExists==null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTTWOCOLUMNS : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrExists, lcorrLoc);
}
else if ( opcode.equalsIgnoreCase("avar") ) {
boolean lcorrExists = (corrExists==null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc==null) ?
CorrectionLocationType.LASTFOURCOLUMNS :
CorrectionLocationType.valueOf(corrLoc);
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
agg = new AggregateOperator(0, varFn, lcorrExists, lcorrLoc);
}
return agg;
}
public static AggregateUnaryOperator parseBasicCumulativeAggregateUnaryOperator(UnaryOperator uop)
{
Builtin f = (Builtin)uop.fn;
if( f.getBuiltinCode()==BuiltinCode.CUMSUM )
return parseBasicAggregateUnaryOperator("uack+") ;
else if( f.getBuiltinCode()==BuiltinCode.CUMPROD )
return parseBasicAggregateUnaryOperator("uac*") ;
else if( f.getBuiltinCode()==BuiltinCode.CUMMIN )
return parseBasicAggregateUnaryOperator("uacmin") ;
else if( f.getBuiltinCode()==BuiltinCode.CUMMAX )
return parseBasicAggregateUnaryOperator("uacmax" ) ;
throw new RuntimeException("Unsupported cumulative aggregate unary operator: "+f.getBuiltinCode());
}
public static AggregateUnaryOperator parseCumulativeAggregateUnaryOperator(String opcode)
{
AggregateUnaryOperator aggun = null;
if( "ucumack+".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( "ucumac*".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, Multiply.getMultiplyFnObject(), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( "ucumacmin".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("min"), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
else if ( "ucumacmax".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("max"), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
return aggun;
}
public static BinaryOperator parseBinaryOperator(String opcode)
throws DMLRuntimeException
{
if(opcode.equalsIgnoreCase("=="))
return new BinaryOperator(Equals.getEqualsFnObject());
else if(opcode.equalsIgnoreCase("!="))
return new BinaryOperator(NotEquals.getNotEqualsFnObject());
else if(opcode.equalsIgnoreCase("<"))
return new BinaryOperator(LessThan.getLessThanFnObject());
else if(opcode.equalsIgnoreCase(">"))
return new BinaryOperator(GreaterThan.getGreaterThanFnObject());
else if(opcode.equalsIgnoreCase("<="))
return new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject());
else if(opcode.equalsIgnoreCase(">="))
return new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject());
else if(opcode.equalsIgnoreCase("&&"))
return new BinaryOperator(And.getAndFnObject());
else if(opcode.equalsIgnoreCase("||"))
return new BinaryOperator(Or.getOrFnObject());
else if(opcode.equalsIgnoreCase("+"))
return new BinaryOperator(Plus.getPlusFnObject());
else if(opcode.equalsIgnoreCase("-"))
return new BinaryOperator(Minus.getMinusFnObject());
else if(opcode.equalsIgnoreCase("*"))
return new BinaryOperator(Multiply.getMultiplyFnObject());
else if(opcode.equalsIgnoreCase("1-*"))
return new BinaryOperator(Minus1Multiply.getMinus1MultiplyFnObject());
else if ( opcode.equalsIgnoreCase("*2") )
return new BinaryOperator(Multiply2.getMultiply2FnObject());
else if(opcode.equalsIgnoreCase("/"))
return new BinaryOperator(Divide.getDivideFnObject());
else if(opcode.equalsIgnoreCase("%%"))
return new BinaryOperator(Modulus.getFnObject());
else if(opcode.equalsIgnoreCase("%/%"))
return new BinaryOperator(IntegerDivide.getFnObject());
else if(opcode.equalsIgnoreCase("^"))
return new BinaryOperator(Power.getPowerFnObject());
else if ( opcode.equalsIgnoreCase("^2") )
return new BinaryOperator(Power2.getPower2FnObject());
else if ( opcode.equalsIgnoreCase("max") )
return new BinaryOperator(Builtin.getBuiltinFnObject("max"));
else if ( opcode.equalsIgnoreCase("min") )
return new BinaryOperator(Builtin.getBuiltinFnObject("min"));
else if ( opcode.equalsIgnoreCase("+*") )
return new BinaryOperator(PlusMultiply.getPlusMultiplyFnObject());
else if ( opcode.equalsIgnoreCase("-*") )
return new BinaryOperator(MinusMultiply.getMinusMultiplyFnObject());
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
/**
* scalar-matrix operator
*
* @param opcode the opcode
* @param arg1IsScalar ?
* @return scalar operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean arg1IsScalar)
throws DMLRuntimeException
{
//for all runtimes that set constant dynamically (cp/spark)
double default_constant = 0;
return parseScalarBinaryOperator(opcode, arg1IsScalar, default_constant);
}
/**
* scalar-matrix operator
*
* @param opcode the opcode
* @param arg1IsScalar ?
* @param constant ?
* @return scalar operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean arg1IsScalar, double constant)
throws DMLRuntimeException
{
//commutative operators
if ( opcode.equalsIgnoreCase("+") ){
return new RightScalarOperator(Plus.getPlusFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("*") ) {
return new RightScalarOperator(Multiply.getMultiplyFnObject(), constant);
}
//non-commutative operators
else if ( opcode.equalsIgnoreCase("-") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Minus.getMinusFnObject(), constant);
else return new RightScalarOperator(Minus.getMinusFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("-nz") ) {
//no support for left scalar yet
return new RightScalarOperator(MinusNz.getMinusNzFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("/") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Divide.getDivideFnObject(), constant);
else return new RightScalarOperator(Divide.getDivideFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("%%") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Modulus.getFnObject(), constant);
else return new RightScalarOperator(Modulus.getFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("%/%") ) {
if(arg1IsScalar)
return new LeftScalarOperator(IntegerDivide.getFnObject(), constant);
else return new RightScalarOperator(IntegerDivide.getFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("^") ){
if(arg1IsScalar)
return new LeftScalarOperator(Power.getPowerFnObject(), constant);
else return new RightScalarOperator(Power.getPowerFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("max") ) {
return new RightScalarOperator(Builtin.getBuiltinFnObject("max"), constant);
}
else if ( opcode.equalsIgnoreCase("min") ) {
return new RightScalarOperator(Builtin.getBuiltinFnObject("min"), constant);
}
else if ( opcode.equalsIgnoreCase("log") || opcode.equalsIgnoreCase("log_nz") ){
if( arg1IsScalar )
return new LeftScalarOperator(Builtin.getBuiltinFnObject(opcode), constant);
return new RightScalarOperator(Builtin.getBuiltinFnObject(opcode), constant);
}
else if ( opcode.equalsIgnoreCase(">") ) {
if(arg1IsScalar)
return new LeftScalarOperator(GreaterThan.getGreaterThanFnObject(), constant);
return new RightScalarOperator(GreaterThan.getGreaterThanFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase(">=") ) {
if(arg1IsScalar)
return new LeftScalarOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), constant);
return new RightScalarOperator(GreaterThanEquals.getGreaterThanEqualsFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("<") ) {
if(arg1IsScalar)
return new LeftScalarOperator(LessThan.getLessThanFnObject(), constant);
return new RightScalarOperator(LessThan.getLessThanFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("<=") ) {
if(arg1IsScalar)
return new LeftScalarOperator(LessThanEquals.getLessThanEqualsFnObject(), constant);
return new RightScalarOperator(LessThanEquals.getLessThanEqualsFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("==") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Equals.getEqualsFnObject(), constant);
return new RightScalarOperator(Equals.getEqualsFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("!=") ) {
if(arg1IsScalar)
return new LeftScalarOperator(NotEquals.getNotEqualsFnObject(), constant);
return new RightScalarOperator(NotEquals.getNotEqualsFnObject(), constant);
}
//operations that only exist for performance purposes (all unary or commutative operators)
else if ( opcode.equalsIgnoreCase("*2") ) {
return new RightScalarOperator(Multiply2.getMultiply2FnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("^2") ){
return new RightScalarOperator(Power2.getPower2FnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("1-*") ) {
return new RightScalarOperator(Minus1Multiply.getMinus1MultiplyFnObject(), constant);
}
//operations that only exist in mr
else if ( opcode.equalsIgnoreCase("s-r") ) {
return new LeftScalarOperator(Minus.getMinusFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("so") ) {
return new LeftScalarOperator(Divide.getDivideFnObject(), constant);
}
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
public static BinaryOperator parseExtendedBinaryOperator(String opcode)
throws DMLRuntimeException
{
if(opcode.equalsIgnoreCase("==") || opcode.equalsIgnoreCase("map=="))
return new BinaryOperator(Equals.getEqualsFnObject());
else if(opcode.equalsIgnoreCase("!=") || opcode.equalsIgnoreCase("map!="))
return new BinaryOperator(NotEquals.getNotEqualsFnObject());
else if(opcode.equalsIgnoreCase("<") || opcode.equalsIgnoreCase("map<"))
return new BinaryOperator(LessThan.getLessThanFnObject());
else if(opcode.equalsIgnoreCase(">") || opcode.equalsIgnoreCase("map>"))
return new BinaryOperator(GreaterThan.getGreaterThanFnObject());
else if(opcode.equalsIgnoreCase("<=") || opcode.equalsIgnoreCase("map<="))
return new BinaryOperator(LessThanEquals.getLessThanEqualsFnObject());
else if(opcode.equalsIgnoreCase(">=") || opcode.equalsIgnoreCase("map>="))
return new BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject());
else if(opcode.equalsIgnoreCase("&&"))
return new BinaryOperator(And.getAndFnObject());
else if(opcode.equalsIgnoreCase("||"))
return new BinaryOperator(Or.getOrFnObject());
else if(opcode.equalsIgnoreCase("+") || opcode.equalsIgnoreCase("map+"))
return new BinaryOperator(Plus.getPlusFnObject());
else if(opcode.equalsIgnoreCase("-") || opcode.equalsIgnoreCase("map-"))
return new BinaryOperator(Minus.getMinusFnObject());
else if(opcode.equalsIgnoreCase("*") || opcode.equalsIgnoreCase("map*"))
return new BinaryOperator(Multiply.getMultiplyFnObject());
else if(opcode.equalsIgnoreCase("1-*") || opcode.equalsIgnoreCase("map1-*"))
return new BinaryOperator(Minus1Multiply.getMinus1MultiplyFnObject());
else if ( opcode.equalsIgnoreCase("*2") )
return new BinaryOperator(Multiply2.getMultiply2FnObject());
else if(opcode.equalsIgnoreCase("/") || opcode.equalsIgnoreCase("map/"))
return new BinaryOperator(Divide.getDivideFnObject());
else if(opcode.equalsIgnoreCase("%%") || opcode.equalsIgnoreCase("map%%"))
return new BinaryOperator(Modulus.getFnObject());
else if(opcode.equalsIgnoreCase("%/%") || opcode.equalsIgnoreCase("map%/%"))
return new BinaryOperator(IntegerDivide.getFnObject());
else if(opcode.equalsIgnoreCase("^") || opcode.equalsIgnoreCase("map^"))
return new BinaryOperator(Power.getPowerFnObject());
else if ( opcode.equalsIgnoreCase("^2") )
return new BinaryOperator(Power2.getPower2FnObject());
else if ( opcode.equalsIgnoreCase("max") || opcode.equalsIgnoreCase("mapmax") )
return new BinaryOperator(Builtin.getBuiltinFnObject("max"));
else if ( opcode.equalsIgnoreCase("min") || opcode.equalsIgnoreCase("mapmin") )
return new BinaryOperator(Builtin.getBuiltinFnObject("min"));
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
public static String deriveAggregateOperatorOpcode(String opcode)
{
if ( opcode.equalsIgnoreCase("uak+") || opcode.equalsIgnoreCase("uark+") || opcode.equalsIgnoreCase("uack+"))
return "ak+";
else if ( opcode.equalsIgnoreCase("uasqk+") || opcode.equalsIgnoreCase("uarsqk+") || opcode.equalsIgnoreCase("uacsqk+") )
return "asqk+";
else if ( opcode.equalsIgnoreCase("uamean") || opcode.equalsIgnoreCase("uarmean") || opcode.equalsIgnoreCase("uacmean") )
return "amean";
else if ( opcode.equalsIgnoreCase("uavar") || opcode.equalsIgnoreCase("uarvar") || opcode.equalsIgnoreCase("uacvar") )
return "avar";
else if ( opcode.equalsIgnoreCase("ua+") || opcode.equalsIgnoreCase("uar+") || opcode.equalsIgnoreCase("uac+") )
return "a+";
else if ( opcode.equalsIgnoreCase("ua*") )
return "a*";
else if ( opcode.equalsIgnoreCase("uatrace") || opcode.equalsIgnoreCase("uaktrace") )
return "aktrace";
else if ( opcode.equalsIgnoreCase("uamax") || opcode.equalsIgnoreCase("uarmax") || opcode.equalsIgnoreCase("uacmax") )
return "amax";
else if ( opcode.equalsIgnoreCase("uamin") || opcode.equalsIgnoreCase("uarmin") || opcode.equalsIgnoreCase("uacmin") )
return "amin";
else if (opcode.equalsIgnoreCase("uarimax") )
return "arimax";
else if (opcode.equalsIgnoreCase("uarimin") )
return "arimin";
return null;
}
public static CorrectionLocationType deriveAggregateOperatorCorrectionLocation(String opcode)
{
if ( opcode.equalsIgnoreCase("uak+") || opcode.equalsIgnoreCase("uark+") ||
opcode.equalsIgnoreCase("uasqk+") || opcode.equalsIgnoreCase("uarsqk+") ||
opcode.equalsIgnoreCase("uatrace") || opcode.equalsIgnoreCase("uaktrace") )
return CorrectionLocationType.LASTCOLUMN;
else if ( opcode.equalsIgnoreCase("uack+") || opcode.equalsIgnoreCase("uacsqk+") )
return CorrectionLocationType.LASTROW;
else if ( opcode.equalsIgnoreCase("uamean") || opcode.equalsIgnoreCase("uarmean") )
return CorrectionLocationType.LASTTWOCOLUMNS;
else if ( opcode.equalsIgnoreCase("uacmean") )
return CorrectionLocationType.LASTTWOROWS;
else if ( opcode.equalsIgnoreCase("uavar") || opcode.equalsIgnoreCase("uarvar") )
return CorrectionLocationType.LASTFOURCOLUMNS;
else if ( opcode.equalsIgnoreCase("uacvar") )
return CorrectionLocationType.LASTFOURROWS;
else if (opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin") )
return CorrectionLocationType.LASTCOLUMN;
return CorrectionLocationType.NONE;
}
public static boolean isDistQuaternaryOpcode(String opcode)
{
return WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //mapwsloss
|| WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) //redwsloss
|| WeightedSigmoid.OPCODE.equalsIgnoreCase(opcode) //mapwsigmoid
|| WeightedSigmoidR.OPCODE.equalsIgnoreCase(opcode) //redwsigmoid
|| WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //mapwdivmm
|| WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) //redwdivmm
|| WeightedCrossEntropy.OPCODE.equalsIgnoreCase(opcode) //mapwcemm
|| WeightedCrossEntropyR.OPCODE.equalsIgnoreCase(opcode) //redwcemm
|| WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //mapwumm
|| WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); //redwumm
}
}