/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysml.hops.codegen.cplan.CNodeCell; import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.matrix.data.Pair; public class TemplateCell extends TemplateBase { private static final AggOp[] SUPPORTED_AGG = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX}; public TemplateCell() { super(TemplateType.CellTpl); } public TemplateCell(boolean closed) { super(TemplateType.CellTpl, closed); } public TemplateCell(TemplateType type, boolean closed) { super(type, closed); } @Override public boolean open(Hop hop) { return isValidOperation(hop) || (hop instanceof IndexingOp && ((IndexingOp)hop).isColLowerEqualsUpper()); } @Override public boolean fuse(Hop hop, Hop input) { return !isClosed() && (isValidOperation(hop) || (HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) && ((AggUnaryOp) hop).getDirection()!= Direction.Col) || (HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1()==1 && hop.getDim2()==1) && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))); } @Override public boolean merge(Hop hop, Hop input) { //merge of other cell tpl possible return (!isClosed() && isValidOperation(hop)); } @Override public CloseType close(Hop hop) { //need to close cell tpl after aggregation, see fuse for exact properties if( (HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) && ((AggUnaryOp) hop).getDirection()!= Direction.Col) || (HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1()==1 && hop.getDim2()==1) ) return CloseType.CLOSED_VALID; else if( hop instanceof AggUnaryOp || hop instanceof AggBinaryOp ) return CloseType.CLOSED_INVALID; else return CloseType.OPEN; } @Override public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) { //recursively process required cplan output HashSet<Hop> inHops = new HashSet<Hop>(); HashMap<Long, CNode> tmp = new HashMap<Long, CNode>(); hop.resetVisitStatus(); rConstructCplan(hop, memo, tmp, inHops, compileLiterals); hop.resetVisitStatus(); //reorder inputs (ensure matrices/vectors come first) and prune literals //note: we order by number of cells and subsequently sparsity to ensure //that sparse inputs are used as the main input w/o unnecessary conversion List<Hop> sinHops = inHops.stream() .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) .sorted(new HopInputComparator()).collect(Collectors.toList()); //construct template node ArrayList<CNode> inputs = new ArrayList<CNode>(); for( Hop in : sinHops ) inputs.add(tmp.get(in.getHopID())); CNode output = tmp.get(hop.getHopID()); CNodeCell tpl = new CNodeCell(inputs, output); tpl.setCellType(TemplateUtils.getCellType(hop)); tpl.setAggOp(TemplateUtils.getAggOp(hop)); tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && hop.getInput().contains(sinHops.get(0))) || (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && hop.getInput().get(0) == sinHops.get(0))); tpl.setRequiresCastDtm(hop instanceof AggBinaryOp); // return cplan instance return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); } protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) { //memoization for common subexpression elimination and to avoid redundant work if( tmp.containsKey(hop.getHopID()) ) return; MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CellTpl); //recursively process required childs if( me!=null && (me.type == TemplateType.RowTpl || me.type == TemplateType.OuterProdTpl) ) { CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals); tmp.put(hop.getHopID(), cdata); inHops.add(hop); return; } for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); if( me!=null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type!=TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl))) rConstructCplan(c, memo, tmp, inHops, compileLiterals); else { CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); tmp.put(c.getHopID(), cdata); inHops.add(c); } } //construct cnode for current hop CNode out = null; if(hop instanceof UnaryOp) { CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); String primitiveOpName = ((UnaryOp)hop).getOp().name(); out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); } else if(hop instanceof BinaryOp) { BinaryOp bop = (BinaryOp) hop; CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); String primitiveOpName = bop.getOp().name(); //add lookups if required cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1)); if( bop.getOp()==OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2") ) out = new CNodeUnary(cdata1, UnaryType.POW2); else if( bop.getOp()==OpOp2.MULT && cdata2.isLiteral() && cdata2.getVarname().equals("2") ) out = new CNodeUnary(cdata1, UnaryType.MULT2); else //default binary out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } else if(hop instanceof TernaryOp) { TernaryOp top = (TernaryOp) hop; CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID()); //add lookups if required cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2)); //construct ternary cnode, primitive operation derived from OpOp3 out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name())); } else if( hop instanceof ParameterizedBuiltinOp ) { CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID()); cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID()); CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID()); TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE; out = new CNodeTernary(cdata1, cdata2, cdata3, ttype); } else if( hop instanceof IndexingOp ) { CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1); } else if( HopRewriteUtils.isTransposeOperation(hop) ) { out = tmp.get(hop.getInput().get(0).getHopID()); } else if( hop instanceof AggUnaryOp ) { //aggregation handled in template implementation (note: we do not compile //^2 of SUM_SQ into the operator to simplify the detection of single operators) out = tmp.get(hop.getInput().get(0).getHopID()); } else if( hop instanceof AggBinaryOp ) { //guaranteed to be a dot product, so there are two cases: //(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y) if( HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1)) ) { CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID()); out = new CNodeUnary(cdata1, UnaryType.POW2); } else { CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals); if( TemplateUtils.isColVector(cdata1) ) cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); if( TemplateUtils.isColVector(cdata2) ) cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); out = new CNodeBinary(cdata1, cdata2, BinType.MULT); } } tmp.put(hop.getHopID(), out); } protected static boolean isValidOperation(Hop hop) { //prepare indicators for binary operations boolean isBinaryMatrixScalar = false; boolean isBinaryMatrixVector = false; boolean isBinaryMatrixMatrixDense = false; if( hop instanceof BinaryOp && hop.getDataType().isMatrix() ) { Hop left = hop.getInput().get(0); Hop right = hop.getInput().get(1); DataType ldt = left.getDataType(); DataType rdt = right.getDataType(); isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar()); isBinaryMatrixVector = hop.dimsKnown() && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)) ); isBinaryMatrixMatrixDense = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix() && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right); } //prepare indicators for ternary operations boolean isTernaryVectorScalarVector = false; boolean isTernaryMatrixScalarMatrixDense = false; if( hop instanceof TernaryOp && hop.getInput().size()==3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) { Hop left = hop.getInput().get(0); Hop right = hop.getInput().get(2); isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right); isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right); } //check supported unary, binary, ternary operations return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrixDense || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE)); } /** * Comparator to order input hops of the cell template. We try to order * matrices-vectors-scalars via sorting by number of cells and for * equal number of cells by sparsity to prefer sparse inputs as the main * input for sparsity exploitation. */ public static class HopInputComparator implements Comparator<Hop> { @Override public int compare(Hop h1, Hop h2) { long ncells1 = h1.getDataType()==DataType.SCALAR ? Long.MIN_VALUE : h1.dimsKnown() ? h1.getDim1()*h1.getDim2() : Long.MAX_VALUE; long ncells2 = h2.getDataType()==DataType.SCALAR ? Long.MIN_VALUE : h2.dimsKnown() ? h2.getDim1()*h2.getDim2() : Long.MAX_VALUE; if( ncells1 > ncells2 ) return -1; else if( ncells1 < ncells2) return 1; return Long.compare( h1.dimsKnown(true) ? h1.getNnz() : ncells1, h2.dimsKnown(true) ? h2.getNnz() : ncells2); } } }