package me.tomassetti.turin.compiler; import com.google.common.collect.ImmutableList; import me.tomassetti.bytecode_generation.*; import me.tomassetti.bytecode_generation.pushop.PushStaticField; import me.tomassetti.bytecode_generation.returnop.ReturnValueBS; import me.tomassetti.bytecode_generation.returnop.ReturnVoidBS; import me.tomassetti.jvm.*; import me.tomassetti.turin.definitions.ContextDefinition; import me.tomassetti.turin.parser.ast.expressions.Expression; import me.tomassetti.turin.parser.ast.statements.*; import me.tomassetti.turin.parser.ast.typeusage.TypeUsageNode; import me.tomassetti.turin.symbols.Symbol; import me.tomassetti.turin.typesystem.TypeUsage; import org.objectweb.asm.Label; import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import turin.context.Context; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; public class CompilationOfStatements { private final Compilation compilation; public CompilationOfStatements(Compilation compilation) { this.compilation = compilation; } private BytecodeSequence codeToExecuteBeforeReturning; BytecodeSequence compile(Statement statement) { if (statement instanceof VariableDeclaration) { VariableDeclaration variableDeclaration = (VariableDeclaration) statement; int pos = compilation.getLocalVarsSymbolTable().add(variableDeclaration.getName(), variableDeclaration); JvmTypeCategory typeCategory = variableDeclaration.varType(compilation.getResolver()).jvmType().typeCategory(); return new ComposedBytecodeSequence(ImmutableList.of( compilation.getPushUtils().pushExpression(variableDeclaration.getValue()), new LocalVarAssignmentBS(pos, typeCategory))); } else if (statement instanceof ExpressionStatement) { Expression expression = ((ExpressionStatement) statement).getExpression(); return new CompilationOfPush(compilation).pushExpression(expression); } else if (statement instanceof BlockStatement) { BlockStatement blockStatement = (BlockStatement) statement; List<BytecodeSequence> elements = blockStatement.getStatements().stream().map((s) -> compile(s)).collect(Collectors.toList()); return new ComposedBytecodeSequence(elements); } else if (statement instanceof ReturnStatement) { ReturnStatement returnStatement = (ReturnStatement) statement; if (returnStatement.hasValue()) { Expression returnedValue = returnStatement.getValue(); TypeUsage returnedValueType = returnedValue.calcType(); int returnType = returnedValueType.jvmType().returnOpcode(); if (codeToExecuteBeforeReturning == null) { return new ReturnValueBS(returnType, compilation.getPushUtils().pushExpression(returnStatement.getValue())); } else { // So we first calculate the value and then we clean the context return new ReturnValueBS(returnType, new ComposedBytecodeSequence( compilation.getPushUtils().pushExpression(returnStatement.getValue()), codeToExecuteBeforeReturning)); } } else { if (codeToExecuteBeforeReturning == null) { return new ReturnVoidBS(); } else { return new ComposedBytecodeSequence(codeToExecuteBeforeReturning, new ReturnVoidBS()); } } } else if (statement instanceof IfStatement) { IfStatement ifStatement = (IfStatement) statement; BytecodeSequence ifCondition = compilation.getPushUtils().pushExpression(ifStatement.getCondition()); BytecodeSequence ifBody = compile(ifStatement.getIfBody()); List<BytecodeSequence> elifConditions = ifStatement.getElifStatements().stream().map((ec) -> compilation.getPushUtils().pushExpression(ec.getCondition())).collect(Collectors.toList()); List<BytecodeSequence> elifBodys = ifStatement.getElifStatements().stream().map((ec) -> compile(ec.getBody())).collect(Collectors.toList()); if (ifStatement.hasElse()) { return new IfBS(ifCondition, ifBody, elifConditions, elifBodys, compile(ifStatement.getElseBody())); } else { return new IfBS(ifCondition, ifBody, elifConditions, elifBodys); } } else if (statement instanceof ThrowStatement) { ThrowStatement throwStatement = (ThrowStatement) statement; return new ThrowBS(compilation.getPushUtils().pushExpression(throwStatement.getException())); } else if (statement instanceof TryCatchStatement) { TryCatchStatement tryCatchStatement = (TryCatchStatement) statement; return compile(tryCatchStatement); } else if (statement instanceof ContextScope) { return compile((ContextScope)statement); } else { throw new UnsupportedOperationException(statement.toString()); } } BytecodeSequence codeOnLeavingContextScope(ContextScope contextScope) { return new BytecodeSequence() { @Override public void operate(MethodVisitor mv) { for (ContextAssignment assignment : contextScope.getAssignments()) { ContextDefinition contextSymbol = assignment.contextSymbol().get(); // We need to get the INSTANCE field JvmFieldDefinition fieldDefinition = new JvmFieldDefinition( JvmNameUtils.canonicalToInternal(contextSymbol.getClassQualifiedName()), "INSTANCE", "L" + JvmNameUtils.canonicalToInternal(contextSymbol.getClassQualifiedName()) + ";", true); new PushStaticField(fieldDefinition).operate(mv); // and then call enterContext JvmMethodDefinition enterContext = new JvmMethodDefinition( JvmNameUtils.internalName(Context.class), "exitContext", "()V", false, false ); new MethodInvocationBS(enterContext).operate(mv); } } }; } BytecodeSequence compile(ContextScope contextScope) { // TODO catch exceptions just to execute the "leave context" module // TODO before returning execute the leave context instructions return new BytecodeSequence() { @Override public void operate(MethodVisitor mv) { Label start = new Label(); Label end = new Label(); mv.visitLabel(start); for (ContextAssignment assignment : contextScope.getAssignments()) { ContextDefinition contextSymbol = assignment.contextSymbol().get(); // We need to get the INSTANCE field JvmFieldDefinition fieldDefinition = new JvmFieldDefinition( JvmNameUtils.canonicalToInternal(contextSymbol.getClassQualifiedName()), "INSTANCE", "L" + JvmNameUtils.canonicalToInternal(contextSymbol.getClassQualifiedName()) + ";", true); new PushStaticField(fieldDefinition).operate(mv); // and then call enterContext // push the parameter CompilationOfStatements.this.compilation.getPushUtils().pushExpression(assignment.getContextValue()).operate(mv); JvmMethodDefinition enterContext = new JvmMethodDefinition( JvmNameUtils.internalName(Context.class), "enterContext", "(Ljava/lang/Object;)V", false, false ); new MethodInvocationBS(enterContext).operate(mv); } CompilationOfStatements.this.codeToExecuteBeforeReturning = codeOnLeavingContextScope(contextScope); contextScope.getStatements().forEach((s)->compile(s).operate(mv)); CompilationOfStatements.this.codeToExecuteBeforeReturning = null; codeOnLeavingContextScope(contextScope).operate(mv); mv.visitLabel(end); } }; } BytecodeSequence compile(TryCatchStatement tryCatchStatement) { return new BytecodeSequence() { @Override public void operate(MethodVisitor mv) { Label tryStart = new Label(); Label tryEnd = new Label(); Label afterTryCatch = new Label(); List<Label> catchSpecificLabels = new ArrayList<Label>(); for (CatchClause catchClause : tryCatchStatement.getCatchClauses()) { Label catchSpecificLabel = new Label(); mv.visitTryCatchBlock(tryStart, tryEnd, catchSpecificLabel, JvmNameUtils.canonicalToInternal(catchClause.getExceptionType().resolve(compilation.getResolver()).getQualifiedName())); catchSpecificLabels.add(catchSpecificLabel); } mv.visitLabel(tryStart); compile(tryCatchStatement.getBody()).operate(mv); mv.visitLabel(tryEnd); mv.visitJumpInsn(Opcodes.GOTO, afterTryCatch); int i = 0; for (CatchClause catchClause : tryCatchStatement.getCatchClauses()) { Label catchSpecificLabel = catchSpecificLabels.get(i); mv.visitLabel(catchSpecificLabel); compilation.getLocalVarsSymbolTable().enterBlock(); int catchedExcIndex = compilation.getLocalVarsSymbolTable().add(catchClause.getVariableName(), catchClause); new LocalVarAssignmentBS(catchedExcIndex, JvmTypeCategory.REFERENCE).operate(mv); compile(catchClause.getBody()).operate(mv); compilation.getLocalVarsSymbolTable().exitBlock(); mv.visitJumpInsn(Opcodes.GOTO, afterTryCatch); i++; } mv.visitLabel(afterTryCatch); } }; } }