/*
* JBoss, Home of Professional Open Source
* Copyright 2008-10 Red Hat and individual contributors
* by the @authors tag. See the copyright.txt in the distribution for a
* full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*
* @authors Andrew Dinn
*/
package org.jboss.byteman.rule.expression;
import org.jboss.byteman.rule.compiler.CompileContext;
import org.jboss.byteman.rule.type.Type;
import org.jboss.byteman.rule.exception.TypeException;
import org.jboss.byteman.rule.exception.ExecuteException;
import org.jboss.byteman.rule.exception.CompileException;
import org.jboss.byteman.rule.Rule;
import org.jboss.byteman.rule.helper.HelperAdapter;
import org.jboss.byteman.rule.grammar.ParseNode;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
/**
* A binary arithmetic operator expression
*/
public class ArithmeticExpression extends BinaryOperExpression
{
public ArithmeticExpression(Rule rule, int oper, ParseNode token, Expression left, Expression right)
throws TypeException
{
super(rule, oper, Type.promote(left.getType(), right.getType()), token, left, right);
}
public Type typeCheck(Type expected) throws TypeException {
Type type1 = getOperand(0).typeCheck(Type.N);
Type type2 = getOperand(1).typeCheck(Type.N);
type = Type.promote(type1, type2);
if (Type.dereference(expected).isDefined() && !expected.isAssignableFrom(type) ) {
throw new TypeException("ArithmenticExpression.typeCheck : invalid expected result type " + expected.getName() + getPos());
}
return type;
}
public Object interpret(HelperAdapter helper)
throws ExecuteException
{
try {
// n.b. be careful with characters here
Object objValue1 = getOperand(0).interpret(helper);
Object objValue2 = getOperand(1).interpret(helper);
Number value1;
Number value2;
if (objValue1 instanceof Character) {
value1 = Integer.valueOf((Character)objValue1);
} else {
value1 = (Number)objValue1;
}
if (objValue2 instanceof Character) {
value2 = Integer.valueOf((Character)objValue2);
} else {
value2 = (Number)objValue2;
}
// type is the result of promoting one or other or both of the operands
// and they should be converted to this type before doing the arithmetic operation
if (type == type.B) {
byte b1 = value1.byteValue();
byte b2 = value2.byteValue();
byte result;
// TODO we should probably only respect the byte, short and char types for + and -
// TODO also need to decide how to handle divide by zero
switch (oper)
{
case MUL:
result = (byte)(b1 * b2);
break;
case DIV:
result = (byte)(b1 / b2);
break;
case PLUS:
result = (byte)(b1 + b2);
break;
case MINUS:
result = (byte)(b1 - b2);
break;
case MOD:
result = (byte)(b1 % b2);
break;
default:
result = 0;
break;
}
return Byte.valueOf(result);
} else if (type == type.S) {
short s1 = value1.shortValue();
short s2 = value2.shortValue();
short result;
switch (oper)
{
case MUL:
result = (short)(s1 * s2);
break;
case DIV:
result = (short)(s1 / s2);
break;
case PLUS:
result = (short)(s1 + s2);
break;
case MINUS:
result = (short)(s1 - s2);
break;
case MOD:
result = (short)(s1 % s2);
break;
default:
result = 0;
break;
}
return Short.valueOf(result);
} else if (type == type.I) {
int i1 = value1.intValue();
int i2 = value2.intValue();
int result;
switch (oper)
{
case MUL:
result = (i1 * i2);
break;
case DIV:
result = (i1 / i2);
break;
case PLUS:
result = (i1 + i2);
break;
case MINUS:
result = (i1 - i2);
break;
case MOD:
result = (i1 % i2);
break;
default:
result = 0;
break;
}
return Integer.valueOf(result);
} else if (type == type.J) {
long l1 = value1.longValue();
long l2 = value2.longValue();
long result;
switch (oper)
{
case MUL:
result = (l1 * l2);
break;
case DIV:
result = (l1 / l2);
break;
case PLUS:
result = (l1 + l2);
break;
case MINUS:
result = (l1 - l2);
break;
case MOD:
result = (l1 % l2);
break;
default:
result = 0;
break;
}
return Long.valueOf(result);
} else if (type == type.F) {
float f1 = value1.floatValue();
float f2 = value2.floatValue();
float result;
switch (oper)
{
case MUL:
result = (f1 * f2);
break;
case DIV:
result = (f1 / f2);
break;
case PLUS:
result = (f1 + f2);
break;
case MINUS:
result = (f1 - f2);
break;
case MOD:
result = (f1 % f2);
break;
default:
result = 0;
break;
}
return Float.valueOf(result);
} else if (type == type.D) {
double d1 = value1.doubleValue();
double d2 = value2.doubleValue();
double result;
switch (oper)
{
case MUL:
result = (d1 * d2);
break;
case DIV:
result = (d1 / d2);
break;
case PLUS:
result = (d1 + d2);
break;
case MINUS:
result = (d1 - d2);
break;
case MOD:
result = (d1 % d2);
break;
default:
result = 0;
break;
}
return Double.valueOf(result);
} else { // (type == type.C)
// use integers here but be careful about conversions
int s1 = value1.intValue();
int s2 = value2.intValue();
char result;
switch (oper)
{
case MUL:
result = (char)(s1 * s2);
break;
case DIV:
result = (char)(s1 / s2);
break;
case PLUS:
result = (char)(s1 + s2);
break;
case MINUS:
result = (char)(s1 - s2);
break;
case MOD:
result = (char)(s1 % s2);
break;
default:
result = 0;
break;
}
return Integer.valueOf(result);
}
} catch (ExecuteException e) {
throw e;
} catch (Exception e) {
throw new ExecuteException("ArithmeticExpression.interpret : unexpected exception for operation " + token + getPos() + " in rule " + helper.getName(), e);
}
}
public void compile(MethodVisitor mv, CompileContext compileContext) throws CompileException
{
// make sure we are at the right source line
compileContext.notifySourceLine(line);
int currentStack = compileContext.getStackCount();
int expectedStack = 0;
Expression operand0 = getOperand(0);
Expression operand1 = getOperand(1);
Type type0 = operand0.getType();
Type type1 = operand1.getType();
// compile lhs -- it adds 1 or 2 to the stack height
operand0.compile(mv, compileContext);
// do any required type conversion
compileTypeConversion(type0, type, mv, compileContext);
// compile rhs -- it adds 1 or 2 to the stack height
operand1.compile(mv, compileContext);
// do any required type conversion
compileTypeConversion(type1, type, mv, compileContext);
try {
// n.b. be careful with characters here
if (type == type.B || type == type.S || type == type.C || type == type.I) {
// TODO we should probably only respect the byte, short and char types for + and -
// TODO also need to decide how to handle divide by zero
expectedStack = 1;
switch (oper)
{
case MUL:
mv.visitInsn(Opcodes.IMUL);
break;
case DIV:
mv.visitInsn(Opcodes.IDIV);
break;
case PLUS:
mv.visitInsn(Opcodes.IADD);
break;
case MINUS:
mv.visitInsn(Opcodes.ISUB);
break;
case MOD:
mv.visitInsn(Opcodes.IREM);
break;
default:
// should never happen
throw new CompileException("ArithmeticExpression.compile : unexpected operator " + oper);
}
// now coerce back to appropriate type
if (type == type.B) {
mv.visitInsn(Opcodes.I2B);
} else if (type == type.S) {
mv.visitInsn(Opcodes.I2S);
} else if (type == type.C) {
mv.visitInsn(Opcodes.I2C);
} // else if (type == type.I) { do nothing }
// ok, we popped two bytes but added one
compileContext.addStackCount(-1);
} else if (type == type.J) {
expectedStack = 2;
switch (oper)
{
case MUL:
mv.visitInsn(Opcodes.LMUL);
break;
case DIV:
mv.visitInsn(Opcodes.LDIV);
break;
case PLUS:
mv.visitInsn(Opcodes.LADD);
break;
case MINUS:
mv.visitInsn(Opcodes.LSUB);
break;
case MOD:
mv.visitInsn(Opcodes.LREM);
break;
default:
// should never happen
throw new CompileException("ArithmeticExpression.compile : unexpected operator " + oper);
}
// ok, we popped four bytes but added two
compileContext.addStackCount(-2);
} else if (type == type.F) {
expectedStack = 1;
switch (oper)
{
case MUL:
mv.visitInsn(Opcodes.FMUL);
break;
case DIV:
mv.visitInsn(Opcodes.FDIV);
break;
case PLUS:
mv.visitInsn(Opcodes.FADD);
break;
case MINUS:
mv.visitInsn(Opcodes.FSUB);
break;
case MOD:
mv.visitInsn(Opcodes.FREM);
break;
default:
// should never happen
throw new CompileException("ArithmeticExpression.compile : unexpected operator " + oper);
}
// ok, we popped two bytes but added one
compileContext.addStackCount(-1);
} else if (type == type.D) {
expectedStack = 2;
switch (oper)
{
case MUL:
mv.visitInsn(Opcodes.DMUL);
break;
case DIV:
mv.visitInsn(Opcodes.DDIV);
break;
case PLUS:
mv.visitInsn(Opcodes.DADD);
break;
case MINUS:
mv.visitInsn(Opcodes.DSUB);
break;
case MOD:
mv.visitInsn(Opcodes.DREM);
break;
default:
// should never happen
throw new CompileException("ArithmeticExpression.compile : unexpected operator " + oper);
}
// ok, we popped four bytes but added two
compileContext.addStackCount(-2);
} else {
throw new CompileException("ArithmeticExpression.compile : unexpected result type " + type.getName());
}
} catch (CompileException e) {
throw e;
} catch (Exception e) {
throw new CompileException("ArithmeticExpression.compile : unexpected exception for operation " + token + getPos() + " in rule " + rule.getName(), e);
}
// check stack heights
if (compileContext.getStackCount() != currentStack + expectedStack) {
throw new CompileException("ArithmeticExpression.compile : invalid stack height " + compileContext.getStackCount() + " expecting " + (currentStack + expectedStack));
}
}
}