/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.codegen.template;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.SpoofCompiler;
import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
public class CPlanMemoTable
{
private static final Log LOG = LogFactory.getLog(CPlanMemoTable.class.getName());
protected HashMap<Long, List<MemoTableEntry>> _plans;
protected HashMap<Long, Hop> _hopRefs;
protected HashSet<Long> _plansBlacklist;
public CPlanMemoTable() {
_plans = new HashMap<Long, List<MemoTableEntry>>();
_hopRefs = new HashMap<Long, Hop>();
_plansBlacklist = new HashSet<Long>();
}
public void addHop(Hop hop) {
_hopRefs.put(hop.getHopID(), hop);
}
public boolean containsHop(Hop hop) {
return _hopRefs.containsKey(hop.getHopID());
}
public boolean contains(long hopID) {
return _plans.containsKey(hopID);
}
public boolean contains(long hopID, TemplateType type) {
return contains(hopID) && get(hopID).stream()
.filter(p -> p.type==type).findAny().isPresent();
}
public boolean containsTopLevel(long hopID) {
return !_plansBlacklist.contains(hopID)
&& getBest(hopID) != null;
}
public void add(Hop hop, TemplateType type) {
add(hop, type, -1, -1, -1);
}
public void add(Hop hop, TemplateType type, long in1) {
add(hop, type, in1, -1, -1);
}
public void add(Hop hop, TemplateType type, long in1, long in2) {
add(hop, type, in1, in2, -1);
}
public void add(Hop hop, TemplateType type, long in1, long in2, long in3) {
add(hop, new MemoTableEntry(type, in1, in2, in3));
}
public void add(Hop hop, MemoTableEntry me) {
_hopRefs.put(hop.getHopID(), hop);
if( !_plans.containsKey(hop.getHopID()) )
_plans.put(hop.getHopID(), new ArrayList<MemoTableEntry>());
_plans.get(hop.getHopID()).add(me);
}
public void addAll(Hop hop, MemoTableEntrySet P) {
_hopRefs.put(hop.getHopID(), hop);
if( !_plans.containsKey(hop.getHopID()) )
_plans.put(hop.getHopID(), new ArrayList<MemoTableEntry>());
_plans.get(hop.getHopID()).addAll(P.plans);
}
public void remove(Hop hop, HashSet<MemoTableEntry> blackList) {
_plans.put(hop.getHopID(), _plans.get(hop.getHopID()).stream()
.filter(p -> !blackList.contains(p)).collect(Collectors.toList()));
}
public void setDistinct(long hopID, List<MemoTableEntry> plans) {
_plans.put(hopID, plans.stream()
.distinct().collect(Collectors.toList()));
}
public void pruneRedundant(long hopID) {
if( !contains(hopID) )
return;
//prune redundant plans (i.e., equivalent)
setDistinct(hopID, _plans.get(hopID));
//prune dominated plans (e.g., opened plan subsumed
//by fused plan if single consumer of input)
HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>();
List<MemoTableEntry> list = _plans.get(hopID);
Hop hop = _hopRefs.get(hopID);
for( MemoTableEntry e1 : list )
for( MemoTableEntry e2 : list )
if( e1 != e2 && e1.subsumes(e2) ) {
//check that childs don't have multiple consumers
boolean rmSafe = true;
for( int i=0; i<=2; i++ )
rmSafe &= (e1.isPlanRef(i) && !e2.isPlanRef(i)) ?
hop.getInput().get(i).getParent().size()==1 : true;
if( rmSafe )
rmList.add(e2);
}
//update current entry list, by removing rmList
remove(hop, rmList);
}
public void pruneSuboptimal(ArrayList<Hop> roots) {
if( LOG.isTraceEnabled() )
LOG.trace("#1: Memo before plan selection ("+size()+" plans)\n"+this);
//build index of referenced entries
HashSet<Long> ix = new HashSet<Long>();
for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
for( MemoTableEntry me : e.getValue() ) {
ix.add(me.input1);
ix.add(me.input2);
ix.add(me.input3);
}
//prune single-operation (not referenced, and no child references)
Iterator<Entry<Long, List<MemoTableEntry>>> iter = _plans.entrySet().iterator();
while( iter.hasNext() ) {
Entry<Long, List<MemoTableEntry>> e = iter.next();
if( !ix.contains(e.getKey()) ) {
e.setValue(e.getValue().stream().filter(
p -> p.hasPlanRef()).collect(Collectors.toList()));
if( e.getValue().isEmpty() )
iter.remove();
}
}
//prune dominated plans (e.g., plan referenced by other plan and this
//other plan is single consumer) by marking it as blacklisted because
//the chain of entries is still required for cplan construction
for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
for( MemoTableEntry me : e.getValue() ) {
for( int i=0; i<=2; i++ )
if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 )
_plansBlacklist.add(me.input(i));
}
//core plan selection
PlanSelection selector = SpoofCompiler.createPlanSelector();
selector.selectPlans(this, roots);
if( LOG.isTraceEnabled() )
LOG.trace("#2: Memo after plan selection ("+size()+" plans)\n"+this);
}
public List<MemoTableEntry> get(long hopID) {
return _plans.get(hopID);
}
public List<MemoTableEntry> getDistinct(long hopID) {
//return distinct entries wrt type and closed attributes
return _plans.get(hopID).stream()
.map(p -> new MemoTableEntry(p.type,-1,-1,-1,p.closed))
.distinct().collect(Collectors.toList());
}
public MemoTableEntry getBest(long hopID) {
List<MemoTableEntry> tmp = get(hopID);
if( tmp == null || tmp.isEmpty() )
return null;
//single plan per type, get plan w/ best rank in preferred order
//but ensure that the plans valid as a top-level plan
return tmp.stream().filter(p -> PlanSelection.isValid(p, _hopRefs.get(hopID)))
.min(Comparator.comparing(p -> p.type.getRank())).orElse(null);
}
public MemoTableEntry getBest(long hopID, TemplateType pref) {
List<MemoTableEntry> tmp = get(hopID);
if( tmp == null || tmp.isEmpty() )
return null;
//single plan per type, get plan w/ best rank in preferred order
return Collections.min(tmp, Comparator.comparing(
p -> (p.type==pref) ? -p.countPlanRefs() : p.type.getRank()+1));
}
public long[] getAllRefs(long hopID) {
long[] refs = new long[3];
for( MemoTableEntry me : get(hopID) )
for( int i=0; i<3; i++ )
if( me.isPlanRef(i) )
refs[i] |= me.input(i);
return refs;
}
public int size() {
return _plans.values().stream()
.map(list -> list.size())
.mapToInt(x -> x.intValue()).sum();
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("----------------------------------\n");
sb.append("MEMO TABLE: \n");
sb.append("----------------------------------\n");
for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() ) {
sb.append(e.getKey() + " "+_hopRefs.get(e.getKey()).getOpString()+": ");
sb.append(Arrays.toString(e.getValue().toArray(new MemoTableEntry[0]))+"\n");
}
sb.append("----------------------------------\n");
sb.append("Blacklisted Plans: ");
sb.append(Arrays.toString(_plansBlacklist.toArray(new Long[0]))+"\n");
sb.append("----------------------------------\n");
return sb.toString();
}
////////////////////////////////////////
// Memo table entry abstractions
//////
public static class MemoTableEntry
{
public TemplateType type;
public final long input1;
public final long input2;
public final long input3;
public boolean closed = false;
public MemoTableEntry(TemplateType t, long in1, long in2, long in3) {
this(t, in1, in2, in3, false);
}
public MemoTableEntry(TemplateType t, long in1, long in2, long in3, boolean close) {
type = t;
input1 = in1;
input2 = in2;
input3 = in3;
closed = close;
}
public boolean isPlanRef(int index) {
return (index==0 && input1 >=0)
|| (index==1 && input2 >=0)
|| (index==2 && input3 >=0);
}
public boolean hasPlanRef() {
return isPlanRef(0) || isPlanRef(1) || isPlanRef(2);
}
public int countPlanRefs() {
return ((input1 >= 0) ? 1 : 0)
+ ((input2 >= 0) ? 1 : 0)
+ ((input3 >= 0) ? 1 : 0);
}
public long input(int index) {
return (index==0) ? input1 : (index==1) ? input2 : input3;
}
public boolean subsumes(MemoTableEntry that) {
return (type == that.type
&& !(!isPlanRef(0) && that.isPlanRef(0))
&& !(!isPlanRef(1) && that.isPlanRef(1))
&& !(!isPlanRef(2) && that.isPlanRef(2)));
}
@Override
public int hashCode() {
return Arrays.hashCode(
new long[]{(long)type.ordinal(), input1, input2, input3});
}
@Override
public boolean equals(Object obj) {
if( !(obj instanceof MemoTableEntry) )
return false;
MemoTableEntry that = (MemoTableEntry)obj;
return type == that.type && input1 == that.input1
&& input2 == that.input2 && input3 == that.input3;
}
@Override
public String toString() {
return type.name()+"("+input1+","+input2+","+input3+")";
}
}
public static class MemoTableEntrySet
{
public ArrayList<MemoTableEntry> plans = new ArrayList<MemoTableEntry>();
public MemoTableEntrySet(TemplateType type, boolean close) {
plans.add(new MemoTableEntry(type, -1, -1, -1, close));
}
public MemoTableEntrySet(TemplateType type, int pos, long hopID, boolean close) {
plans.add(new MemoTableEntry(type, (pos==0)?hopID:-1,
(pos==1)?hopID:-1, (pos==2)?hopID:-1));
}
public void crossProduct(int pos, Long... refs) {
ArrayList<MemoTableEntry> tmp = new ArrayList<MemoTableEntry>();
for( MemoTableEntry me : plans )
for( Long ref : refs )
tmp.add(new MemoTableEntry(me.type, (pos==0)?ref:me.input1,
(pos==1)?ref:me.input2, (pos==2)?ref:me.input3));
plans = tmp;
}
@Override
public String toString() {
return Arrays.toString(plans.toArray(new MemoTableEntry[0]));
}
}
}