/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashSet; import org.apache.commons.lang.ArrayUtils; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; import org.apache.sysml.runtime.util.UtilFunctions; public class TemplateUtils { public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()}; public static boolean isVector(Hop hop) { return (hop.getDataType() == DataType.MATRIX && (hop.getDim1() != 1 && hop.getDim2() == 1 || hop.getDim1() == 1 && hop.getDim2() != 1 ) ); } public static boolean isColVector(CNode hop) { return (hop.getDataType() == DataType.MATRIX && hop.getNumRows() != 1 && hop.getNumCols() == 1); } public static boolean isRowVector(CNode hop) { return (hop.getDataType() == DataType.MATRIX && hop.getNumRows() == 1 && hop.getNumCols() != 1); } public static CNode wrapLookupIfNecessary(CNode node, Hop hop) { CNode ret = node; if( isColVector(node) ) ret = new CNodeUnary(node, UnaryType.LOOKUP_R); else if( isRowVector(node) ) ret = new CNodeUnary(node, UnaryType.LOOKUP_C); else if( node instanceof CNodeData && hop.getDataType().isMatrix() ) ret = new CNodeUnary(node, UnaryType.LOOKUP_RC); return ret; } public static boolean isMatrix(Hop hop) { return (hop.getDataType() == DataType.MATRIX && hop.getDim1() != 1 && hop.getDim2()!=1); } public static boolean isVectorOrScalar(Hop hop) { return hop.dimsKnown() && (hop.getDataType() == DataType.SCALAR || isVector(hop) ); } public static boolean isBinaryMatrixRowVector(Hop hop) { if( !(hop instanceof BinaryOp) ) return false; Hop left = hop.getInput().get(0); Hop right = hop.getInput().get(1); return left.dimsKnown() && right.dimsKnown() && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim1() > right.getDim1(); } public static boolean isBinaryMatrixColVector(Hop hop) { if( !(hop instanceof BinaryOp) ) return false; Hop left = hop.getInput().get(0); Hop right = hop.getInput().get(1); return left.dimsKnown() && right.dimsKnown() && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim2() > right.getDim2(); } public static boolean hasMatrixInput( Hop hop ) { for( Hop c : hop.getInput() ) if( isMatrix(c) ) return true; return false; } public static boolean isOperationSupported(Hop h) { if(h instanceof UnaryOp) return UnaryType.contains(((UnaryOp)h).getOp().name()); else if(h instanceof BinaryOp) return BinType.contains(((BinaryOp)h).getOp().name()); else if(h instanceof TernaryOp) return TernaryType.contains(((TernaryOp)h).getOp().name()); else if(h instanceof ParameterizedBuiltinOp) return TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name()); return false; } private static void rfindChildren(Hop hop, HashSet<Hop> children ) { if( hop instanceof UnaryOp || (hop instanceof BinaryOp && hop.getInput().get(0).getDataType() == DataType.MATRIX && TemplateUtils.isVectorOrScalar( hop.getInput().get(1))) || (hop instanceof BinaryOp && TemplateUtils.isVectorOrScalar( hop.getInput().get(0)) && hop.getInput().get(1).getDataType() == DataType.MATRIX) //unary operation or binary operaiton with one matrix and a scalar && hop.getDataType() == DataType.MATRIX ) { if(!children.contains(hop)) children.add(hop); Hop matrix = TemplateUtils.isMatrix(hop.getInput().get(0)) ? hop.getInput().get(0) : hop.getInput().get(1); rfindChildren(matrix,children); } else children.add(hop); } private static Hop findCommonChild(Hop hop1, Hop hop2) { //this method assumes that each two nodes have at most one common child LinkedHashSet<Hop> children1 = new LinkedHashSet<Hop>(); LinkedHashSet<Hop> children2 = new LinkedHashSet<Hop>(); rfindChildren(hop1, children1 ); rfindChildren(hop2, children2 ); //iterate on one set and find the first common child in the other set Iterator<Hop> iter = children1.iterator(); while (iter.hasNext()) { Hop candidate = iter.next(); if(children2.contains(candidate)) return candidate; } return null; } public static Hop commonChild(ArrayList<Hop> _adddedMatrices, Hop input) { Hop currentChild = null; //loop on every added matrix and find its common child with the input, if all of them have the same common child then return it, otherwise null for(Hop addedMatrix : _adddedMatrices) { Hop child = findCommonChild(addedMatrix,input); if(child == null) // did not find a common child return null; if(currentChild == null) // first common child to be seen currentChild = child; else if(child.getHopID() != currentChild.getHopID()) return null; } return currentChild; } public static HashSet<Long> rGetInputHopIDs( CNode node, HashSet<Long> ids ) { if( node instanceof CNodeData && !node.isLiteral() ) ids.add(((CNodeData)node).getHopID()); for( CNode c : node.getInput() ) rGetInputHopIDs(c, ids); return ids; } public static Hop[] mergeDistinct(HashSet<Long> ids, Hop[] input1, Hop[] input2) { Hop[] ret = new Hop[ids.size()]; int pos = 0; for( Hop[] input : new Hop[][]{input1, input2} ) for( Hop c : input ) if( ids.contains(c.getHopID()) ) ret[pos++] = c; return ret; } public static TemplateBase createTemplate(TemplateType type) { return createTemplate(type, false); } public static TemplateBase createTemplate(TemplateType type, boolean closed) { TemplateBase tpl = null; switch( type ) { case CellTpl: tpl = new TemplateCell(closed); break; case RowTpl: tpl = new TemplateRow(closed); break; case MultiAggTpl: tpl = new TemplateMultiAgg(closed); break; case OuterProdTpl: tpl = new TemplateOuterProduct(closed); break; } return tpl; } public static TemplateBase[] createCompatibleTemplates(TemplateType type, boolean closed) { TemplateBase[] tpl = null; switch( type ) { case CellTpl: tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRow(closed)}; break; case RowTpl: tpl = new TemplateBase[]{new TemplateRow(closed)}; break; case MultiAggTpl: tpl = new TemplateBase[]{new TemplateMultiAgg(closed)}; break; case OuterProdTpl: tpl = new TemplateBase[]{new TemplateOuterProduct(closed)}; break; } return tpl; } public static CellType getCellType(Hop hop) { return (hop instanceof AggBinaryOp) ? CellType.FULL_AGG : (hop instanceof AggUnaryOp) ? ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ? CellType.FULL_AGG : CellType.ROW_AGG) : CellType.NO_AGG; } public static RowType getRowType(Hop output, Hop input) { if( HopRewriteUtils.isEqualSize(output, input) ) return RowType.NO_AGG; else if( output.getDim1()==input.getDim1() && output.getDim2()==1 ) return RowType.ROW_AGG; else if( output.getDim1()==input.getDim2() && output.getDim2()==1 ) return RowType.COL_AGG_T; else return RowType.COL_AGG; } public static AggOp getAggOp(Hop hop) { return (hop instanceof AggUnaryOp) ? ((AggUnaryOp)hop).getOp() : (hop instanceof AggBinaryOp) ? AggOp.SUM : null; } public static OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) { if( out.getDataType() == DataType.SCALAR ) return OutProdType.AGG_OUTER_PRODUCT; else if( (out instanceof AggBinaryOp && (out.getInput().get(0) == U || HopRewriteUtils.isTransposeOperation(out.getInput().get(0)) && out.getInput().get(0).getInput().get(0) == U)) || HopRewriteUtils.isTransposeOperation(out) ) return OutProdType.LEFT_OUTER_PRODUCT; else if( out instanceof AggBinaryOp && (out.getInput().get(1) == V || HopRewriteUtils.isTransposeOperation(out.getInput().get(1)) && out.getInput().get(1).getInput().get(0) == V ) ) return OutProdType.RIGHT_OUTER_PRODUCT; else if( out instanceof BinaryOp && HopRewriteUtils.isEqualSize(out.getInput().get(0), out.getInput().get(1)) ) return OutProdType.CELLWISE_OUTER_PRODUCT; //should never come here throw new RuntimeException("Undefined outer product type for hop "+out.getHopID()); } public static boolean isLookup(CNode node) { return isUnary(node, UnaryType.LOOKUP_R, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC) || isTernary(node, TernaryType.LOOKUP_RC1); } public static boolean isUnary(CNode node, UnaryType...types) { return node instanceof CNodeUnary && ArrayUtils.contains(types, ((CNodeUnary)node).getType()); } public static boolean isTernary(CNode node, TernaryType...types) { return node instanceof CNodeTernary && ArrayUtils.contains(types, ((CNodeTernary)node).getType()); } public static CNodeData createCNodeData(Hop hop, boolean compileLiterals) { CNodeData cdata = new CNodeData(hop); cdata.setLiteral(hop instanceof LiteralOp && (compileLiterals || UtilFunctions.isIntegerNumber(((LiteralOp)hop).getStringValue()))); return cdata; } public static CNode skipTranspose(CNode cdataOrig, Hop hop, HashMap<Long, CNode> tmp, boolean compileLiterals) { if( HopRewriteUtils.isTransposeOperation(hop) ) { CNode cdata = tmp.get(hop.getInput().get(0).getHopID()); if( cdata == null ) { //never accessed cdata = TemplateUtils.createCNodeData(hop.getInput().get(0), compileLiterals); tmp.put(hop.getInput().get(0).getHopID(), cdata); } tmp.put(hop.getHopID(), cdata); return cdata; } else { return cdataOrig; } } public static boolean hasTransposeParentUnderOuterProduct(Hop hop) { for( Hop p : hop.getParent() ) if( HopRewriteUtils.isTransposeOperation(p) ) for( Hop p2 : p.getParent() ) if( HopRewriteUtils.isOuterProductLikeMM(p2) ) return true; return false; } public static boolean hasSingleOperation(CNodeTpl tpl) { CNode output = tpl.getOutput(); return (output instanceof CNodeUnary || output instanceof CNodeBinary || output instanceof CNodeTernary) && hasOnlyDataNodeOrLookupInputs(output); } public static boolean hasNoOperation(CNodeTpl tpl) { return tpl.getOutput() instanceof CNodeData || isLookup(tpl.getOutput()); } public static boolean hasOnlyDataNodeOrLookupInputs(CNode node) { boolean ret = true; for( CNode c : node.getInput() ) ret &= (c instanceof CNodeData || (c instanceof CNodeUnary && (((CNodeUnary)c).getType()==UnaryType.LOOKUP0 || ((CNodeUnary)c).getType()==UnaryType.LOOKUP_R || ((CNodeUnary)c).getType()==UnaryType.LOOKUP_RC))); return ret; } public static int countVectorIntermediates(CNode node, HashSet<Long> memo) { //memoization to prevent double counting if( memo.contains(node.getID()) ) return 0; memo.add(node.getID()); //compute vector requirements over all inputs int ret = 0; for( CNode c : node.getInput() ) ret += countVectorIntermediates(c, memo); //compute vector requirements of current node int cntBin = ((node instanceof CNodeBinary && ((CNodeBinary)node).getType().isVectorScalarPrimitive()) ? 1 : 0); int cntUn = ((node instanceof CNodeUnary && ((CNodeUnary)node).getType().isVectorScalarPrimitive()) ? 1 : 0); return ret + cntBin + cntUn; } public static boolean isType(TemplateType type, TemplateType... validTypes) { return ArrayUtils.contains(validTypes, type); } public static boolean hasCommonRowTemplateMatrixInput(Hop input1, Hop input2, CPlanMemoTable memo) { //if second input has no row template, it's always true if( !memo.contains(input2.getHopID(), TemplateType.RowTpl) ) return true; //check for common row template input long tmp1 = getRowTemplateMatrixInput(input1, memo); long tmp2 = getRowTemplateMatrixInput(input2, memo); return (tmp1 == tmp2); } public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) { MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl); long ret = -1; for( int i=0; ret<0 && i<current.getInput().size(); i++ ) { Hop input = current.getInput().get(i); if( me.isPlanRef(i) && memo.contains(input.getHopID(), TemplateType.RowTpl) ) ret = getRowTemplateMatrixInput(input, memo); else if( !me.isPlanRef(i) && isMatrix(input) ) ret = input.getHopID(); } return ret; } }