/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.matrix; import java.io.Serializable; import java.util.HashMap; import org.apache.sysml.lops.MMTSJ.MMTSJType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction; import org.apache.sysml.runtime.instructions.mr.AggregateInstruction; import org.apache.sysml.runtime.instructions.mr.AggregateUnaryInstruction; import org.apache.sysml.runtime.instructions.mr.AppendInstruction; import org.apache.sysml.runtime.instructions.mr.BinaryInstruction; import org.apache.sysml.runtime.instructions.mr.BinaryMInstruction; import org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase; import org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction; import org.apache.sysml.runtime.instructions.mr.CombineBinaryInstruction; import org.apache.sysml.runtime.instructions.mr.CombineTernaryInstruction; import org.apache.sysml.runtime.instructions.mr.CombineUnaryInstruction; import org.apache.sysml.runtime.instructions.mr.CumulativeAggregateInstruction; import org.apache.sysml.runtime.instructions.mr.DataGenMRInstruction; import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction; import org.apache.sysml.runtime.instructions.mr.GroupedAggregateMInstruction; import org.apache.sysml.runtime.instructions.mr.MMTSJMRInstruction; import org.apache.sysml.runtime.instructions.mr.MRInstruction; import org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction; import org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction; import org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction; import org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction; import org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction; import org.apache.sysml.runtime.instructions.mr.RandInstruction; import org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction; import org.apache.sysml.runtime.instructions.mr.ReblockInstruction; import org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction; import org.apache.sysml.runtime.instructions.mr.ReorgInstruction; import org.apache.sysml.runtime.instructions.mr.ReplicateInstruction; import org.apache.sysml.runtime.instructions.mr.ScalarInstruction; import org.apache.sysml.runtime.instructions.mr.SeqInstruction; import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; import org.apache.sysml.runtime.instructions.mr.UaggOuterChainInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase; import org.apache.sysml.runtime.instructions.mr.ZeroOutInstruction; import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; public class MatrixCharacteristics implements Serializable { private static final long serialVersionUID = 8300479822915546000L; private long numRows = -1; private long numColumns = -1; private int numRowsPerBlock = 1; private int numColumnsPerBlock = 1; private long nonZero = -1; public MatrixCharacteristics() { } public MatrixCharacteristics(long nr, long nc, int bnr, int bnc) { set(nr, nc, bnr, bnc); } public MatrixCharacteristics(long nr, long nc, int bnr, int bnc, long nnz) { set(nr, nc, bnr, bnc, nnz); } public MatrixCharacteristics(MatrixCharacteristics that) { set(that.numRows, that.numColumns, that.numRowsPerBlock, that.numColumnsPerBlock, that.nonZero); } public void set(long nr, long nc, int bnr, int bnc) { numRows = nr; numColumns = nc; numRowsPerBlock = bnr; numColumnsPerBlock = bnc; } public void set(long nr, long nc, int bnr, int bnc, long nnz) { numRows = nr; numColumns = nc; numRowsPerBlock = bnr; numColumnsPerBlock = bnc; nonZero = nnz; } public void set(MatrixCharacteristics that) { numRows = that.numRows; numColumns = that.numColumns; numRowsPerBlock = that.numRowsPerBlock; numColumnsPerBlock = that.numColumnsPerBlock; nonZero = that.nonZero; } public long getRows(){ return numRows; } public long getCols(){ return numColumns; } public int getRowsPerBlock() { return numRowsPerBlock; } public void setRowsPerBlock( int brlen){ numRowsPerBlock = brlen; } public int getColsPerBlock() { return numColumnsPerBlock; } public void setColsPerBlock( int bclen){ numColumnsPerBlock = bclen; } public long getNumBlocks() { return getNumRowBlocks() * getNumColBlocks(); } public long getNumRowBlocks() { return (long) Math.ceil((double)getRows() / getRowsPerBlock()); } public long getNumColBlocks() { return (long) Math.ceil((double)getCols() / getColsPerBlock()); } public String toString() { return "["+numRows+" x "+numColumns+", nnz="+nonZero +", blocks ("+numRowsPerBlock+" x "+numColumnsPerBlock+")]"; } public void setDimension(long nr, long nc) { numRows = nr; numColumns = nc; } public void setBlockSize(int blen) { setBlockSize(blen, blen); } public void setBlockSize(int bnr, int bnc) { numRowsPerBlock = bnr; numColumnsPerBlock = bnc; } public void setNonZeros(long nnz) { nonZero = nnz; } public long getNonZeros() { return nonZero; } public boolean dimsKnown() { return ( numRows > 0 && numColumns > 0 ); } public boolean dimsKnown(boolean includeNnz) { return ( numRows > 0 && numColumns > 0 && (!includeNnz || nonZero>=0)); } public boolean rowsKnown() { return ( numRows > 0 ); } public boolean colsKnown() { return ( numColumns > 0 ); } public boolean nnzKnown() { return ( nonZero >= 0 ); } public boolean mightHaveEmptyBlocks() { long singleBlk = Math.min(numRows, numRowsPerBlock) * Math.min(numColumns, numColumnsPerBlock); return !nnzKnown() || (nonZero < numRows*numColumns - singleBlk); } public static void reorg(MatrixCharacteristics dim, ReorgOperator op, MatrixCharacteristics dimOut) throws DMLRuntimeException { op.fn.computeDimension(dim, dimOut); } public static void aggregateUnary(MatrixCharacteristics dim, AggregateUnaryOperator op, MatrixCharacteristics dimOut) throws DMLRuntimeException { op.indexFn.computeDimension(dim, dimOut); } public static void aggregateBinary(MatrixCharacteristics dim1, MatrixCharacteristics dim2, AggregateBinaryOperator op, MatrixCharacteristics dimOut) { //set dimension dimOut.set(dim1.numRows, dim2.numColumns, dim1.numRowsPerBlock, dim2.numColumnsPerBlock); } public static void computeDimension(HashMap<Byte, MatrixCharacteristics> dims, MRInstruction ins) throws DMLRuntimeException { MatrixCharacteristics dimOut=dims.get(ins.output); if(dimOut==null) { dimOut=new MatrixCharacteristics(); dims.put(ins.output, dimOut); } if(ins instanceof ReorgInstruction) { ReorgInstruction realIns=(ReorgInstruction)ins; reorg(dims.get(realIns.input), (ReorgOperator)realIns.getOperator(), dimOut); } else if(ins instanceof AppendInstruction ) { AppendInstruction realIns = (AppendInstruction)ins; MatrixCharacteristics in_dim1 = dims.get(realIns.input1); MatrixCharacteristics in_dim2 = dims.get(realIns.input2); if( realIns.isCBind() ) dimOut.set(in_dim1.numRows, in_dim1.numColumns+in_dim2.numColumns, in_dim1.numRowsPerBlock, in_dim2.numColumnsPerBlock); else dimOut.set(in_dim1.numRows+in_dim2.numRows, in_dim1.numColumns, in_dim1.numRowsPerBlock, in_dim2.numColumnsPerBlock); } else if(ins instanceof CumulativeAggregateInstruction) { AggregateUnaryInstruction realIns=(AggregateUnaryInstruction)ins; MatrixCharacteristics in = dims.get(realIns.input); dimOut.set((long)Math.ceil( (double)in.getRows()/in.getRowsPerBlock()), in.getCols(), in.getRowsPerBlock(), in.getColsPerBlock()); } else if(ins instanceof AggregateUnaryInstruction) { AggregateUnaryInstruction realIns=(AggregateUnaryInstruction)ins; aggregateUnary(dims.get(realIns.input), (AggregateUnaryOperator)realIns.getOperator(), dimOut); } else if(ins instanceof AggregateBinaryInstruction) { AggregateBinaryInstruction realIns=(AggregateBinaryInstruction)ins; aggregateBinary(dims.get(realIns.input1), dims.get(realIns.input2), (AggregateBinaryOperator)realIns.getOperator(), dimOut); } else if(ins instanceof MapMultChainInstruction) { //output size independent of chain type MapMultChainInstruction realIns=(MapMultChainInstruction)ins; MatrixCharacteristics mc1 = dims.get(realIns.getInput1()); MatrixCharacteristics mc2 = dims.get(realIns.getInput2()); dimOut.set(mc1.numColumns, mc2.numColumns, mc1.numRowsPerBlock, mc1.numColumnsPerBlock); } else if(ins instanceof QuaternaryInstruction) { QuaternaryInstruction realIns=(QuaternaryInstruction)ins; MatrixCharacteristics mc1 = dims.get(realIns.getInput1()); MatrixCharacteristics mc2 = dims.get(realIns.getInput2()); MatrixCharacteristics mc3 = dims.get(realIns.getInput3()); realIns.computeMatrixCharacteristics(mc1, mc2, mc3, dimOut); } else if(ins instanceof ReblockInstruction) { ReblockInstruction realIns=(ReblockInstruction)ins; MatrixCharacteristics in_dim=dims.get(realIns.input); dimOut.set(in_dim.numRows, in_dim.numColumns, realIns.brlen, realIns.bclen, in_dim.nonZero); } else if( ins instanceof MatrixReshapeMRInstruction ) { MatrixReshapeMRInstruction mrinst = (MatrixReshapeMRInstruction) ins; MatrixCharacteristics in_dim=dims.get(mrinst.input); dimOut.set(mrinst.getNumRows(),mrinst.getNumColunms(),in_dim.getRowsPerBlock(), in_dim.getColsPerBlock(), in_dim.getNonZeros()); } else if(ins instanceof RandInstruction || ins instanceof SeqInstruction) { DataGenMRInstruction dataIns=(DataGenMRInstruction)ins; dimOut.set(dims.get(dataIns.getInput())); } else if( ins instanceof ReplicateInstruction ) { ReplicateInstruction realIns=(ReplicateInstruction)ins; realIns.computeOutputDimension(dims.get(realIns.input), dimOut); } else if( ins instanceof ParameterizedBuiltinMRInstruction ) //before unary { ParameterizedBuiltinMRInstruction realIns = (ParameterizedBuiltinMRInstruction)ins; realIns.computeOutputCharacteristics(dims.get(realIns.input), dimOut); } else if(ins instanceof ScalarInstruction || ins instanceof AggregateInstruction ||(ins instanceof UnaryInstruction && !(ins instanceof MMTSJMRInstruction)) || ins instanceof ZeroOutInstruction) { UnaryMRInstructionBase realIns=(UnaryMRInstructionBase)ins; dimOut.set(dims.get(realIns.input)); } else if (ins instanceof MMTSJMRInstruction) { MMTSJMRInstruction mmtsj = (MMTSJMRInstruction)ins; MMTSJType tstype = mmtsj.getMMTSJType(); MatrixCharacteristics mc = dims.get(mmtsj.input); dimOut.set( tstype.isLeft() ? mc.numColumns : mc.numRows, tstype.isLeft() ? mc.numColumns : mc.numRows, mc.numRowsPerBlock, mc.numColumnsPerBlock ); } else if( ins instanceof PMMJMRInstruction ) { PMMJMRInstruction pmmins = (PMMJMRInstruction) ins; MatrixCharacteristics mc = dims.get(pmmins.input2); dimOut.set( pmmins.getNumRows(), mc.numColumns, mc.numRowsPerBlock, mc.numColumnsPerBlock ); } else if( ins instanceof RemoveEmptyMRInstruction ) { RemoveEmptyMRInstruction realIns=(RemoveEmptyMRInstruction)ins; MatrixCharacteristics mc = dims.get(realIns.input1); if( realIns.isRemoveRows() ) dimOut.set(realIns.getOutputLen(), mc.getCols(), mc.numRowsPerBlock, mc.numColumnsPerBlock); else dimOut.set(mc.getRows(), realIns.getOutputLen(), mc.numRowsPerBlock, mc.numColumnsPerBlock); } else if(ins instanceof UaggOuterChainInstruction) //needs to be checked before binary { UaggOuterChainInstruction realIns=(UaggOuterChainInstruction)ins; MatrixCharacteristics mc1 = dims.get(realIns.input1); MatrixCharacteristics mc2 = dims.get(realIns.input2); realIns.computeOutputCharacteristics(mc1, mc2, dimOut); } else if( ins instanceof GroupedAggregateMInstruction ) { GroupedAggregateMInstruction realIns = (GroupedAggregateMInstruction) ins; MatrixCharacteristics mc1 = dims.get(realIns.input1); realIns.computeOutputCharacteristics(mc1, dimOut); } else if(ins instanceof BinaryInstruction || ins instanceof BinaryMInstruction || ins instanceof CombineBinaryInstruction ) { BinaryMRInstructionBase realIns=(BinaryMRInstructionBase)ins; MatrixCharacteristics mc1 = dims.get(realIns.input1); MatrixCharacteristics mc2 = dims.get(realIns.input2); if( mc1.getRows()>1 && mc1.getCols()==1 && mc2.getRows()==1 && mc2.getCols()>1 ) //outer { dimOut.set(mc1.getRows(), mc2.getCols(), mc1.getRowsPerBlock(), mc2.getColsPerBlock()); } else { //default case dimOut.set(mc1); } } else if (ins instanceof CombineTernaryInstruction ) { TernaryInstruction realIns=(TernaryInstruction)ins; dimOut.set(dims.get(realIns.input1)); } else if (ins instanceof CombineUnaryInstruction ) { dimOut.set( dims.get(((CombineUnaryInstruction) ins).input)); } else if(ins instanceof CM_N_COVInstruction || ins instanceof GroupedAggregateInstruction ) { dimOut.set(1, 1, 1, 1); } else if(ins instanceof RangeBasedReIndexInstruction) { RangeBasedReIndexInstruction realIns = (RangeBasedReIndexInstruction)ins; MatrixCharacteristics dimIn = dims.get(realIns.input); realIns.computeOutputCharacteristics(dimIn, dimOut); } else if (ins instanceof TernaryInstruction) { TernaryInstruction realIns = (TernaryInstruction)ins; MatrixCharacteristics in_dim=dims.get(realIns.input1); dimOut.set(realIns.getOutputDim1(), realIns.getOutputDim2(), in_dim.numRowsPerBlock, in_dim.numColumnsPerBlock); } else { /* * if ins is none of the above cases then we assume that dim_out dimensions are unknown */ dimOut.numRows = -1; dimOut.numColumns = -1; dimOut.numRowsPerBlock=1; dimOut.numColumnsPerBlock=1; } } @Override public boolean equals (Object anObject) { if (anObject instanceof MatrixCharacteristics) { MatrixCharacteristics mc = (MatrixCharacteristics) anObject; return ((numRows == mc.numRows) && (numColumns == mc.numColumns) && (numRowsPerBlock == mc.numRowsPerBlock) && (numColumnsPerBlock == mc.numColumnsPerBlock) && (nonZero == mc.nonZero)) ; } else return false; } @Override public int hashCode() { return super.hashCode(); } }