/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.instructions.spark; import java.util.List; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; import org.apache.sysml.lops.PickByCount.OperationTypes; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.operators.Operator; import org.apache.sysml.runtime.util.UtilFunctions; public class QuantilePickSPInstruction extends BinarySPInstruction { private OperationTypes _type = null; public QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand out, OperationTypes type, boolean inmem, String opcode, String istr){ this(op, in, null, out, type, inmem, opcode, istr); } public QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand in2, CPOperand out, OperationTypes type, boolean inmem, String opcode, String istr){ super(op, in, in2, out, opcode, istr); _sptype = SPINSTRUCTION_TYPE.QPick; _type = type; //inmem ignored here } public static QuantilePickSPInstruction parseInstruction ( String str ) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; //sanity check opcode if ( !opcode.equalsIgnoreCase("qpick") ) { throw new DMLRuntimeException("Unknown opcode while parsing a QuantilePickCPInstruction: " + str); } //instruction parsing if( parts.length == 4 ) { //instructions of length 4 originate from unary - mr-iqm //TODO this should be refactored to use pickvaluecount lops CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand out = new CPOperand(parts[3]); OperationTypes ptype = OperationTypes.IQM; boolean inmem = false; return new QuantilePickSPInstruction(null, in1, in2, out, ptype, inmem, opcode, str); } else if( parts.length == 5 ) { CPOperand in1 = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[2]); OperationTypes ptype = OperationTypes.valueOf(parts[3]); boolean inmem = Boolean.parseBoolean(parts[4]); return new QuantilePickSPInstruction(null, in1, out, ptype, inmem, opcode, str); } else if( parts.length == 6 ) { CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand out = new CPOperand(parts[3]); OperationTypes ptype = OperationTypes.valueOf(parts[4]); boolean inmem = Boolean.parseBoolean(parts[5]); return new QuantilePickSPInstruction(null, in1, in2, out, ptype, inmem, opcode, str); } return null; } @Override public void processInstruction(ExecutionContext ec) throws DMLRuntimeException { SparkExecutionContext sec = (SparkExecutionContext)ec; MatrixCharacteristics mc = sec.getMatrixCharacteristics(input1.getName()); boolean weighted = (mc.getCols()==2); //get input rdds JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); //NOTE: no difference between inmem/mr pick (see related cp instruction), but wrt w/ w/o weights //(in contrast to cp instructions, w/o weights does not materializes weights of 1) switch( _type ) { case VALUEPICK: { double sum_wt = weighted ? sumWeights(in) : mc.getRows(); ScalarObject quantile = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()); long key = (long)Math.ceil(quantile.getDoubleValue()*sum_wt); double val = lookupKey(in, key, mc.getRowsPerBlock()); ec.setScalarOutput(output.getName(), new DoubleObject(val)); break; } case MEDIAN: { double sum_wt = weighted ? sumWeights(in) : mc.getRows(); long key = (long)Math.ceil(0.5*sum_wt); double val = lookupKey(in, key, mc.getRowsPerBlock()); ec.setScalarOutput(output.getName(), new DoubleObject(val)); break; } case IQM: { double sum_wt = weighted ? sumWeights(in) : mc.getRows(); long key25 = (long)Math.ceil(0.25*sum_wt); long key75 = (long)Math.ceil(0.75*sum_wt); double val25 = lookupKey(in, key25, mc.getRowsPerBlock()); double val75 = lookupKey(in, key75, mc.getRowsPerBlock()); JavaPairRDD<MatrixIndexes,MatrixBlock> out = in .filter(new FilterFunction(key25+1,key75,mc.getRowsPerBlock())) .mapToPair(new ExtractAndSumFunction(key25+1, key75, mc.getRowsPerBlock())); MatrixBlock mb = RDDAggregateUtils.sumStable(out); double val = (mb.getValue(0, 0) + (key25-0.25*sum_wt)*val25 - (key75-0.75*sum_wt)*val75) / (0.5*sum_wt); ec.setScalarOutput(output.getName(), new DoubleObject(val)); break; } default: throw new DMLRuntimeException("Unsupported qpick operation type: "+_type); } } private double lookupKey(JavaPairRDD<MatrixIndexes,MatrixBlock> in, long key, int brlen) { long rix = UtilFunctions.computeBlockIndex(key, brlen); long pos = UtilFunctions.computeCellInBlock(key, brlen); List<MatrixBlock> val = in.lookup(new MatrixIndexes(rix,1)); return val.get(0).quickGetValue((int)pos, 0); } private double sumWeights(JavaPairRDD<MatrixIndexes,MatrixBlock> in) { JavaPairRDD<MatrixIndexes,MatrixBlock> tmp = in .mapValues(new ExtractAndSumWeightsFunction()); MatrixBlock val = RDDAggregateUtils.sumStable(tmp); return val.quickGetValue(0, 0); } private static class FilterFunction implements Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> { private static final long serialVersionUID = -8249102381116157388L; //boundary keys (inclusive) private long _minRowIndex; private long _maxRowIndex; public FilterFunction(long key25, long key75, int brlen) { _minRowIndex = UtilFunctions.computeBlockIndex(key25, brlen); _maxRowIndex = UtilFunctions.computeBlockIndex(key75, brlen); } @Override public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception { long rowIndex = arg0._1().getRowIndex(); return (rowIndex>=_minRowIndex && rowIndex<=_maxRowIndex); } } private static class ExtractAndSumFunction implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock> { private static final long serialVersionUID = -584044441055250489L; //boundary keys (inclusive) private long _minRowIndex; private long _maxRowIndex; private int _minPos; private int _maxPos; public ExtractAndSumFunction(long key25, long key75, int brlen) { _minRowIndex = UtilFunctions.computeBlockIndex(key25, brlen); _maxRowIndex = UtilFunctions.computeBlockIndex(key75, brlen); _minPos = UtilFunctions.computeCellInBlock(key25, brlen); _maxPos = UtilFunctions.computeCellInBlock(key75, brlen); } @Override public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception { MatrixIndexes ix = arg0._1(); MatrixBlock mb = arg0._2(); if( _minRowIndex==_maxRowIndex ){ mb = mb.sliceOperations(_minPos-1, _maxPos-1, 0, 0, new MatrixBlock()); } else if( ix.getRowIndex() == _minRowIndex ) { mb = mb.sliceOperations(_minPos, mb.getNumRows()-1, 0, 0, new MatrixBlock()); } else if( ix.getRowIndex() == _maxRowIndex ) { mb = mb.sliceOperations(0, _maxPos, 0, 0, new MatrixBlock()); } //create output (with correction) MatrixBlock ret = new MatrixBlock(1,2,false); ret.setValue(0, 0, mb.sum()); return new Tuple2<MatrixIndexes,MatrixBlock>(new MatrixIndexes(1,1), ret); } } private static class ExtractAndSumWeightsFunction implements Function<MatrixBlock,MatrixBlock> { private static final long serialVersionUID = 7169831202450745373L; @Override public MatrixBlock call(MatrixBlock arg0) throws Exception { //slice operation (2nd column) MatrixBlock mb = arg0.sliceOperations(0, arg0.getNumRows()-1, 1, 1, new MatrixBlock()); //create output (with correction) MatrixBlock ret = new MatrixBlock(1,2,false); ret.setValue(0, 0, mb.sum()); return ret; } } }