/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.mr;
import java.util.ArrayList;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.PMMJ.CacheType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.util.UtilFunctions;
public class PMMJMRInstruction extends BinaryMRInstructionBase implements IDistributedCacheConsumer
{
private long _rlen = -1;
private boolean _outputEmptyBlocks = true;
public PMMJMRInstruction(Operator op, byte in1, byte in2, byte out, long nrow, CacheType ctype, boolean outputEmpty, String istr)
{
super(op, in1, in2, out);
instString = istr;
_rlen = nrow;
_outputEmptyBlocks = outputEmpty;
//NOTE: cache type only used by distributed cache input
}
public long getNumRows() {
return _rlen;
}
public boolean getOutputEmptyBlocks() {
return _outputEmptyBlocks;
}
public static PMMJMRInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
InstructionUtils.checkNumFields ( str, 6 );
String[] parts = InstructionUtils.getInstructionParts(str);
String opcode = parts[0];
byte in1 = Byte.parseByte(parts[1]);
byte in2 = Byte.parseByte(parts[2]);
long nrow = UtilFunctions.toLong(Double.parseDouble(parts[3]));
byte out = Byte.parseByte(parts[4]);
CacheType ctype = CacheType.valueOf(parts[5]);
boolean outputEmpty = Boolean.parseBoolean(parts[6]);
if(!opcode.equalsIgnoreCase("pmm"))
throw new DMLRuntimeException("Unknown opcode while parsing an PmmMRInstruction: " + str);
return new PMMJMRInstruction(new Operator(true), in1, in2, out, nrow, ctype, outputEmpty, str);
}
@Override
public void processInstruction(Class<? extends MatrixValue> valueClass,
CachedValueMap cachedValues, IndexedMatrixValue tempValue,
IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor)
throws DMLRuntimeException
{
//get both matrix inputs (left side always permutation)
DistributedCacheInput dcInput = MRBaseForCommonInstructions.dcValues.get(input1);
IndexedMatrixValue in2 = cachedValues.getFirst(input2);
IndexedMatrixValue in1 = dcInput.getDataBlock((int)in2.getIndexes().getRowIndex(), 1);
MatrixBlock mb1 = (MatrixBlock)in1.getValue();
MatrixBlock mb2 = (MatrixBlock)in2.getValue();
//compute target block indexes
long minPos = UtilFunctions.toLong( mb1.minNonZero() );
long maxPos = UtilFunctions.toLong( mb1.max() );
long rowIX1 = (minPos-1)/blockRowFactor+1;
long rowIX2 = (maxPos-1)/blockRowFactor+1;
boolean multipleOuts = (rowIX1 != rowIX2);
if( minPos >= 1 ) //at least one row selected
{
//output sparsity estimate
double spmb1 = OptimizerUtils.getSparsity(mb1.getNumRows(), 1, mb1.getNonZeros());
long estnnz = (long) (spmb1 * mb2.getNonZeros());
boolean sparse = MatrixBlock.evalSparseFormatInMemory(blockRowFactor, mb2.getNumColumns(), estnnz);
//compute and allocate output blocks
IndexedMatrixValue out1 = cachedValues.holdPlace(output, valueClass);
IndexedMatrixValue out2 = multipleOuts ? cachedValues.holdPlace(output, valueClass) : null;
out1.getValue().reset(blockRowFactor, mb2.getNumColumns(), sparse);
if( out2 != null )
out2.getValue().reset(UtilFunctions.computeBlockSize(_rlen, rowIX2, blockRowFactor), mb2.getNumColumns(), sparse);
//compute core matrix permutation (assumes that out1 has default blocksize,
//hence we do a meta data correction afterwards)
mb1.permutationMatrixMultOperations(mb2, out1.getValue(), (out2!=null)?out2.getValue():null);
((MatrixBlock)out1.getValue()).setNumRows(UtilFunctions.computeBlockSize(_rlen, rowIX1, blockRowFactor));
out1.getIndexes().setIndexes(rowIX1, in2.getIndexes().getColumnIndex());
if( out2 != null )
out2.getIndexes().setIndexes(rowIX2, in2.getIndexes().getColumnIndex());
//empty block output filter (enabled by compiler consumer operation is in CP)
if( !_outputEmptyBlocks && out1.getValue().isEmpty()
&& (out2==null || out2.getValue().isEmpty() ) )
{
cachedValues.remove(output);
}
}
}
@Override //IDistributedCacheConsumer
public boolean isDistCacheOnlyIndex( String inst, byte index )
{
return (index==input1 && index!=input2);
}
@Override //IDistributedCacheConsumer
public void addDistCacheIndex( String inst, ArrayList<Byte> indexes )
{
indexes.add(input1);
}
}