/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.runtime.instructions.mr; import java.util.HashMap; import com.ibm.bi.dml.lops.Ternary; import com.ibm.bi.dml.lops.Ternary.OperationTypes; import com.ibm.bi.dml.runtime.DMLRuntimeException; import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException; import com.ibm.bi.dml.runtime.instructions.Instruction; import com.ibm.bi.dml.runtime.instructions.InstructionUtils; import com.ibm.bi.dml.runtime.matrix.data.CTableMap; import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock; import com.ibm.bi.dml.runtime.matrix.data.MatrixValue; import com.ibm.bi.dml.runtime.matrix.data.OperationsOnMatrixValues; import com.ibm.bi.dml.runtime.matrix.mapred.CachedValueMap; import com.ibm.bi.dml.runtime.matrix.mapred.IndexedMatrixValue; public class TernaryInstruction extends MRInstruction { private OperationTypes _op; public byte input1; public byte input2; public byte input3; public double scalar_input2; public double scalar_input3; private long _outputDim1, _outputDim2; /** * Single matrix input * * @param op * @param in1 * @param scalar_in2 * @param scalar_in3 * @param out * @param istr */ public TernaryInstruction(OperationTypes op, byte in1, double scalar_in2, double scalar_in3, byte out, long outputDim1, long outputDim2, String istr) { super(null, out); mrtype = MRINSTRUCTION_TYPE.Ternary; _op = op; input1 = in1; scalar_input2 = scalar_in2; scalar_input3 = scalar_in3; _outputDim1 = outputDim1; _outputDim2 = outputDim2; instString = istr; } /** * Two matrix inputs * * @param op * @param in1 * @param in2 * @param scalar_in3 * @param out * @param istr */ public TernaryInstruction(OperationTypes op, byte in1, byte in2, double scalar_in3, byte out, long outputDim1, long outputDim2, String istr) { super(null, out); mrtype = MRINSTRUCTION_TYPE.Ternary; _op = op; input1 = in1; input2 = in2; scalar_input3 = scalar_in3; _outputDim1 = outputDim1; _outputDim2 = outputDim2; instString = istr; } /** * Two matrix input * * @param op * @param in1 * @param scalar_in2 * @param in3 * @param out * @param istr */ public TernaryInstruction(OperationTypes op, byte in1, double scalar_in2, byte in3, byte out, long outputDim1, long outputDim2, String istr) { super(null, out); mrtype = MRINSTRUCTION_TYPE.Ternary; _op = op; input1 = in1; scalar_input2 = scalar_in2; input3 = in3; _outputDim1 = outputDim1; _outputDim2 = outputDim2; instString = istr; } /** * Three matrix inputs * * @param op * @param in1 * @param in2 * @param in3 * @param out * @param istr */ public TernaryInstruction(OperationTypes op, byte in1, byte in2, byte in3, byte out, long outputDim1, long outputDim2, String istr) { super(null, out); mrtype = MRINSTRUCTION_TYPE.Ternary; _op = op; input1 = in1; input2 = in2; input3 = in3; _outputDim1 = outputDim1; _outputDim2 = outputDim2; instString = istr; } public long getOutputDim1() { return _outputDim1; } public long getOutputDim2() { return _outputDim2; } public boolean knownOutputDims() { return (_outputDim1 >0 && _outputDim2>0); } /** * * @param str * @return * @throws DMLRuntimeException */ public static Instruction parseInstruction ( String str ) throws DMLRuntimeException { // example instruction string // - ctabletransform:::0:DOUBLE:::1:DOUBLE:::2:DOUBLE:::3:DOUBLE // - ctabletransformscalarweight:::0:DOUBLE:::1:DOUBLE:::1.0:DOUBLE:::3:DOUBLE // - ctabletransformhistogram:::0:DOUBLE:::1.0:DOUBLE:::1.0:DOUBLE:::3:DOUBLE // - ctabletransformweightedhistogram:::0:DOUBLE:::1:INT:::1:DOUBLE:::2:DOUBLE //check number of fields InstructionUtils.checkNumFields ( str, 6 ); //common setup byte in1, in2, in3, out; String[] parts = InstructionUtils.getInstructionParts ( str ); String opcode = parts[0]; in1 = Byte.parseByte(parts[1]); long outputDim1 = (long) Double.parseDouble(parts[4]); long outputDim2 = (long) Double.parseDouble(parts[5]); out = Byte.parseByte(parts[6]); OperationTypes op = Ternary.getOperationType(opcode); switch( op ) { case CTABLE_TRANSFORM: { in2 = Byte.parseByte(parts[2]); in3 = Byte.parseByte(parts[3]); return new TernaryInstruction(op, in1, in2, in3, out, outputDim1, outputDim2, str); } case CTABLE_TRANSFORM_SCALAR_WEIGHT: { in2 = Byte.parseByte(parts[2]); double scalar_in3 = Double.parseDouble(parts[3]); return new TernaryInstruction(op, in1, in2, scalar_in3, out, outputDim1, outputDim2, str); } case CTABLE_EXPAND_SCALAR_WEIGHT: { double scalar_in2 = Double.parseDouble(parts[2]); double type = Double.parseDouble(parts[3]); //used as type (1 left, 0 right) return new TernaryInstruction(op, in1, scalar_in2, type, out, outputDim1, outputDim2, str); } case CTABLE_TRANSFORM_HISTOGRAM: { double scalar_in2 = Double.parseDouble(parts[2]); double scalar_in3 = Double.parseDouble(parts[3]); return new TernaryInstruction(op, in1, scalar_in2, scalar_in3, out, outputDim1, outputDim2, str); } case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { double scalar_in2 = Double.parseDouble(parts[2]); in3 = Byte.parseByte(parts[3]); return new TernaryInstruction(op, in1, scalar_in2, in3, out, outputDim1, outputDim2, str); } default: throw new DMLRuntimeException("Unrecognized opcode in Ternary Instruction: " + op); } } public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue zeroInput, HashMap<Byte, CTableMap> resultMaps, HashMap<Byte, MatrixBlock> resultBlocks, int blockRowFactor, int blockColFactor) throws DMLUnsupportedOperationException, DMLRuntimeException { IndexedMatrixValue in1, in2, in3 = null; in1 = cachedValues.getFirst(input1); CTableMap ctableResult = null; MatrixBlock ctableResultBlock = null; if ( knownOutputDims() ) { if ( resultBlocks != null ) { ctableResultBlock = resultBlocks.get(output); if ( ctableResultBlock == null ) { // From MR, output of ctable is set to be sparse since it is built from a single input block. ctableResultBlock = new MatrixBlock((int)_outputDim1, (int)_outputDim2, true); resultBlocks.put(output, ctableResultBlock); } } else { throw new DMLRuntimeException("Unexpected error in processing table instruction."); } } else { //prepare aggregation maps ctableResult=resultMaps.get(output); if(ctableResult==null) { ctableResult = new CTableMap(); resultMaps.put(output, ctableResult); } } //get inputs and process instruction switch( _op ) { case CTABLE_TRANSFORM: { in2 = cachedValues.getFirst(input2); in3 = cachedValues.getFirst(input3); if(in1==null || in2==null || in3 == null ) return; OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); break; } case CTABLE_TRANSFORM_SCALAR_WEIGHT: { // 3rd input is a scalar in2 = cachedValues.getFirst(input2); if(in1==null || in2==null ) return; OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), scalar_input3, ctableResult, ctableResultBlock, optr); break; } case CTABLE_EXPAND_SCALAR_WEIGHT: { // 2nd and 3rd input is a scalar if(in1==null ) return; OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, (scalar_input3==1), blockRowFactor, ctableResult, ctableResultBlock, optr); break; } case CTABLE_TRANSFORM_HISTOGRAM: { // 2nd and 3rd inputs are scalars if(in1==null ) return; OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, scalar_input3, ctableResult, ctableResultBlock, optr); break; } case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { // 2nd and 3rd inputs are scalars in3 = cachedValues.getFirst(input3); if(in1==null || in3==null) return; OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); break; } default: throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString); } } @Override public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLUnsupportedOperationException, DMLRuntimeException { throw new DMLRuntimeException("This function should not be called!"); } @Override public byte[] getAllIndexes() throws DMLRuntimeException { return new byte[]{input1, input2, input3, output}; } @Override public byte[] getInputIndexes() throws DMLRuntimeException { return new byte[]{input1, input2, input3}; } }