/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.hops.codegen; import java.util.ArrayList; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.MultiThreadedHop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.MemoTable; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.SpoofFused; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; public class SpoofFusedOp extends Hop implements MultiThreadedHop { public enum SpoofOutputDimsType { INPUT_DIMS, ROW_DIMS, COLUMN_DIMS_ROWS, COLUMN_DIMS_COLS, SCALAR, MULTI_SCALAR, ROW_RANK_DIMS, // right wdivmm COLUMN_RANK_DIMS // left wdivmm } private Class<?> _class = null; private boolean _distSupported = false; private int _numThreads = -1; private SpoofOutputDimsType _dimsType; public SpoofFusedOp ( ) { } public SpoofFusedOp( String name, DataType dt, ValueType vt, Class<?> cla, boolean dist, SpoofOutputDimsType type ) { super(name, dt, vt); _class = cla; _distSupported = dist; _dimsType = type; } @Override public void setMaxNumThreads(int k) { _numThreads = k; } @Override public int getMaxNumThreads() { return _numThreads; } @Override public boolean allowsAllExecTypes() { return _distSupported; } @Override protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { return OptimizerUtils.estimateSize(dim1, dim2); } @Override protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { return 0; } @Override protected long[] inferOutputCharacteristics(MemoTable memo) { return null; } @Override public Lop constructLops() throws HopsException, LopsException { if( getLops() != null ) return getLops(); ExecType et = optFindExecType(); ArrayList<Lop> inputs = new ArrayList<Lop>(); for( Hop c : getInput() ) inputs.add(c.constructLops()); int k = OptimizerUtils.getConstrainedNumThreads(_numThreads); SpoofFused lop = new SpoofFused(inputs, getDataType(), getValueType(), _class, k, et); setOutputDimensions(lop); setLineNumbers(lop); setLops(lop); return lop; } @Override protected ExecType optFindExecType() throws HopsException { checkAndSetForcedPlatform(); if( _etypeForced != null ) { _etype = _etypeForced; } else { _etype = findExecTypeByMemEstimate(); checkAndSetInvalidCPDimsAndSize(); } //ensure valid execution plans if( _etype == ExecType.MR ) _etype = ExecType.CP; return _etype; } @Override public String getOpString() { return "spoof("+_class.getSimpleName()+")"; } @Override public void refreshSizeInformation() { switch(_dimsType) { case ROW_DIMS: setDim1(getInput().get(0).getDim1()); setDim2(1); break; case COLUMN_DIMS_ROWS: setDim1(getInput().get(0).getDim2()); setDim2(1); break; case COLUMN_DIMS_COLS: setDim1(1); setDim2(getInput().get(0).getDim2()); break; case INPUT_DIMS: setDim1(getInput().get(0).getDim1()); setDim2(getInput().get(0).getDim2()); break; case SCALAR: setDim1(0); setDim2(0); break; case MULTI_SCALAR: setDim1(1); //row vector //dim2 statically set from outside break; case ROW_RANK_DIMS: setDim1(getInput().get(0).getDim1()); setDim2(getInput().get(1).getDim2()); break; case COLUMN_RANK_DIMS: setDim1(getInput().get(0).getDim2()); setDim2(getInput().get(1).getDim2()); break; default: throw new RuntimeException("Failed to refresh size information " + "for type: "+_dimsType.toString()); } } @Override public Object clone() throws CloneNotSupportedException { SpoofFusedOp ret = new SpoofFusedOp(); //copy generic attributes ret.clone(this, false); //copy specific attributes ret._class = _class; ret._distSupported = _distSupported; ret._numThreads = _numThreads; ret._dimsType = _dimsType; return ret; } @Override public boolean compare( Hop that ) { if( !(that instanceof SpoofFusedOp) ) return false; SpoofFusedOp that2 = (SpoofFusedOp)that; boolean ret = ( _class.equals(that2._class) && _distSupported == that2._distSupported && _numThreads == that2._numThreads && getInput().size() == that2.getInput().size()); if( ret ) { for( int i=0; i<getInput().size(); i++ ) ret &= (getInput().get(i) == that2.getInput().get(i)); } return ret; } }