/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.runtime.matrix.mapred;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import com.ibm.bi.dml.hops.OptimizerUtils;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
import com.ibm.bi.dml.runtime.matrix.data.MatrixIndexes;
import com.ibm.bi.dml.runtime.matrix.data.MatrixValue;
import com.ibm.bi.dml.runtime.matrix.data.OperationsOnMatrixValues;
import com.ibm.bi.dml.runtime.matrix.data.TaggedFirstSecondIndexes;
import com.ibm.bi.dml.runtime.matrix.operators.AggregateBinaryOperator;
import com.ibm.bi.dml.runtime.util.MapReduceTool;
public class MMCJMRReducer extends MMCJMRCombinerReducerBase
implements Reducer<TaggedFirstSecondIndexes, MatrixValue, Writable, Writable>
{
private static class RemainIndexValue
{
public long remainIndex;
public MatrixValue value;
private Class<? extends MatrixValue> valueClass;
public RemainIndexValue(Class<? extends MatrixValue> cls) throws Exception
{
remainIndex=-1;
valueClass=cls;
value=valueClass.newInstance();
}
public RemainIndexValue(long ind, MatrixValue b) throws Exception
{
remainIndex=ind;
valueClass=b.getClass();
value=valueClass.newInstance();
value.copy(b);
}
public void set(long ind, MatrixValue b)
{
remainIndex=ind;
value.copy(b);
}
}
//in memory cache to hold the records from one input matrix for the cross product
private ArrayList<RemainIndexValue> cache=new ArrayList<RemainIndexValue>(100);
private int cacheSize=0;
//to cache output, so that we can do some partial aggregation here
private long OUT_CACHE_SIZE;
private HashMap<MatrixIndexes, MatrixValue> outCache;
//variables to keep track of the flow
private double prevFirstIndex=-1;
private int prevTag=-1;
//temporary variable
private MatrixIndexes indexesbuffer=new MatrixIndexes();
private RemainIndexValue remainingbuffer=null;
private MatrixValue valueBuffer=null;
private boolean outputDummyRecords = false;
@Override
public void reduce(TaggedFirstSecondIndexes indexes, Iterator<MatrixValue> values,
OutputCollector<Writable, Writable> out,
Reporter report) throws IOException {
long start=System.currentTimeMillis();
// LOG.info("---------- key: "+indexes);
commonSetup(report);
//perform aggregate
MatrixValue aggregateValue=performAggregateInstructions(indexes, values);
if(aggregateValue==null)
return;
int tag=indexes.getTag();
long firstIndex=indexes.getFirstIndex();
long secondIndex=indexes.getSecondIndex();
//for a different k
if(prevFirstIndex!=firstIndex)
{
resetCache();
prevFirstIndex=firstIndex;
}else if(prevTag>tag)
throw new RuntimeException("tag is not ordered correctly: "+prevTag+" > "+tag);
remainingbuffer.set(secondIndex, aggregateValue);
try {
processJoin(tag, remainingbuffer);
} catch (Exception e) {
throw new IOException(e);
}
prevTag=tag;
report.incrCounter(Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis()-start);
}
private void processJoin(int tag, RemainIndexValue rValue)
throws Exception
{
//for the cached matrix
if(tag==0)
{
addToCache(rValue, tag);
// LOG.info("put in the buffer for left matrix");
// LOG.info(rblock.block.toString());
}
else//for the probing matrix
{
//LOG.info("process join with block size: "+rValue.value.getNumRows()+" X "+rValue.value.getNumColumns()+" nonZeros: "+rValue.value.getNonZeros());
for(int i=0; i<cacheSize; i++)
{
RemainIndexValue left, right;
if(tagForLeft==0)
{
left=cache.get(i);
right=rValue;
}else
{
right=cache.get(i);
left=rValue;
}
indexesbuffer.setIndexes(left.remainIndex, right.remainIndex);
try {
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(left.value,
right.value, valueBuffer, (AggregateBinaryOperator)aggBinInstruction.getOperator());
} catch (DMLUnsupportedOperationException e) {
throw new IOException(e);
}
//if(valueBuffer.getNonZeros()>0)
collectOutput(indexesbuffer, valueBuffer);
}
}
}
private void collectOutput(MatrixIndexes indexes,
MatrixValue value_out)
throws Exception
{
MatrixValue value=outCache.get(indexes);
try {
if(value!=null)
{
//LOG.info("********** oops, should not run this code1 ***********");
/* LOG.info("the output is in the cache");
LOG.info("old block");
LOG.info(block.toString());
*/
value.binaryOperationsInPlace(((AggregateBinaryOperator)aggBinInstruction.getOperator()).aggOp.increOp,
value_out);
/* LOG.info("add block");
LOG.info(block_out.toString());
LOG.info("result block");
LOG.info(block.toString());
*/
}
else if(outCache.size()<OUT_CACHE_SIZE)
{
//LOG.info("********** oops, should not run this code2 ***********");
value=valueClass.newInstance();
value.reset(value_out.getNumRows(), value_out.getNumColumns(), value.isInSparseFormat());
value.binaryOperationsInPlace(((AggregateBinaryOperator)aggBinInstruction.getOperator()).aggOp.increOp,
value_out);
outCache.put(new MatrixIndexes(indexes), value);
/* LOG.info("the output is not in the cache");
LOG.info("result block");
LOG.info(block.toString());
*/
}else
{
realWriteToCollector(indexes, value_out);
}
} catch (DMLUnsupportedOperationException e) {
throw new IOException(e);
}
}
private void resetCache() {
cacheSize=0;
}
private void addToCache(RemainIndexValue rValue, int tag) throws Exception {
//LOG.info("add to cache with block size: "+rValue.value.getNumRows()+" X "+rValue.value.getNumColumns()+" nonZeros: "+rValue.value.getNonZeros());
if(cacheSize<cache.size())
cache.get(cacheSize).set(rValue.remainIndex, rValue.value);
else
cache.add(new RemainIndexValue(rValue.remainIndex, rValue.value));
cacheSize++;
}
//output the records in the outCache.
public void close() throws IOException
{
long start=System.currentTimeMillis();
Iterator<Entry<MatrixIndexes, MatrixValue>> it=outCache.entrySet().iterator();
while(it.hasNext())
{
Entry<MatrixIndexes, MatrixValue> entry=it.next();
realWriteToCollector(entry.getKey(), entry.getValue());
}
//handle empty block output (on first reduce task only)
if( outputDummyRecords ) //required for rejecting empty blocks in mappers
{
long rlen = dim1.getRows();
long clen = dim2.getCols();
int brlen = dim1.getRowsPerBlock();
int bclen = dim2.getColsPerBlock();
MatrixIndexes tmpIx = new MatrixIndexes();
MatrixBlock tmpVal = new MatrixBlock();
for(long i=0, r=1; i<rlen; i+=brlen, r++)
for(long j=0, c=1; j<clen; j+=bclen, c++)
{
int realBrlen=(int)Math.min((long)brlen, rlen-(r-1)*brlen);
int realBclen=(int)Math.min((long)bclen, clen-(c-1)*bclen);
tmpIx.setIndexes(r, c);
tmpVal.reset(realBrlen,realBclen);
collectFinalMultipleOutputs.collectOutput(tmpIx, tmpVal, 0, cachedReporter);
}
}
if(cachedReporter!=null)
cachedReporter.incrCounter(Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis()-start);
super.close();
}
public void realWriteToCollector(MatrixIndexes indexes, MatrixValue value) throws IOException
{
collectOutput_N_Increase_Counter(indexes, value, 0, cachedReporter);
// LOG.info("--------- output: "+indexes+" <--> "+block);
/* if(count%1000==0)
{
LOG.info("result block: sparse format="+value.isInSparseFormat()+
", dimension="+value.getNumRows()+"x"+value.getNumColumns()
+", nonZeros="+value.getNonZeros());
}
count++;*/
}
public void configure(JobConf job)
{
super.configure(job);
if(resultIndexes.length>1)
throw new RuntimeException("MMCJMR only outputs one result");
outputDummyRecords = MapReduceTool.getUniqueKeyPerTask(job, false).equals("0");
try {
//valueBuffer=valueClass.newInstance();
valueBuffer=buffer;
remainingbuffer=new RemainIndexValue(valueClass);
} catch (Exception e) {
throw new RuntimeException(e);
}
int blockRlen=dim1.getRowsPerBlock();
int blockClen=dim2.getColsPerBlock();
int elementSize=(int)Math.ceil((double)(77+8*blockRlen*blockClen+20+12)/0.75);
OUT_CACHE_SIZE=((long)OptimizerUtils.getLocalMemBudget() //current jvm max mem
-MRJobConfiguration.getMMCJCacheSize(job))/elementSize;
outCache=new HashMap<MatrixIndexes, MatrixValue>(1024);
}
}