/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.matrix.mapred;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.hadoop.mapred.JobConf;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.AggregateInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.TaggedFirstSecondIndexes;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
public class MMCJMRCombinerReducerBase extends ReduceBase
{
//aggregate binary instruction for the mmcj
protected AggregateBinaryInstruction aggBinInstruction=null;
//temporary variable to hold the aggregate result
protected MatrixValue buffer=null;
//the tags to be output for the left and right matrice for the mmcj
protected byte tagForLeft=0;
protected byte tagForRight=1;
protected MatrixCharacteristics dim1;
protected MatrixCharacteristics dim2;
// protected int elementSize=8;
public void configure(JobConf job)
{
super.configure(job);
AggregateBinaryInstruction[] ins;
try {
ins = MRJobConfiguration.getAggregateBinaryInstructions(job);
} catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
if(ins.length!=1)
throw new RuntimeException("MMCJ only perform one aggregate binary instruction");
aggBinInstruction=ins[0];
//decide which matrix need to be cached for cross product
dim1=MRJobConfiguration.getMatrixCharactristicsForBinAgg(job, aggBinInstruction.input1);
dim2=MRJobConfiguration.getMatrixCharactristicsForBinAgg(job, aggBinInstruction.input2);
if(dim1.getRows()>dim2.getCols())
{
tagForLeft=1;
tagForRight=0;
}
//allocate space for the temporary variable
try {
buffer=valueClass.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
// if(valueClass.equals(MatrixCell.class))
// elementSize=90;
}
protected MatrixValue performAggregateInstructions(TaggedFirstSecondIndexes indexes, Iterator<MatrixValue> values)
throws IOException
{
//manipulation on the tags first
byte realTag=indexes.getTag();
byte representTag;
if(realTag==tagForLeft)
representTag=aggBinInstruction.input1;
else
representTag=aggBinInstruction.input2;
ArrayList<AggregateInstruction> instructions=agg_instructions.get(representTag);
AggregateInstruction ins;
if(instructions==null)
{
defaultAggIns.input=realTag;
defaultAggIns.output=realTag;
ins=defaultAggIns;
}else
{
if(instructions.size()>1)
throw new IOException("only one aggregate operation on input "
+indexes.getTag()+" is allowed in BlockMMCJMR");
ins=instructions.get(0);
if(ins.input!=ins.output)
throw new IOException("input index and output index have to be " +
"the same for aggregate instructions in BlockMMCJMR");
}
//performa aggregation before doing mmcj
//TODO: customize the code, since aggregation for matrix multiplcation can only be sum
boolean needStartAgg=true;
try {
while(values.hasNext())
{
MatrixValue value=values.next();
if(needStartAgg)
{
buffer.reset(value.getNumRows(), value.getNumColumns(), value.isInSparseFormat());
needStartAgg=false;
// LOG.info("initialize buffer: sparse="+buffer.isInSparseFormat()+", nonZero="+buffer.getNonZeros());
}
buffer.binaryOperationsInPlace(((AggregateOperator)ins.getOperator()).increOp, value);
// LOG.info("increment buffer: sparse="+buffer.isInSparseFormat()+", nonZero="+buffer.getNonZeros());
}
} catch (Exception e) {
throw new IOException(e);
}
if(needStartAgg)
return null;
else
return buffer;
}
}