/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.matrix.mapred;
import java.io.IOException;
import java.util.Iterator;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.TaggedMatrixValue;
import org.apache.sysml.runtime.matrix.data.TripleIndexes;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
public class MMRJMRReducer extends ReduceBase
implements Reducer<TripleIndexes, TaggedMatrixValue, MatrixIndexes, MatrixValue>
{
private Reporter cachedReporter=null;
private MatrixValue resultblock=null;
private MatrixIndexes aggIndexes=new MatrixIndexes();
private TripleIndexes prevIndexes=new TripleIndexes(-1, -1, -1);
//aggregate binary instruction for the mmrj
protected AggregateBinaryInstruction[] aggBinInstructions=null;
// private MatrixIndexes indexBuf=new MatrixIndexes();
@Override
public void reduce(TripleIndexes triple, Iterator<TaggedMatrixValue> values,
OutputCollector<MatrixIndexes, MatrixValue> out, Reporter report)
throws IOException {
long start=System.currentTimeMillis();
// System.out.println("~~~~~ group: "+triple);
commonSetup(report);
//output previous results if needed
if(prevIndexes.getFirstIndex()!=triple.getFirstIndex()
|| prevIndexes.getSecondIndex()!=triple.getSecondIndex())
{
// System.out.println("cacheValues before processReducerInstructions: \n"+cachedValues);
//perform mixed operations
processReducerInstructions();
// System.out.println("cacheValues before output: \n"+cachedValues);
//output results
outputResultsFromCachedValues(report);
cachedValues.reset();
}else
{
//clear the buffer
for(AggregateBinaryInstruction aggBinInstruction: aggBinInstructions)
{
// System.out.println("cacheValues before remore: \n"+cachedValues);
cachedValues.remove(aggBinInstruction.input1);
// System.out.println("cacheValues after remore: "+aggBinInstruction.input1+"\n"+cachedValues);
cachedValues.remove(aggBinInstruction.input2);
// System.out.println("cacheValues after remore: "+aggBinInstruction.input2+"\n"+cachedValues);
}
}
//perform aggregation first
aggIndexes.setIndexes(triple.getFirstIndex(), triple.getSecondIndex());
processAggregateInstructions(aggIndexes, values);
// System.out.println("cacheValues after aggregation: \n"+cachedValues);
//perform aggbinary for this group
for(AggregateBinaryInstruction aggBinInstruction: aggBinInstructions)
processAggBinaryPerGroup(aggIndexes, aggBinInstruction);
// System.out.println("cacheValues after aggbinary: \n"+cachedValues);
prevIndexes.setIndexes(triple);
report.incrCounter(Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis()-start);
}
//perform pairwise aggregate binary, and added to the aggregates
private void processAggBinaryPerGroup(MatrixIndexes indexes, AggregateBinaryInstruction aggBinInstruction) throws IOException
{
IndexedMatrixValue left = cachedValues.getFirst(aggBinInstruction.input1);
IndexedMatrixValue right= cachedValues.getFirst(aggBinInstruction.input2);
// System.out.println("left: \n"+left.getValue());
// System.out.println("right: \n"+right.getValue());
if(left!=null && right!=null)
{
try {
resultblock=left.getValue().aggregateBinaryOperations(left.getValue(), right.getValue(),
resultblock, (AggregateBinaryOperator) aggBinInstruction.getOperator());
// System.out.println("resultblock: \n"+resultblock);
IndexedMatrixValue out=cachedValues.getFirst(aggBinInstruction.output);
if(out==null)
{
out=cachedValues.holdPlace(aggBinInstruction.output, valueClass);
out.getIndexes().setIndexes(indexes);
OperationsOnMatrixValues.startAggregation(out.getValue(), null, ((AggregateBinaryOperator) aggBinInstruction.getOperator()).aggOp,
resultblock.getNumRows(), resultblock.getNumColumns(), resultblock.isInSparseFormat(), false);
}
OperationsOnMatrixValues.incrementalAggregation(out.getValue(), null, resultblock,
((AggregateBinaryOperator) aggBinInstruction.getOperator()).aggOp, false);
// System.out.println("agg: \n"+out.getValue());
} catch (Exception e) {
throw new IOException(e);
}
}
}
public void close() throws IOException
{
long start=System.currentTimeMillis();
// System.out.println("cacheValues before processReducerInstructions: \n"+cachedValues);
//perform mixed operations
processReducerInstructions();
// System.out.println("cacheValues before output: \n"+cachedValues);
//output results
outputResultsFromCachedValues(cachedReporter);
if(cachedReporter!=null)
cachedReporter.incrCounter(Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis()-start);
super.close();
}
public void configure(JobConf job)
{
super.configure(job);
try {
aggBinInstructions = MRJobConfiguration.getAggregateBinaryInstructions(job);
} catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
}
}