/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.parser; import java.util.HashMap; import org.apache.sysml.parser.LanguageException.LanguageErrorCodes; public class RelationalExpression extends Expression { private Expression _left; private Expression _right; private RelationalOp _opcode; public RelationalExpression(RelationalOp bop) { _kind = Kind.RelationalOp; _opcode = bop; setFilename("MAIN SCRIPT"); setBeginLine(0); setBeginColumn(0); setEndLine(0); setEndColumn(0); } public RelationalExpression(RelationalOp bop, String filename, int beginLine, int beginColumn, int endLine, int endColumn) { _kind = Kind.RelationalOp; _opcode = bop; setFilename(filename); setBeginLine(beginLine); setBeginColumn(beginColumn); setEndLine(endLine); setEndColumn(endColumn); } public Expression rewriteExpression(String prefix) throws LanguageException{ RelationalExpression newExpr = new RelationalExpression(this._opcode, getFilename(), getBeginLine(), getBeginColumn(), getEndLine(), getEndColumn()); newExpr.setLeft(_left.rewriteExpression(prefix)); newExpr.setRight(_right.rewriteExpression(prefix)); return newExpr; } public RelationalOp getOpCode(){ return _opcode; } public void setLeft(Expression l){ _left = l; // update script location information --> left expression is BEFORE in script if (_left != null){ setFilename(_left.getFilename()); setBeginLine(_left.getBeginLine()); setBeginColumn(_left.getBeginColumn()); } } public void setRight(Expression r){ _right = r; // update script location information --> right expression is AFTER in script if (_right != null){ setFilename(_right.getFilename()); setBeginLine(_right.getEndLine()); setBeginColumn(_right.getEndColumn()); } } public Expression getLeft(){ return _left; } public Expression getRight(){ return _right; } /** * Validate parse tree : Process Relational Expression */ @Override public void validateExpression(HashMap<String,DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional) throws LanguageException { //check for functions calls in expression if (_left instanceof FunctionCallIdentifier){ raiseValidateError("user-defined function calls not supported in relational expressions", false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION); } if (_right instanceof FunctionCallIdentifier){ raiseValidateError("user-defined function calls not supported in relational expressions", false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION); } // handle <NUMERIC> == <BOOLEAN> --> convert <BOOLEAN> to numeric value if ((_left != null && _left instanceof BooleanIdentifier) || (_right != null && _right instanceof BooleanIdentifier)){ if ((_left instanceof IntIdentifier || _left instanceof DoubleIdentifier) || _right instanceof IntIdentifier || _right instanceof DoubleIdentifier){ if (_left instanceof BooleanIdentifier){ if (((BooleanIdentifier) _left).getValue()) this.setLeft(new IntIdentifier(1, _left.getFilename(), _left.getBeginLine(), _left.getBeginColumn(), _left.getEndLine(), _left.getEndColumn())); else this.setLeft(new IntIdentifier(0, _left.getFilename(), _left.getBeginLine(), _left.getBeginColumn(), _left.getEndLine(), _left.getEndColumn())); } else if (_right instanceof BooleanIdentifier){ if (((BooleanIdentifier) _right).getValue()) this.setRight(new IntIdentifier(1, _right.getFilename(), _right.getBeginLine(), _right.getBeginColumn(), _right.getEndLine(),_right.getEndColumn())); else this.setRight(new IntIdentifier(0, _right.getFilename(), _right.getBeginLine(), _right.getBeginColumn(), _right.getEndLine(),_right.getEndColumn())); } } } //recursive validate _left.validateExpression(ids, constVars, conditional); if( _right !=null ) _right.validateExpression(ids, constVars, conditional); //constant propagation (precondition for more complex constant folding rewrite) if( _left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _left).getName()) ) _left = constVars.get(((DataIdentifier) _left).getName()); if( _right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _right).getName()) ) _right = constVars.get(((DataIdentifier) _right).getName()); String outputName = getTempName(); DataIdentifier output = new DataIdentifier(outputName); output.setAllPositions(this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn()); boolean isLeftMatrix = (_left.getOutput() != null && _left.getOutput().getDataType() == DataType.MATRIX); boolean isRightMatrix = (_right.getOutput() != null && _right.getOutput().getDataType() == DataType.MATRIX); if(isLeftMatrix || isRightMatrix) { // Added to support matrix relational comparison if(isLeftMatrix && isRightMatrix) { checkMatchingDimensions(_left, _right, true); } long[] dims = getBinaryMatrixCharacteristics(_left, _right); output.setDataType(DataType.MATRIX); output.setDimensions(dims[0], dims[1]); output.setBlockDimensions(dims[2], dims[3]); //since SystemML only supports double matrices, the value type is forced to //double; once we support boolean matrices this needs to change output.setValueType(ValueType.DOUBLE); } else { output.setBooleanProperties(); } this.setOutput(output); } /** * This is same as the function from BuiltinFunctionExpression which is called by ppred * * @param expr1 expression 1 * @param expr2 expression 2 * @param allowsMV ? * @throws LanguageException */ private void checkMatchingDimensions(Expression expr1, Expression expr2, boolean allowsMV) throws LanguageException { if (expr1 != null && expr2 != null) { // if any matrix has unknown dimensions, simply return if( expr1.getOutput().getDim1() == -1 || expr2.getOutput().getDim1() == -1 ||expr1.getOutput().getDim2() == -1 || expr2.getOutput().getDim2() == -1 ) { return; } else if( (!allowsMV && expr1.getOutput().getDim1() != expr2.getOutput().getDim1()) || (allowsMV && expr1.getOutput().getDim1() != expr2.getOutput().getDim1() && expr2.getOutput().getDim1() != 1) || (!allowsMV && expr1.getOutput().getDim2() != expr2.getOutput().getDim2()) || (allowsMV && expr1.getOutput().getDim2() != expr2.getOutput().getDim2() && expr2.getOutput().getDim2() != 1) ) { raiseValidateError("Mismatch in matrix dimensions of parameters for function " + this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); } } } public String toString(){ String leftString; String rightString; if (_left instanceof StringIdentifier) { leftString = "\"" + _left.toString() + "\""; } else { leftString = _left.toString(); } if (_right instanceof StringIdentifier) { rightString = "\"" + _right.toString() + "\""; } else { rightString = _right.toString(); } return "(" + leftString + " " + _opcode.toString() + " " + rightString + ")"; } @Override public VariableSet variablesRead() { VariableSet result = new VariableSet(); result.addVariables(_left.variablesRead()); result.addVariables(_right.variablesRead()); return result; } @Override public VariableSet variablesUpdated() { VariableSet result = new VariableSet(); result.addVariables(_left.variablesUpdated()); result.addVariables(_right.variablesUpdated()); return result; } }