/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.instructions.mr; import java.util.ArrayList; import org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType; import org.apache.sysml.lops.WeightedDivMM; import org.apache.sysml.lops.WeightedDivMM.WDivMMType; import org.apache.sysml.lops.WeightedDivMMR; import org.apache.sysml.lops.WeightedSigmoid.WSigmoidType; import org.apache.sysml.lops.WeightedSquaredLoss; import org.apache.sysml.lops.WeightedSquaredLoss.WeightsType; import org.apache.sysml.lops.WeightedSquaredLossR; import org.apache.sysml.lops.WeightedUnaryMM; import org.apache.sysml.lops.WeightedUnaryMM.WUMMType; import org.apache.sysml.lops.WeightedUnaryMMR; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.SwapIndex; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.LibMatrixReorg; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions; import org.apache.sysml.runtime.matrix.operators.Operator; import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; public class QuaternaryInstruction extends MRInstruction implements IDistributedCacheConsumer { private byte _input1 = -1; private byte _input2 = -1; private byte _input3 = -1; private byte _input4 = -1; private boolean _cacheU = false; private boolean _cacheV = false; public QuaternaryInstruction(Operator op, byte in1, byte in2, byte in3, byte in4, byte out, boolean cacheU, boolean cacheV, String istr) { super(op, out); mrtype = MRINSTRUCTION_TYPE.Quaternary; instString = istr; _input1 = in1; _input2 = in2; _input3 = in3; _input4 = in4; _cacheU = cacheU; _cacheV = cacheV; } public byte getInput1() { return _input1; } public byte getInput2() { return _input2; } public byte getInput3() { return _input3; } public byte getInput4() { return _input4; } public void computeMatrixCharacteristics(MatrixCharacteristics mc1, MatrixCharacteristics mc2, MatrixCharacteristics mc3, MatrixCharacteristics dimOut) { QuaternaryOperator qop = (QuaternaryOperator)optr; if( qop.wtype1 != null || qop.wtype4 != null ) { //wsloss/wcemm //output size independent of chain type (scalar) dimOut.set(1, 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock()); } else if( qop.wtype2 != null || qop.wtype5 != null ) { //wsigmoid/wumm //output size determined by main input dimOut.set(mc1.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock()); } else if(qop.wtype3 != null ) { //wdivmm //note: cannot directly consume mc2 or mc3 for redwdivmm because rep instruction changed //the relevant dimensions; as a workaround the original dims are passed via nnz boolean mapwdivmm = _cacheU && _cacheV; long rank = qop.wtype3.isLeft() ? mapwdivmm?mc3.getCols():mc3.getNonZeros() : mapwdivmm?mc2.getCols():mc2.getNonZeros(); MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mc1.getRows(), mc1.getCols(), rank); dimOut.set(mcTmp.getRows(), mcTmp.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock()); } } public static QuaternaryInstruction parseInstruction( String str ) throws DMLRuntimeException { String opcode = InstructionUtils.getOpCode(str); //validity check if ( !InstructionUtils.isDistQuaternaryOpcode(opcode) ){ throw new DMLRuntimeException("Unexpected opcode in QuaternaryInstruction: " + str); } //instruction parsing if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //wsloss || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) ) { boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode); //check number of fields (4 inputs, output, type) if( isRed ) InstructionUtils.checkNumFields ( str, 8 ); else InstructionUtils.checkNumFields ( str, 6 ); //parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte in3 = Byte.parseByte(parts[3]); byte in4 = Byte.parseByte(parts[4]); byte out = Byte.parseByte(parts[5]); WeightsType wtype = WeightsType.valueOf(parts[6]); //in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str); } else if( WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //wumm || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode) ) { boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); //check number of fields (4 inputs, output, type) if( isRed ) InstructionUtils.checkNumFields ( str, 8 ); else InstructionUtils.checkNumFields ( str, 6 ); //parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); String uopcode = parts[1]; byte in1 = Byte.parseByte(parts[2]); byte in2 = Byte.parseByte(parts[3]); byte in3 = Byte.parseByte(parts[4]); byte out = Byte.parseByte(parts[5]); WUMMType wtype = WUMMType.valueOf(parts[6]); //in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; return new QuaternaryInstruction(new QuaternaryOperator(wtype,uopcode), in1, in2, in3, (byte)-1, out, cacheU, cacheV, str); } else if( WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //wdivmm || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) ) { boolean isRed = opcode.startsWith("red"); //check number of fields (4 inputs, output, type) if( isRed ) InstructionUtils.checkNumFields ( str, 8 ); else InstructionUtils.checkNumFields ( str, 6 ); //parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); final WDivMMType wtype = WDivMMType.valueOf(parts[6]); byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte in3 = Byte.parseByte(parts[3]); byte in4 = wtype.hasScalar() ? -1 : Byte.parseByte(parts[4]); byte out = Byte.parseByte(parts[5]); //in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str); } else //wsigmoid / wcemm { boolean isRed = opcode.startsWith("red"); int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0; //check number of fields (3 or 4 inputs, output, type) if( isRed ) InstructionUtils.checkNumFields ( str, 7 + addInput4 ); else InstructionUtils.checkNumFields ( str, 5 + addInput4 ); //parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte in3 = Byte.parseByte(parts[3]); byte out = Byte.parseByte(parts[4 + addInput4]); //in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true; if( opcode.endsWith("wsigmoid") ) return new QuaternaryInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, (byte)-1, out, cacheU, cacheV, str); else if( opcode.endsWith("wcemm") ) return new QuaternaryInstruction(new QuaternaryOperator(WCeMMType.valueOf(parts[6])), in1, in2, in3, (byte)-1, out, cacheU, cacheV, str); } return null; } @Override //IDistributedCacheConsumer public boolean isDistCacheOnlyIndex( String inst, byte index ) { if( _cacheU && _cacheV ) { return (index==_input2 && index!=_input1 && index!=_input4) || (index==_input3 && index!=_input1 && index!=_input4); } else { return (_cacheU && index==_input2 && index!=_input1 && index!=_input4) || (_cacheV && index==_input3 && index!=_input1 && index!=_input4); } } @Override //IDistributedCacheConsumer public void addDistCacheIndex( String inst, ArrayList<Byte> indexes ) { if( _cacheU ) indexes.add(_input2); if( _cacheV ) indexes.add(_input3); } @Override public byte[] getInputIndexes() { QuaternaryOperator qop = (QuaternaryOperator)optr; if( qop.hasFourInputs() ) return new byte[]{_input1, _input2, _input3, _input4}; else return new byte[]{_input1, _input2, _input3}; } @Override public byte[] getAllIndexes() { QuaternaryOperator qop = (QuaternaryOperator)optr; if( qop.hasFourInputs() ) return new byte[]{_input1, _input2, _input3, _input4, output}; else return new byte[]{_input1, _input2, _input3, output}; } @Override public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException { QuaternaryOperator qop = (QuaternaryOperator)optr; ArrayList<IndexedMatrixValue> blkList = cachedValues.get(_input1); if( blkList !=null ) for(IndexedMatrixValue imv : blkList) { //Step 1: prepare inputs and output if( imv==null ) continue; MatrixIndexes inIx = imv.getIndexes(); MatrixValue inVal = imv.getValue(); //allocate space for the output value IndexedMatrixValue iout = null; if(output==_input1) iout=tempValue; else iout=cachedValues.holdPlace(output, valueClass); MatrixIndexes outIx = iout.getIndexes(); MatrixValue outVal = iout.getValue(); //Step 2: get remaining inputs: Wij, Ui, Vj MatrixValue Xij = inVal; //get Wij if existing (null of WeightsType.NONE or WSigmoid any type) IndexedMatrixValue iWij = (_input4 != -1) ? cachedValues.getFirst(_input4) : null; MatrixValue Wij = (iWij!=null) ? iWij.getValue() : null; if (null == Wij && qop.hasFourInputs()) { MatrixBlock mb = new MatrixBlock(1, 1, false); String[] parts = InstructionUtils.getInstructionParts(instString); mb.quickSetValue(0, 0, Double.valueOf(parts[4])); Wij = mb; } //get Ui and Vj, potentially through distributed cache MatrixValue Ui = (!_cacheU) ? cachedValues.getFirst(_input2).getValue() //U : MRBaseForCommonInstructions.dcValues.get(_input2) .getDataBlock((int)inIx.getRowIndex(), 1).getValue(); MatrixValue Vj = (!_cacheV) ? cachedValues.getFirst(_input3).getValue() //t(V) : MRBaseForCommonInstructions.dcValues.get(_input3) .getDataBlock((int)inIx.getColumnIndex(), 1).getValue(); //handle special input case: //V through shuffle -> t(V) if( Ui.getNumColumns()!=Vj.getNumColumns() ) { Vj = LibMatrixReorg.reorg((MatrixBlock) Vj, new MatrixBlock(Vj.getNumColumns(), Vj.getNumRows(), Vj.isInSparseFormat()), new ReorgOperator(SwapIndex.getSwapIndexFnObject())); } //Step 3: process instruction Xij.quaternaryOperations(qop, Ui, Vj, Wij, outVal); //set output indexes if( qop.wtype1 != null || qop.wtype4 != null) outIx.setIndexes(1, 1); //wsloss else if ( qop.wtype2 != null || qop.wtype5 != null || qop.wtype3!=null && qop.wtype3.isBasic() ) outIx.setIndexes(inIx); //wsigmoid/wdivmm-basic else { //wdivmm boolean left = qop.wtype3.isLeft(); outIx.setIndexes(left?inIx.getColumnIndex():inIx.getRowIndex(), 1); } //put the output value in the cache if(iout==tempValue) cachedValues.add(output, iout); } } }