/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.runtime.instructions.mr;
import java.util.ArrayList;
import com.ibm.bi.dml.lops.WeightedCrossEntropy.WCeMMType;
import com.ibm.bi.dml.lops.WeightedDivMM.WDivMMType;
import com.ibm.bi.dml.lops.WeightedDivMMR;
import com.ibm.bi.dml.lops.WeightedSigmoid.WSigmoidType;
import com.ibm.bi.dml.lops.WeightedSquaredLoss;
import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType;
import com.ibm.bi.dml.lops.WeightedSquaredLossR;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.functionobjects.SwapIndex;
import com.ibm.bi.dml.runtime.instructions.Instruction;
import com.ibm.bi.dml.runtime.instructions.InstructionUtils;
import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics;
import com.ibm.bi.dml.runtime.matrix.data.LibMatrixReorg;
import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
import com.ibm.bi.dml.runtime.matrix.data.MatrixIndexes;
import com.ibm.bi.dml.runtime.matrix.data.MatrixValue;
import com.ibm.bi.dml.runtime.matrix.mapred.CachedValueMap;
import com.ibm.bi.dml.runtime.matrix.mapred.IndexedMatrixValue;
import com.ibm.bi.dml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import com.ibm.bi.dml.runtime.matrix.operators.Operator;
import com.ibm.bi.dml.runtime.matrix.operators.QuaternaryOperator;
import com.ibm.bi.dml.runtime.matrix.operators.ReorgOperator;
/**
*
*/
public class QuaternaryInstruction extends MRInstruction implements IDistributedCacheConsumer
{
private byte _input1 = -1;
private byte _input2 = -1;
private byte _input3 = -1;
private byte _input4 = -1;
private boolean _cacheU = false;
private boolean _cacheV = false;
/**
*
* @param type
* @param in1
* @param in2
* @param out
* @param istr
*/
public QuaternaryInstruction(Operator op, byte in1, byte in2, byte in3, byte in4, byte out, boolean cacheU, boolean cacheV, String istr)
{
super(op, out);
mrtype = MRINSTRUCTION_TYPE.Quaternary;
instString = istr;
_input1 = in1;
_input2 = in2;
_input3 = in3;
_input4 = in4;
_cacheU = cacheU;
_cacheV = cacheV;
}
public byte getInput1() {
return _input1;
}
public byte getInput2() {
return _input2;
}
public byte getInput3() {
return _input3;
}
public byte getInput4() {
return _input3;
}
/**
*
* @param mc1
* @param mc2
* @param mc3
* @param dimOut
*/
public void computeMatrixCharacteristics(MatrixCharacteristics mc1, MatrixCharacteristics mc2,
MatrixCharacteristics mc3, MatrixCharacteristics dimOut)
{
QuaternaryOperator qop = (QuaternaryOperator)optr;
if( qop.wtype1 != null || qop.wtype4 != null ) { //wsloss/wcemm
//output size independent of chain type (scalar)
dimOut.set(1, 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
else if( qop.wtype2 != null ) { //wsigmoid
//output size determined by main input
dimOut.set(mc1.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
else if(qop.wtype3 != null ) { //wdivmm
//note: cannot directly consume mc2 or mc3 for redwdivmm because rep instruction changed
//the relevant dimensions; as a workaround the original dims are passed via nnz
boolean mapwdivmm = _cacheU && _cacheV;
long rank = qop.wtype3.isLeft() ? mapwdivmm?mc3.getCols():mc3.getNonZeros() : mapwdivmm?mc2.getCols():mc2.getNonZeros();
MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mc1.getRows(), mc1.getCols(), rank);
dimOut.set(mcTmp.getRows(), mcTmp.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
}
/**
*
* @param str
* @return
* @throws DMLRuntimeException
*/
public static Instruction parseInstruction( String str )
throws DMLRuntimeException
{
String opcode = InstructionUtils.getOpCode(str);
//validity check
if ( !InstructionUtils.isDistQuaternaryOpcode(opcode) ){
throw new DMLRuntimeException("Unexpected opcode in QuaternaryInstruction: " + str);
}
//instruction parsing
if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //wsloss
|| WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) )
{
boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
//check number of fields (4 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields ( str, 8 );
else
InstructionUtils.checkNumFields ( str, 6 );
//parse instruction parts (without exec type)
String[] parts = InstructionUtils.getInstructionParts(str);
byte in1 = Byte.parseByte(parts[1]);
byte in2 = Byte.parseByte(parts[2]);
byte in3 = Byte.parseByte(parts[3]);
byte in4 = Byte.parseByte(parts[4]);
byte out = Byte.parseByte(parts[5]);
WeightsType wtype = WeightsType.valueOf(parts[6]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str);
}
else //wsigmoid / wdivmm / wcemm
{
boolean isRed = opcode.startsWith("red");
//check number of fields (3 inputs, output, type)
if( WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) )
InstructionUtils.checkNumFields ( str, 7, 8 );
else if( isRed )
InstructionUtils.checkNumFields ( str, 7 );
else
InstructionUtils.checkNumFields ( str, 5 );
//parse instruction parts (without exec type)
String[] parts = InstructionUtils.getInstructionParts(str);
boolean wdivmmMinus = (parts.length==9);
byte in1 = Byte.parseByte(parts[1]);
byte in2 = Byte.parseByte(parts[2]);
byte in3 = Byte.parseByte(parts[3]);
byte in4 = wdivmmMinus?Byte.parseByte(parts[4]):-1;
byte out = Byte.parseByte(parts[wdivmmMinus?5:4]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[wdivmmMinus?7:6]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[wdivmmMinus?8:7]) : true;
if( opcode.endsWith("wsigmoid") )
return new QuaternaryInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, (byte)-1, out, cacheU, cacheV, str);
else if( opcode.endsWith("wdivmm") )
return new QuaternaryInstruction(new QuaternaryOperator(WDivMMType.valueOf(parts[wdivmmMinus?6:5])), in1, in2, in3, in4, out, cacheU, cacheV, str);
else if( opcode.endsWith("wcemm") )
return new QuaternaryInstruction(new QuaternaryOperator(WCeMMType.valueOf(parts[5])), in1, in2, in3, (byte)-1, out, cacheU, cacheV, str);
}
return null;
}
@Override //IDistributedCacheConsumer
public boolean isDistCacheOnlyIndex( String inst, byte index )
{
if( _cacheU && _cacheV )
{
return (index==_input2 && index!=_input1 && index!=_input4)
|| (index==_input3 && index!=_input1 && index!=_input4);
}
else
{
return (_cacheU && index==_input2 && index!=_input1 && index!=_input4)
|| (_cacheV && index==_input3 && index!=_input1 && index!=_input4);
}
}
@Override //IDistributedCacheConsumer
public void addDistCacheIndex( String inst, ArrayList<Byte> indexes )
{
if( _cacheU )
indexes.add(_input2);
if( _cacheV )
indexes.add(_input3);
}
@Override
public byte[] getInputIndexes()
{
QuaternaryOperator qop = (QuaternaryOperator)optr;
if( qop.wtype1 == null || !qop.wtype1.hasFourInputs() )
return new byte[]{_input1, _input2, _input3};
else
return new byte[]{_input1, _input2, _input3, _input4};
}
@Override
public byte[] getAllIndexes()
{
QuaternaryOperator qop = (QuaternaryOperator)optr;
if( qop.wtype1 == null || !qop.wtype1.hasFourInputs() )
return new byte[]{_input1, _input2, _input3, output};
else
return new byte[]{_input1, _input2, _input3, _input4, output};
}
@Override
public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues,
IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor)
throws DMLUnsupportedOperationException, DMLRuntimeException
{
QuaternaryOperator qop = (QuaternaryOperator)optr;
ArrayList<IndexedMatrixValue> blkList = cachedValues.get(_input1);
if( blkList !=null )
for(IndexedMatrixValue imv : blkList)
{
//Step 1: prepare inputs and output
if( imv==null )
continue;
MatrixIndexes inIx = imv.getIndexes();
MatrixValue inVal = imv.getValue();
//allocate space for the output value
IndexedMatrixValue iout = null;
if(output==_input1)
iout=tempValue;
else
iout=cachedValues.holdPlace(output, valueClass);
MatrixIndexes outIx = iout.getIndexes();
MatrixValue outVal = iout.getValue();
//Step 2: get remaining inputs: Wij, Ui, Vj
MatrixValue Xij = inVal;
//get Wij if existing (null of WeightsType.NONE or WSigmoid any type)
IndexedMatrixValue iWij = cachedValues.getFirst(_input4);
MatrixValue Wij = (iWij!=null) ? iWij.getValue() : null;
//get Ui and Vj, potentially through distributed cache
MatrixValue Ui = (!_cacheU) ? cachedValues.getFirst(_input2).getValue() //U
: MRBaseForCommonInstructions.dcValues.get(_input2)
.getDataBlock((int)inIx.getRowIndex(), 1).getValue();
MatrixValue Vj = (!_cacheV) ? cachedValues.getFirst(_input3).getValue() //t(V)
: MRBaseForCommonInstructions.dcValues.get(_input3)
.getDataBlock((int)inIx.getColumnIndex(), 1).getValue();
//handle special input case: //V through shuffle -> t(V)
if( Ui.getNumColumns()!=Vj.getNumColumns() ) {
Vj = LibMatrixReorg.reorg((MatrixBlock) Vj, new MatrixBlock(Vj.getNumColumns(), Vj.getNumRows(), Vj.isInSparseFormat()), new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
}
//Step 3: process instruction
Xij.quaternaryOperations(qop, Ui, Vj, Wij, outVal);
//set output indexes
if( qop.wtype1 != null || qop.wtype4 != null)
outIx.setIndexes(1, 1); //wsloss
else if ( qop.wtype2 != null || qop.wtype3!=null && qop.wtype3.isBasic() )
outIx.setIndexes(inIx); //wsigmoid/wdivmm-basic
else { //wdivmm
boolean left = qop.wtype3.isLeft();
outIx.setIndexes(left?inIx.getColumnIndex():inIx.getRowIndex(), 1);
}
//put the output value in the cache
if(iout==tempValue)
cachedValues.add(output, iout);
}
}
}