/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.codegen.cplan.CNode; import org.apache.sysml.hops.codegen.cplan.CNodeData; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.runtime.matrix.data.Pair; public class TemplateMultiAgg extends TemplateCell { public TemplateMultiAgg() { super(TemplateType.MultiAggTpl, false); } public TemplateMultiAgg(boolean closed) { super(TemplateType.MultiAggTpl, closed); } @Override public boolean open(Hop hop) { //multiagg is a composite templates, which is not //created via open-fuse-merge-close return false; } @Override public boolean fuse(Hop hop, Hop input) { return false; } @Override public boolean merge(Hop hop, Hop input) { return false; } @Override public CloseType close(Hop hop) { return CloseType.CLOSED_INVALID; } public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) { //get all root nodes for multi aggregation MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MultiAggTpl); ArrayList<Hop> roots = new ArrayList<Hop>(); for( int i=0; i<3; i++ ) if( multiAgg.isPlanRef(i) ) roots.add(memo._hopRefs.get(multiAgg.input(i))); Hop.resetVisitStatus(roots); //recursively process required cplan outputs HashSet<Hop> inHops = new HashSet<Hop>(); HashMap<Long, CNode> tmp = new HashMap<Long, CNode>(); for( Hop root : roots ) //use celltpl cplan construction super.rConstructCplan(root, memo, tmp, inHops, compileLiterals); Hop.resetVisitStatus(roots); //reorder inputs (ensure matrices/vectors come first) and prune literals //note: we order by number of cells and subsequently sparsity to ensure //that sparse inputs are used as the main input w/o unnecessary conversion List<Hop> sinHops = inHops.stream() .filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())) .sorted(new HopInputComparator()).collect(Collectors.toList()); //construct template node ArrayList<CNode> inputs = new ArrayList<CNode>(); for( Hop in : sinHops ) inputs.add(tmp.get(in.getHopID())); ArrayList<CNode> outputs = new ArrayList<CNode>(); ArrayList<AggOp> aggOps = new ArrayList<AggOp>(); for( Hop root : roots ) { CNode node = tmp.get(root.getHopID()); if( node instanceof CNodeData //add indexing ops for sideways data inputs && ((CNodeData)inputs.get(0)).getHopID() != ((CNodeData)node).getHopID() ) node = new CNodeUnary(node, (roots.get(0).getDim2()==1) ? UnaryType.LOOKUP_R : UnaryType.LOOKUP_RC); outputs.add(node); aggOps.add(TemplateUtils.getAggOp(root)); } CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs); tpl.setAggOps(aggOps); tpl.setRootNodes(roots); // return cplan instance return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); } }