/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.codegen.template;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeRow;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.matrix.data.Pair;
public class TemplateRow extends TemplateBase
{
private static final Hop.AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX};
private static final Hop.OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN};
private static final Hop.OpOp2[] SUPPORTED_VECT_BINARY = new OpOp2[]{
OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS, OpOp2.PLUS, OpOp2.POW, OpOp2.MIN, OpOp2.MAX,
OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL};
public TemplateRow() {
super(TemplateType.RowTpl);
}
public TemplateRow(boolean closed) {
super(TemplateType.RowTpl, closed);
}
@Override
public boolean open(Hop hop) {
return (hop instanceof BinaryOp && hop.getInput().get(0).getDim2()>1
&& hop.getInput().get(1).getDim2()==1 && TemplateCell.isValidOperation(hop))
|| (hop instanceof AggBinaryOp && hop.getDim2()==1
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1
&& HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG));
}
@Override
public boolean fuse(Hop hop, Hop input) {
return !isClosed() &&
( (hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop)
&& (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
|| HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) )
|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
&& TemplateCell.isValidOperation(hop))
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol
&& HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG))
|| (hop instanceof AggBinaryOp && hop.getDim1()>1
&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
}
@Override
public boolean merge(Hop hop, Hop input) {
//merge rowagg tpl with cell tpl if input is a vector
return !isClosed() &&
((hop instanceof BinaryOp && input.getDim2()==1 //matrix-scalar/vector-vector ops )
&& TemplateUtils.isOperationSupported(hop))
||(hop instanceof AggBinaryOp && input.getDim2()==1
&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
}
@Override
public CloseType close(Hop hop) {
//close on column aggregate (e.g., colSums, t(X)%*%y)
if( hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col
|| (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) )
return CloseType.CLOSED_VALID;
else
return CloseType.OPEN;
}
@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
//recursively process required cplan output
HashSet<Hop> inHops = new HashSet<Hop>();
HashMap<String, Hop> inHops2 = new HashMap<String,Hop>();
HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
hop.resetVisitStatus();
rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
hop.resetVisitStatus();
//reorder inputs (ensure matrix is first input, and other inputs ordered by size)
List<Hop> sinHops = inHops.stream()
.filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral()))
.sorted(new HopInputComparator(inHops2.get("X"))).collect(Collectors.toList());
//construct template node
ArrayList<CNode> inputs = new ArrayList<CNode>();
for( Hop in : sinHops )
inputs.add(tmp.get(in.getHopID()));
CNode output = tmp.get(hop.getHopID());
CNodeRow tpl = new CNodeRow(inputs, output);
tpl.setRowType(TemplateUtils.getRowType(hop, sinHops.get(0)));
tpl.setNumVectorIntermediates(TemplateUtils
.countVectorIntermediates(output, new HashSet<Long>()));
// return cplan instance
return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals)
{
//memoization for common subexpression elimination and to avoid redundant work
if( tmp.containsKey(hop.getHopID()) )
return;
//recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.RowTpl);
for( int i=0; i<hop.getInput().size(); i++ ) {
Hop c = hop.getInput().get(i);
if( me.isPlanRef(i) )
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
//construct cnode for current hop
CNode out = null;
if(hop instanceof AggUnaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
if( ((AggUnaryOp)hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) ) {
if(hop.getInput().get(0).getDim2()==1)
out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R);
else {
String opcode = "ROW_"+((AggUnaryOp)hop).getOp().name().toUpperCase()+"S";
out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
inHops2.put("X", hop.getInput().get(0));
}
}
else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
//vector add without temporary copy
if( cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() )
out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1),
((CNodeBinary)cdata1).getType().getVectorAddPrimitive());
else
out = cdata1;
}
}
else if(hop instanceof AggBinaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if( HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) )
{
//correct input under transpose
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
inHops.remove(hop.getInput().get(0));
inHops.add(hop.getInput().get(0).getInput().get(0));
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
}
else
{
if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1)
out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0),
(cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
else {
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
inHops2.put("X", hop.getInput().get(0));
}
}
}
else if(hop instanceof UnaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 )
{
if( HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY) ) {
String opname = "VECT_"+((UnaryOp)hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
}
else
throw new RuntimeException("Unsupported unary matrix "
+ "operation: " + ((UnaryOp)hop).getOp().name());
}
else //general scalar case
{
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp)hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
}
else if(hop instanceof BinaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 )
{
if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) {
String opname = "VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR";
if( TemplateUtils.isColVector(cdata2) )
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
}
else
throw new RuntimeException("Unsupported binary matrix "
+ "operation: " + ((BinaryOp)hop).getOp().name());
}
else //one input is a vector/scalar other is a scalar
{
String primitiveOpName = ((BinaryOp)hop).getOp().toString();
if( TemplateUtils.isColVector(cdata1) )
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if( TemplateUtils.isColVector(cdata2) )
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
}
}
else if(hop instanceof TernaryOp)
{
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
//add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
//construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3,
TernaryType.valueOf(top.getOp().toString()));
}
else if( hop instanceof ParameterizedBuiltinOp )
{
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ?
TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
}
else if( hop instanceof IndexingOp )
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1,
TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true),
TemplateUtils.createCNodeData(hop.getInput().get(4), true),
TernaryType.LOOKUP_RC1);
}
if( out == null ) {
throw new RuntimeException(hop.getHopID()+" "+hop.getOpString());
}
if( out.getDataType().isMatrix() ) {
out.setNumRows(hop.getDim1());
out.setNumCols(hop.getDim2());
}
tmp.put(hop.getHopID(), out);
}
/**
* Comparator to order input hops of the row aggregate template. We try
* to order matrices-vectors-scalars via sorting by number of cells but
* we keep the given main input always at the first position.
*/
public static class HopInputComparator implements Comparator<Hop>
{
private final Hop _X;
public HopInputComparator(Hop X) {
_X = X;
}
@Override
public int compare(Hop h1, Hop h2) {
long ncells1 = h1.getDataType()==DataType.SCALAR ? Long.MIN_VALUE :
(h1==_X) ? Long.MAX_VALUE :
h1.dimsKnown() ? h1.getDim1()*h1.getDim2() : Long.MAX_VALUE-1;
long ncells2 = h2.getDataType()==DataType.SCALAR ? Long.MIN_VALUE :
(h2==_X) ? Long.MAX_VALUE :
h2.dimsKnown() ? h2.getDim1()*h2.getDim2() : Long.MAX_VALUE-1;
return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 1 : 0;
}
}
}