/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.codegen.template; import java.util.ArrayList; import java.util.Comparator; import java.util.Map.Entry; import java.util.HashSet; import java.util.List; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; /** * This plan selection heuristic aims for fusion without any redundant * computation, which, however, potentially leads to more materialized * intermediates than the fuse all heuristic. * <p> * NOTE: This heuristic is essentially the same as FuseAll, except that * any plans that refer to a hop with multiple consumers are removed in * a pre-processing step. * */ public class PlanSelectionFuseNoRedundancy extends PlanSelection { @Override public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) { //pruning and collection pass for( Hop hop : roots ) rSelectPlans(memo, hop, null); //take all distinct best plans for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() ) memo.setDistinct(e.getKey(), e.getValue()); } private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) { if( isVisited(current.getHopID(), currentType) ) return; //step 0: remove plans that refer to a common partial plan if( memo.contains(current.getHopID()) ) { HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); List<MemoTableEntry> hopP = memo.get(current.getHopID()); for( MemoTableEntry e1 : hopP ) for( int i=0; i<3; i++ ) if( e1.isPlanRef(i) && current.getInput().get(i).getParent().size()>1 ) rmSet.add(e1); //remove references to hops w/ multiple consumers memo.remove(current, rmSet); } //step 1: prune subsumed plans of same type if( memo.contains(current.getHopID()) ) { HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>(); List<MemoTableEntry> hopP = memo.get(current.getHopID()); for( MemoTableEntry e1 : hopP ) for( MemoTableEntry e2 : hopP ) if( e1 != e2 && e1.subsumes(e2) ) rmSet.add(e2); memo.remove(current, rmSet); } //step 2: select plan for current path MemoTableEntry best = null; if( memo.contains(current.getHopID()) ) { if( currentType == null ) { best = memo.get(current.getHopID()).stream() .filter(p -> isValid(p, current)) .min(new BasicPlanComparator()).orElse(null); } else { best = memo.get(current.getHopID()).stream() .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl) .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs())) .orElse(null); } addBestPlan(current.getHopID(), best); } //step 3: recursively process children for( int i=0; i< current.getInput().size(); i++ ) { TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null; rSelectPlans(memo, current.getInput().get(i), pref); } setVisited(current.getHopID(), currentType); } }