/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.codegen.cplan; import java.util.ArrayList; import java.util.Arrays; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; public abstract class CNode { private static final IDSequence _seqVar = new IDSequence(); private static final IDSequence _seqID = new IDSequence(); protected final long _ID; protected ArrayList<CNode> _inputs = null; protected CNode _output = null; protected boolean _visited = false; protected boolean _generated = false; protected String _genVar = null; protected long _rows = -1; protected long _cols = -1; protected DataType _dataType; protected boolean _literal = false; //cached hash to allow memoization in DAG structures and repeated //recursive hash computation over all inputs (w/ reset on updates) protected int _hash = 0; public CNode() { _ID = _seqID.getNextID(); _inputs = new ArrayList<CNode>(); _generated = false; } public long getID() { return _ID; } public ArrayList<CNode> getInput() { return _inputs; } public String createVarname() { _genVar = "TMP"+_seqVar.getNextID(); return _genVar; } protected String getCurrentVarName() { return "TMP"+(_seqVar.getCurrentID()-1); } public String getVarname() { return _genVar; } public String getClassname() { return getVarname(); } public void resetGenerated() { if( _generated ) for( CNode cn : _inputs ) cn.resetGenerated(); _generated = false; } public void setNumRows(long rows) { _rows = rows; } public long getNumRows() { return _rows; } public void setNumCols(long cols) { _cols = cols; } public long getNumCols() { return _cols; } public DataType getDataType() { return _dataType; } public void setDataType(DataType dt) { _dataType = dt; _hash = 0; } public boolean isLiteral() { return _literal; } public void setLiteral(boolean literal) { _literal = literal; _hash = 0; } public CNode getOutput() { return _output; } public void setOutput(CNode output) { _output = output; _hash = 0; } public boolean isVisited() { return _visited; } public void setVisited() { setVisited(true); } public void setVisited(boolean flag) { _visited = flag; } public void resetVisitStatus() { if( !isVisited() ) return; for( CNode h : getInput() ) h.resetVisitStatus(); setVisited(false); } public abstract String codegen(boolean sparse) ; public abstract void setOutputDims(); /////////////////////////////////////// // Functionality for plan cache //note: genvar/generated changed on codegen and not considered, //rows and cols also not include to increase reuse potential @Override public int hashCode() { if( _hash == 0 ) { int numIn = _inputs.size(); int[] tmp = new int[numIn + 3]; //include inputs, partitioned by matrices and scalars to increase //reuse in case of interleaved inputs (see CNodeTpl.renameInputs) int pos = 0; for( CNode c : _inputs ) if( c.getDataType()==DataType.MATRIX ) tmp[pos++] = c.hashCode(); for( CNode c : _inputs ) if( c.getDataType()!=DataType.MATRIX ) tmp[pos++] = c.hashCode(); tmp[numIn+0] = (_output!=null)?_output.hashCode():0; tmp[numIn+1] = (_dataType!=null)?_dataType.hashCode():0; tmp[numIn+2] = Boolean.valueOf(_literal).hashCode(); _hash = Arrays.hashCode(tmp); } return _hash; } @Override public boolean equals(Object that) { if( !(that instanceof CNode) ) return false; CNode cthat = (CNode) that; boolean ret = _inputs.size() == cthat._inputs.size(); for( int i=0; i<_inputs.size() && ret; i++ ) ret &= _inputs.get(i).equals(cthat._inputs.get(i)); return ret && (_output == cthat._output || _output.equals(cthat._output)) && _dataType == cthat._dataType && _literal == cthat._literal; } }