/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.instructions.mr; import java.util.ArrayList; import org.apache.sysml.lops.MMCJ.MMCJType; import org.apache.sysml.lops.MapMult; import org.apache.sysml.lops.MapMult.CacheType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions; import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.Operator; public class AggregateBinaryInstruction extends BinaryMRInstructionBase implements IDistributedCacheConsumer { private String _opcode = null; //optional argument for cpmm private MMCJType _aggType = MMCJType.AGG; //optional argument for mapmm private CacheType _cacheType = null; private boolean _outputEmptyBlocks = true; public AggregateBinaryInstruction(Operator op, String opcode, byte in1, byte in2, byte out, String istr) { super(op, in1, in2, out); mrtype = MRINSTRUCTION_TYPE.AggregateBinary; instString = istr; _opcode = opcode; } public void setCacheTypeMapMult( CacheType type ) { _cacheType = type; } public void setOutputEmptyBlocksMapMult( boolean flag ) { _outputEmptyBlocks = flag; } public boolean getOutputEmptyBlocks() { return _outputEmptyBlocks; } public void setMMCJType( MMCJType type ) { _aggType = type; } public MMCJType getMMCJType() { return _aggType; } public static AggregateBinaryInstruction parseInstruction ( String str ) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionParts ( str ); byte in1, in2, out; String opcode = parts[0]; in1 = Byte.parseByte(parts[1]); in2 = Byte.parseByte(parts[2]); out = Byte.parseByte(parts[3]); if ( opcode.equalsIgnoreCase("cpmm") || opcode.equalsIgnoreCase("rmm") || opcode.equalsIgnoreCase(MapMult.OPCODE) ) { AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); AggregateBinaryInstruction inst = new AggregateBinaryInstruction(aggbin, opcode, in1, in2, out, str); if( parts.length==5 ){ inst.setMMCJType(MMCJType.valueOf(parts[4])); } else if( parts.length==6 ) { //mapmm inst.setCacheTypeMapMult( CacheType.valueOf(parts[4]) ); inst.setOutputEmptyBlocksMapMult( Boolean.parseBoolean(parts[5]) ); } return inst; } throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode); } @Override //IDistributedCacheConsumer public boolean isDistCacheOnlyIndex( String inst, byte index ) { return _cacheType.isRight() ? (index==input2 && index!=input1) : (index==input1 && index!=input2); } @Override //IDistributedCacheConsumer public void addDistCacheIndex( String inst, ArrayList<Byte> indexes ) { indexes.add( _cacheType.isRight() ? input2 : input1 ); } @Override public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException { IndexedMatrixValue in1=cachedValues.getFirst(input1); IndexedMatrixValue in2=cachedValues.getFirst(input2); if ( _opcode.equals(MapMult.OPCODE) ) { //check empty inputs (data for different instructions) if( _cacheType.isRight() ? in1==null : in2==null ) return; // one of the input is from distributed cache. processMapMultInstruction(valueClass, cachedValues, in1, in2, blockRowFactor, blockColFactor); } else //generic matrix mult { //check empty inputs (data for different instructions) if(in1==null || in2==null) return; //allocate space for the output value IndexedMatrixValue out; if(output==input1 || output==input2) out=tempValue; else out=cachedValues.holdPlace(output, valueClass); //process instruction OperationsOnMatrixValues.performAggregateBinary( in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), out.getIndexes(), out.getValue(), ((AggregateBinaryOperator)optr)); //put the output value in the cache if(out==tempValue) cachedValues.add(output, out); } } /** * Helper function to perform map-side matrix-matrix multiplication. * * @param valueClass matrix value class * @param cachedValues cached value map * @param in1 indexed matrix value 1 * @param in2 indexed matrix value 2 * @param blockRowFactor ? * @param blockColFactor ? * @throws DMLRuntimeException if DMLRuntimeException occurs */ private void processMapMultInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue in1, IndexedMatrixValue in2, int blockRowFactor, int blockColFactor) throws DMLRuntimeException { boolean removeOutput = true; if( _cacheType.isRight() ) { DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(input2); long in2_cols = dcInput.getNumCols(); long in2_colBlocks = (long)Math.ceil(((double)in2_cols)/dcInput.getNumColsPerBlock()); for(int bidx=1; bidx <= in2_colBlocks; bidx++) { // Matrix multiply A[i,k] %*% B[k,bid] // Setup input2 block IndexedMatrixValue in2Block = dcInput.getDataBlock((int)in1.getIndexes().getColumnIndex(), bidx); MatrixValue in2BlockValue = in2Block.getValue(); MatrixIndexes in2BlockIndex = in2Block.getIndexes(); //allocate space for the output value IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass); //process instruction OperationsOnMatrixValues.performAggregateBinary(in1.getIndexes(), in1.getValue(), in2BlockIndex, in2BlockValue, out.getIndexes(), out.getValue(), ((AggregateBinaryOperator)optr)); removeOutput &= ( !_outputEmptyBlocks && out.getValue().isEmpty() ); } } else { DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(input1); long in1_rows = dcInput.getNumRows(); long in1_rowsBlocks = (long) Math.ceil(((double)in1_rows)/dcInput.getNumRowsPerBlock()); for(int bidx=1; bidx <= in1_rowsBlocks; bidx++) { // Matrix multiply A[i,k] %*% B[k,bid] // Setup input2 block IndexedMatrixValue in1Block = dcInput.getDataBlock(bidx, (int)in2.getIndexes().getRowIndex()); MatrixValue in1BlockValue = in1Block.getValue(); MatrixIndexes in1BlockIndex = in1Block.getIndexes(); //allocate space for the output value IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass); //process instruction OperationsOnMatrixValues.performAggregateBinary(in1BlockIndex, in1BlockValue, in2.getIndexes(), in2.getValue(), out.getIndexes(), out.getValue(), ((AggregateBinaryOperator)optr)); removeOutput &= ( !_outputEmptyBlocks && out.getValue().isEmpty() ); } } //empty block output filter (enabled by compiler consumer operation is in CP) if( removeOutput ) cachedValues.remove(output); } }