/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.matrix.mapred; import java.io.IOException; import java.util.HashMap; import java.util.HashSet; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.Mapper; import org.apache.hadoop.mapred.OutputCollector; import org.apache.hadoop.mapred.Reporter; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.CM; import org.apache.sysml.runtime.functionobjects.COV; import org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction; import org.apache.sysml.runtime.matrix.data.CM_N_COVCell; import org.apache.sysml.runtime.matrix.data.TaggedFirstSecondIndexes; import org.apache.sysml.runtime.matrix.data.WeightedPair; import org.apache.sysml.runtime.matrix.operators.CMOperator; import org.apache.sysml.runtime.matrix.operators.COVOperator; public class CMCOVMRMapper extends MapperBase implements Mapper<Writable, Writable, Writable, Writable> { private boolean firsttime=true; private HashMap<Byte, CM> cmFn = new HashMap<Byte, CM>(); private COV covFn=COV.getCOMFnObject(); private OutputCollector<Writable, Writable> cachedCollector=null; private CachedValueMap cmNcovCache=new CachedValueMap(); protected HashSet<Byte> cmTags=new HashSet<Byte>(); protected HashSet<Byte> covTags=new HashSet<Byte>(); @Override public void map(Writable index, Writable cell, OutputCollector<Writable, Writable> out, Reporter report) throws IOException { if(firsttime) { cachedCollector=out; firsttime=false; } // System.out.println("input: "+index+" -- "+cell); commonMap(index, cell, out, report); } @Override protected void specialOperationsForActualMap(int index, OutputCollector<Writable, Writable> out, Reporter reporter) throws IOException { //apply all instructions processMapperInstructionsForMatrix(index); for(byte tag: cmTags) { CM lcmFn = cmFn.get(tag); IndexedMatrixValue input = cachedValues.getFirst(tag); if(input==null) continue; WeightedPair inputPair=(WeightedPair)input.getValue(); CM_N_COVCell cmValue = (CM_N_COVCell) cmNcovCache.getFirst(tag).getValue(); try { // System.out.println("~~~~~\nold: "+cmValue.getCM_N_COVObject()); // System.out.println("add: "+inputPair); lcmFn.execute(cmValue.getCM_N_COVObject(), inputPair.getValue(), inputPair.getWeight()); // System.out.println("new: "+cmValue.getCM_N_COVObject()); } catch (DMLRuntimeException e) { throw new IOException(e); } } for(byte tag: covTags) { IndexedMatrixValue input = cachedValues.getFirst(tag); if(input==null) continue; //System.out.println("*** cached Value:\n"+cachedValues); WeightedPair inputPair=(WeightedPair)input.getValue(); CM_N_COVCell comValue = (CM_N_COVCell) cmNcovCache.getFirst(tag).getValue(); try { //System.out.println("~~~~~\nold: "+comValue.getCM_N_COVObject()); // System.out.println("add: "+inputPair); covFn.execute(comValue.getCM_N_COVObject(), inputPair.getValue(), inputPair.getOtherValue(), inputPair.getWeight()); // System.out.println("new: "+comValue.getCM_N_COVObject()); } catch (DMLRuntimeException e) { throw new IOException(e); } } } public void close() throws IOException { if(cachedCollector!=null) { for(byte tag: cmTags) { CM_N_COVCell cmValue = (CM_N_COVCell) cmNcovCache.getFirst(tag).getValue(); cachedCollector.collect(new TaggedFirstSecondIndexes(1, tag, 1), cmValue); //System.out.println("output to reducer with tag:"+tag+" and value: "+cmValue); } for(byte tag: covTags) { CM_N_COVCell comValue = (CM_N_COVCell) cmNcovCache.getFirst(tag).getValue(); cachedCollector.collect(new TaggedFirstSecondIndexes(1, tag, 1), comValue); //System.out.println("output to reducer with tag:"+tag+" and value: "+comValue); } } } public void configure(JobConf job) { super.configure(job); try { CM_N_COVInstruction[] cmIns=MRJobConfiguration.getCM_N_COVInstructions(job); for(CM_N_COVInstruction ins: cmIns) { if(ins.getOperator() instanceof COVOperator) covTags.add(ins.input); else //CMOperator { cmTags.add(ins.input); cmFn.put(ins.input, CM.getCMFnObject(((CMOperator)ins.getOperator()).getAggOpType())); } } } catch (Exception e) { throw new RuntimeException(e); } for(byte tag: cmTags) { cmNcovCache.holdPlace(tag, CM_N_COVCell.class); } for(byte tag: covTags) { cmNcovCache.holdPlace(tag, CM_N_COVCell.class); } } }