/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.codegen.cplan;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.parser.Expression.DataType;
public abstract class CNodeTpl extends CNode implements Cloneable
{
public CNodeTpl(ArrayList<CNode> inputs, CNode output ) {
if(inputs.size() < 1)
throw new RuntimeException("Cannot pass empty inputs to the CNodeTpl");
for(CNode input : inputs)
addInput(input);
_output = output;
}
public void addInput(CNode in) {
//check for duplicate entries or literals
if( containsInput(in) || in.isLiteral() )
return;
_inputs.add(in);
}
public void cleanupInputs(HashSet<Long> filter) {
ArrayList<CNode> tmp = new ArrayList<CNode>();
for( CNode in : _inputs )
if( in instanceof CNodeData && filter.contains(((CNodeData) in).getHopID()) )
tmp.add(in);
_inputs = tmp;
}
public String[] getInputNames() {
String[] ret = new String[_inputs.size()];
for( int i=0; i<_inputs.size(); i++ )
ret[i] = _inputs.get(i).getVarname();
return ret;
}
public void resetVisitStatusOutputs() {
getOutput().resetVisitStatus();
}
public String codegen() {
return codegen(false);
}
public abstract CNodeTpl clone();
public abstract SpoofOutputDimsType getOutputDimType();
public abstract String getTemplateInfo();
protected void renameInputs(ArrayList<CNode> inputs, int startIndex) {
renameInputs(Collections.singletonList(_output), inputs, startIndex);
}
protected void renameInputs(List<CNode> outputs, ArrayList<CNode> inputs, int startIndex) {
//create map of hopID to data nodes with new names, used for CSE
HashMap<Long, CNode> nodes = new HashMap<Long, CNode>();
for(int i=startIndex, sPos=0, mPos=0; i < inputs.size(); i++) {
CNode cnode = inputs.get(i);
if( cnode instanceof CNodeData && ((CNodeData)cnode).isLiteral() )
continue;
CNodeData cdata = (CNodeData)cnode;
if( cdata.getDataType() == DataType.SCALAR || ( cdata.getNumCols() == 0 && cdata.getNumRows() == 0) )
nodes.put(cdata.getHopID(), new CNodeData(cdata, "scalars["+ mPos++ +"]"));
else
nodes.put(cdata.getHopID(), new CNodeData(cdata, "b["+ sPos++ +"]"));
}
//single pass to replace all names
for( CNode output : outputs )
rReplaceDataNode(output, nodes, new HashMap<Long, CNode>());
}
protected void rReplaceDataNode( CNode root, CNode input, String newName ) {
if( !(input instanceof CNodeData) )
return;
//create temporary name mapping
HashMap<Long, CNode> names = new HashMap<Long, CNode>();
CNodeData tmp = (CNodeData)input;
names.put(tmp.getHopID(), new CNodeData(tmp, newName));
rReplaceDataNode(root, names, new HashMap<Long,CNode>());
}
protected void rReplaceDataNode( ArrayList<CNode> roots, CNode input, String newName ) {
if( !(input instanceof CNodeData) )
return;
//create temporary name mapping
HashMap<Long, CNode> names = new HashMap<Long, CNode>();
CNodeData tmp = (CNodeData)input;
names.put(tmp.getHopID(), new CNodeData(tmp, newName));
for( CNode root : roots )
rReplaceDataNode(root, names, new HashMap<Long,CNode>());
}
/**
* Recursively searches for data nodes and replaces them if found.
*
* @param node current node in recursive descend
* @param dnodes prepared data nodes, identified by own hop id
* @param lnodes memoized lookup nodes, identified by data node hop id
*/
protected void rReplaceDataNode( CNode node, HashMap<Long, CNode> dnodes, HashMap<Long, CNode> lnodes )
{
for( int i=0; i<node._inputs.size(); i++ ) {
//recursively process children
rReplaceDataNode(node._inputs.get(i), dnodes, lnodes);
//replace leaf data node
if( node._inputs.get(i) instanceof CNodeData ) {
CNodeData tmp = (CNodeData)node._inputs.get(i);
if( dnodes.containsKey(tmp.getHopID()) )
node._inputs.set(i, dnodes.get(tmp.getHopID()));
}
//replace lookup on top of leaf data node (for CSE only)
if( node._inputs.get(i) instanceof CNodeUnary
&& node._inputs.get(i)._inputs.get(0) instanceof CNodeData
&& (((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_R
|| ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_RC)) {
CNodeData tmp = (CNodeData)node._inputs.get(i)._inputs.get(0);
if( !lnodes.containsKey(tmp.getHopID()) )
lnodes.put(tmp.getHopID(), node._inputs.get(i));
else
node._inputs.set(i, lnodes.get(tmp.getHopID()));
}
}
}
public void rReplaceDataNode( CNode node, long hopID, CNode newNode )
{
for( int i=0; i<node._inputs.size(); i++ ) {
//replace leaf node
if( node._inputs.get(i) instanceof CNodeData ) {
CNodeData tmp = (CNodeData)node._inputs.get(i);
if( tmp.getHopID() == hopID )
node._inputs.set(i, newNode);
}
//recursively process children
rReplaceDataNode(node._inputs.get(i), hopID, newNode);
//remove unnecessary lookups
if( node._inputs.get(i) instanceof CNodeUnary
&& (((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_R
|| ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP_RC)
&& node._inputs.get(i)._inputs.get(0).getDataType()==DataType.SCALAR)
node._inputs.set(i, node._inputs.get(i)._inputs.get(0));
}
}
public void rInsertLookupNode( CNode node, long hopID, HashMap<Long, CNode> memo, UnaryType lookupType )
{
for( int i=0; i<node._inputs.size(); i++ ) {
//recursively process children
rInsertLookupNode(node._inputs.get(i), hopID, memo, lookupType);
//replace leaf node
if( node._inputs.get(i) instanceof CNodeData ) {
CNodeData tmp = (CNodeData)node._inputs.get(i);
if( tmp.getHopID() == hopID ) {
//use memo structure to retain DAG structure
CNode lookup = memo.get(hopID);
if( lookup == null && !TemplateUtils.isLookup(node) ) {
lookup = new CNodeUnary(tmp, lookupType);
memo.put(hopID, lookup);
}
else if( TemplateUtils.isLookup(node) )
((CNodeUnary)node).setType(lookupType);
else
node._inputs.set(i, lookup);
}
}
}
}
/**
* Checks for duplicates (object ref or varname).
*
* @param input new input node
* @return true if duplicate, false otherwise
*/
private boolean containsInput(CNode input) {
if( !(input instanceof CNodeData) )
return false;
CNodeData input2 = (CNodeData)input;
for( CNode cnode : _inputs ) {
if( !(cnode instanceof CNodeData) )
continue;
CNodeData cnode2 = (CNodeData)cnode;
if( cnode2._name.equals(input2._name) && cnode2._hopID==input2._hopID )
return true;
}
return false;
}
@Override
public int hashCode() {
return super.hashCode();
}
@Override
public boolean equals(Object o) {
return (o instanceof CNodeTpl
&& super.equals(o));
}
protected static boolean equalInputReferences(CNode current1, CNode current2, ArrayList<CNode> input1, ArrayList<CNode> input2) {
boolean ret = (current1.getInput().size() == current2.getInput().size());
//process childs recursively
for( int i=0; ret && i<current1.getInput().size(); i++ )
ret &= equalInputReferences(
current1.getInput().get(i), current2.getInput().get(i), input1, input2);
if( ret && current1 instanceof CNodeData ) {
ret &= current2 instanceof CNodeData
&& indexOf(input1, (CNodeData)current1)
== indexOf(input2, (CNodeData)current2);
}
return ret;
}
protected static boolean equalInputReferences(ArrayList<CNode> current1, ArrayList<CNode> current2, ArrayList<CNode> input1, ArrayList<CNode> input2) {
boolean ret = (current1.size() == current2.size());
for( int i=0; ret && i<current1.size(); i++ )
ret &= equalInputReferences(current1.get(i), current2.get(i), input1, input2);
return ret;
}
private static int indexOf(ArrayList<CNode> inputs, CNodeData probe) {
for( int i=0; i<inputs.size(); i++ ) {
CNodeData cd = ((CNodeData)inputs.get(i));
if( cd.getHopID()==probe.getHopID() )
return i;
}
return -1;
}
}