/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.matrix.mapred; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.OutputCollector; import org.apache.hadoop.mapred.Reducer; import org.apache.hadoop.mapred.Reporter; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.CM; import org.apache.sysml.runtime.functionobjects.COV; import org.apache.sysml.runtime.functionobjects.ValueFunction; import org.apache.sysml.runtime.instructions.cp.CM_COV_Object; import org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction; import org.apache.sysml.runtime.matrix.data.CM_N_COVCell; import org.apache.sysml.runtime.matrix.data.MatrixCell; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.data.TaggedFirstSecondIndexes; import org.apache.sysml.runtime.matrix.operators.CMOperator; import org.apache.sysml.runtime.matrix.operators.COVOperator; public class CMCOVMRReducer extends ReduceBase implements Reducer<TaggedFirstSecondIndexes, MatrixValue, MatrixIndexes, MatrixValue> { private CM_N_COVInstruction[] cmNcovInstructions=null; private CM_N_COVCell cmNcovCell=new CM_N_COVCell(); private COV covFn=COV.getCOMFnObject(); private HashMap<Byte, CM> cmFn = new HashMap<Byte, CM>(); private MatrixIndexes outIndex=new MatrixIndexes(1, 1); private MatrixCell outCell=new MatrixCell(); private HashMap<Byte, ArrayList<Integer>> outputIndexesMapping=new HashMap<Byte, ArrayList<Integer>>(); protected HashSet<Byte> covTags=new HashSet<Byte>(); private CM_COV_Object zeroObj=null; //the dimension for all the representative matrices //(they are all the same, since coming from the same files) protected HashMap<Byte, Long> rlens=null; protected HashMap<Byte, Long> clens=null; @Override public void reduce(TaggedFirstSecondIndexes index, Iterator<MatrixValue> values, OutputCollector<MatrixIndexes, MatrixValue> out, Reporter report) throws IOException { commonSetup(report); cmNcovCell.setCM_N_COVObject(0, 0, 0); ValueFunction fn=cmFn.get(index.getTag()); if(covTags.contains(index.getTag())) fn=covFn; while(values.hasNext()) { CM_N_COVCell cell=(CM_N_COVCell) values.next(); try { fn.execute(cmNcovCell.getCM_N_COVObject(), cell.getCM_N_COVObject()); } catch (DMLRuntimeException e) { throw new IOException(e); } } //add 0 values back in /* long totaln=rlens.get(index.getTag())*clens.get(index.getTag()); long zerosToAdd=totaln-(long)(cmNcovCell.getCM_N_COVObject().w); for(long i=0; i<zerosToAdd; i++) { try { fn.execute(cmNcovCell.getCM_N_COVObject(), zeroObj); } catch (DMLRuntimeException e) { throw new IOException(e); } }*/ long totaln=rlens.get(index.getTag())*clens.get(index.getTag()); long zerosToAdd=totaln-(long)(cmNcovCell.getCM_N_COVObject().w); if(zerosToAdd>0) { zeroObj.w=zerosToAdd; try { fn.execute(cmNcovCell.getCM_N_COVObject(), zeroObj); } catch (DMLRuntimeException e) { throw new IOException(e); } } for(CM_N_COVInstruction in: cmNcovInstructions) { if(in.input==index.getTag()) { try { outCell.setValue(cmNcovCell.getCM_N_COVObject().getRequiredResult(in.getOperator())); } catch (DMLRuntimeException e) { throw new IOException(e); } ArrayList<Integer> outputIndexes = outputIndexesMapping.get(in.output); for(int i: outputIndexes) { collectOutput_N_Increase_Counter(outIndex, outCell, i, report); // System.out.println("final output: "+outIndex+" -- "+outCell); } } } } public void configure(JobConf job) { super.configure(job); try { cmNcovInstructions=MRJobConfiguration.getCM_N_COVInstructions(job); } catch (Exception e) { throw new RuntimeException(e); } rlens=new HashMap<Byte, Long>(); clens=new HashMap<Byte, Long>(); for(CM_N_COVInstruction ins: cmNcovInstructions) { if(ins.getOperator() instanceof COVOperator) covTags.add(ins.input); else //CMOperator cmFn.put(ins.input, CM.getCMFnObject(((CMOperator)ins.getOperator()).getAggOpType())); outputIndexesMapping.put(ins.output, getOutputIndexes(ins.output)); rlens.put(ins.input, MRJobConfiguration.getNumRows(job, ins.input)); clens.put(ins.input, MRJobConfiguration.getNumColumns(job, ins.input)); } zeroObj=new CM_COV_Object(); zeroObj.w=1; } }