/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.mr;
import java.util.HashMap;
import org.apache.sysml.lops.Ternary;
import org.apache.sysml.lops.Ternary.OperationTypes;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.matrix.data.CTableMap;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
public class TernaryInstruction extends MRInstruction
{
private OperationTypes _op;
public byte input1;
public byte input2;
public byte input3;
public double scalar_input2;
public double scalar_input3;
private long _outputDim1, _outputDim2;
/**
* Single matrix input
*
* @param op operation type
* @param in1 input 1 (byte)
* @param scalar_in2 input 2 (double)
* @param scalar_in3 input 3 (double)
* @param out output
* @param outputDim1 output dimension 1
* @param outputDim2 output dimension 2
* @param istr instruction string
*/
public TernaryInstruction(OperationTypes op, byte in1, double scalar_in2, double scalar_in3, byte out, long outputDim1, long outputDim2, String istr)
{
super(null, out);
mrtype = MRINSTRUCTION_TYPE.Ternary;
_op = op;
input1 = in1;
scalar_input2 = scalar_in2;
scalar_input3 = scalar_in3;
_outputDim1 = outputDim1;
_outputDim2 = outputDim2;
instString = istr;
}
/**
* Two matrix inputs
*
* @param op operation type
* @param in1 input 1 (byte)
* @param in2 input 2 (byte)
* @param scalar_in3 input 3 (double)
* @param out output
* @param outputDim1 output dimension 1
* @param outputDim2 output dimension 2
* @param istr instruction string
*/
public TernaryInstruction(OperationTypes op, byte in1, byte in2, double scalar_in3, byte out, long outputDim1, long outputDim2, String istr)
{
super(null, out);
mrtype = MRINSTRUCTION_TYPE.Ternary;
_op = op;
input1 = in1;
input2 = in2;
scalar_input3 = scalar_in3;
_outputDim1 = outputDim1;
_outputDim2 = outputDim2;
instString = istr;
}
/**
* Two matrix input
*
* @param op operation type
* @param in1 input 1 (byte)
* @param scalar_in2 input 2 (double)
* @param in3 input 3 (byte)
* @param out output
* @param outputDim1 output dimension 1
* @param outputDim2 output dimension 2
* @param istr instruction string
*/
public TernaryInstruction(OperationTypes op, byte in1, double scalar_in2, byte in3, byte out, long outputDim1, long outputDim2, String istr)
{
super(null, out);
mrtype = MRINSTRUCTION_TYPE.Ternary;
_op = op;
input1 = in1;
scalar_input2 = scalar_in2;
input3 = in3;
_outputDim1 = outputDim1;
_outputDim2 = outputDim2;
instString = istr;
}
/**
* Three matrix inputs
*
* @param op operation type
* @param in1 input 1 (byte)
* @param in2 input 2 (byte)
* @param in3 input 3 (byte)
* @param out output
* @param outputDim1 output dimension 1
* @param outputDim2 output dimension 2
* @param istr instruction string
*/
public TernaryInstruction(OperationTypes op, byte in1, byte in2, byte in3, byte out, long outputDim1, long outputDim2, String istr)
{
super(null, out);
mrtype = MRINSTRUCTION_TYPE.Ternary;
_op = op;
input1 = in1;
input2 = in2;
input3 = in3;
_outputDim1 = outputDim1;
_outputDim2 = outputDim2;
instString = istr;
}
public long getOutputDim1() {
return _outputDim1;
}
public long getOutputDim2() {
return _outputDim2;
}
public boolean knownOutputDims() {
return (_outputDim1 >0 && _outputDim2>0);
}
public static TernaryInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
// example instruction string
// - ctabletransform:::0:DOUBLE:::1:DOUBLE:::2:DOUBLE:::3:DOUBLE
// - ctabletransformscalarweight:::0:DOUBLE:::1:DOUBLE:::1.0:DOUBLE:::3:DOUBLE
// - ctabletransformhistogram:::0:DOUBLE:::1.0:DOUBLE:::1.0:DOUBLE:::3:DOUBLE
// - ctabletransformweightedhistogram:::0:DOUBLE:::1:INT:::1:DOUBLE:::2:DOUBLE
//check number of fields
InstructionUtils.checkNumFields ( str, 6 );
//common setup
byte in1, in2, in3, out;
String[] parts = InstructionUtils.getInstructionParts ( str );
String opcode = parts[0];
in1 = Byte.parseByte(parts[1]);
long outputDim1 = (long) Double.parseDouble(parts[4]);
long outputDim2 = (long) Double.parseDouble(parts[5]);
out = Byte.parseByte(parts[6]);
OperationTypes op = Ternary.getOperationType(opcode);
switch( op )
{
case CTABLE_TRANSFORM: {
in2 = Byte.parseByte(parts[2]);
in3 = Byte.parseByte(parts[3]);
return new TernaryInstruction(op, in1, in2, in3, out, outputDim1, outputDim2, str);
}
case CTABLE_TRANSFORM_SCALAR_WEIGHT: {
in2 = Byte.parseByte(parts[2]);
double scalar_in3 = Double.parseDouble(parts[3]);
return new TernaryInstruction(op, in1, in2, scalar_in3, out, outputDim1, outputDim2, str);
}
case CTABLE_EXPAND_SCALAR_WEIGHT: {
double scalar_in2 = Double.parseDouble(parts[2]);
double type = Double.parseDouble(parts[3]); //used as type (1 left, 0 right)
return new TernaryInstruction(op, in1, scalar_in2, type, out, outputDim1, outputDim2, str);
}
case CTABLE_TRANSFORM_HISTOGRAM: {
double scalar_in2 = Double.parseDouble(parts[2]);
double scalar_in3 = Double.parseDouble(parts[3]);
return new TernaryInstruction(op, in1, scalar_in2, scalar_in3, out, outputDim1, outputDim2, str);
}
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
double scalar_in2 = Double.parseDouble(parts[2]);
in3 = Byte.parseByte(parts[3]);
return new TernaryInstruction(op, in1, scalar_in2, in3, out, outputDim1, outputDim2, str);
}
default:
throw new DMLRuntimeException("Unrecognized opcode in Ternary Instruction: " + op);
}
}
public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues,
IndexedMatrixValue zeroInput, HashMap<Byte, CTableMap> resultMaps, HashMap<Byte, MatrixBlock> resultBlocks,
int blockRowFactor, int blockColFactor)
throws DMLRuntimeException
{
IndexedMatrixValue in1, in2, in3 = null;
in1 = cachedValues.getFirst(input1);
CTableMap ctableResult = null;
MatrixBlock ctableResultBlock = null;
if ( knownOutputDims() ) {
if ( resultBlocks != null ) {
ctableResultBlock = resultBlocks.get(output);
if ( ctableResultBlock == null ) {
// From MR, output of ctable is set to be sparse since it is built from a single input block.
ctableResultBlock = new MatrixBlock((int)_outputDim1, (int)_outputDim2, true);
resultBlocks.put(output, ctableResultBlock);
}
}
else {
throw new DMLRuntimeException("Unexpected error in processing table instruction.");
}
}
else {
//prepare aggregation maps
ctableResult=resultMaps.get(output);
if(ctableResult==null)
{
ctableResult = new CTableMap();
resultMaps.put(output, ctableResult);
}
}
//get inputs and process instruction
switch( _op )
{
case CTABLE_TRANSFORM: {
in2 = cachedValues.getFirst(input2);
in3 = cachedValues.getFirst(input3);
if(in1==null || in2==null || in3 == null )
return;
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(),
in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_TRANSFORM_SCALAR_WEIGHT: {
// 3rd input is a scalar
in2 = cachedValues.getFirst(input2);
if(in1==null || in2==null )
return;
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(),
scalar_input3, ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_EXPAND_SCALAR_WEIGHT: {
// 2nd and 3rd input is a scalar
if(in1==null )
return;
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, (scalar_input3==1),
blockRowFactor, ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_TRANSFORM_HISTOGRAM: {
// 2nd and 3rd inputs are scalars
if(in1==null )
return;
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, scalar_input3, ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
// 2nd and 3rd inputs are scalars
in3 = cachedValues.getFirst(input3);
if(in1==null || in3==null)
return;
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2,
in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr);
break;
}
default:
throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString);
}
}
@Override
public void processInstruction(Class<? extends MatrixValue> valueClass,
CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput,
int blockRowFactor, int blockColFactor)
throws DMLRuntimeException
{
throw new DMLRuntimeException("This function should not be called!");
}
@Override
public byte[] getAllIndexes() throws DMLRuntimeException {
return new byte[]{input1, input2, input3, output};
}
@Override
public byte[] getInputIndexes() throws DMLRuntimeException {
return new byte[]{input1, input2, input3};
}
}