/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.painless.node;
import org.elasticsearch.painless.CompilerSettings;
import org.elasticsearch.painless.Constant;
import org.elasticsearch.painless.Def;
import org.elasticsearch.painless.Definition;
import org.elasticsearch.painless.Definition.Method;
import org.elasticsearch.painless.Definition.Sort;
import org.elasticsearch.painless.Definition.Type;
import org.elasticsearch.painless.Globals;
import org.elasticsearch.painless.Locals;
import org.elasticsearch.painless.Locals.Parameter;
import org.elasticsearch.painless.Locals.Variable;
import org.elasticsearch.painless.Location;
import org.elasticsearch.painless.MethodWriter;
import org.elasticsearch.painless.WriterConstants;
import org.elasticsearch.painless.node.SSource.Reserved;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Opcodes;
import java.lang.invoke.MethodType;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import static java.util.Collections.emptyList;
import static org.elasticsearch.painless.WriterConstants.CLASS_TYPE;
/**
* Represents a user-defined function.
*/
public final class SFunction extends AStatement {
public static final class FunctionReserved implements Reserved {
private int maxLoopCounter = 0;
@Override
public void markUsedVariable(String name) {
// Do nothing.
}
@Override
public void setMaxLoopCounter(int max) {
maxLoopCounter = max;
}
@Override
public int getMaxLoopCounter() {
return maxLoopCounter;
}
}
final FunctionReserved reserved;
private final String rtnTypeStr;
public final String name;
private final List<String> paramTypeStrs;
private final List<String> paramNameStrs;
private final List<AStatement> statements;
public final boolean synthetic;
Type rtnType = null;
List<Parameter> parameters = new ArrayList<>();
Method method = null;
private Variable loop = null;
public SFunction(FunctionReserved reserved, Location location, String rtnType, String name,
List<String> paramTypes, List<String> paramNames, List<AStatement> statements,
boolean synthetic) {
super(location);
this.reserved = Objects.requireNonNull(reserved);
this.rtnTypeStr = Objects.requireNonNull(rtnType);
this.name = Objects.requireNonNull(name);
this.paramTypeStrs = Collections.unmodifiableList(paramTypes);
this.paramNameStrs = Collections.unmodifiableList(paramNames);
this.statements = Collections.unmodifiableList(statements);
this.synthetic = synthetic;
}
@Override
void extractVariables(Set<String> variables) {
// we should never be extracting from a function, as functions are top-level!
throw new IllegalStateException("Illegal tree structure");
}
void generateSignature(Definition definition) {
try {
rtnType = definition.getType(rtnTypeStr);
} catch (IllegalArgumentException exception) {
throw createError(new IllegalArgumentException("Illegal return type [" + rtnTypeStr + "] for function [" + name + "]."));
}
if (paramTypeStrs.size() != paramNameStrs.size()) {
throw createError(new IllegalStateException("Illegal tree structure."));
}
Class<?>[] paramClasses = new Class<?>[this.paramTypeStrs.size()];
List<Type> paramTypes = new ArrayList<>();
for (int param = 0; param < this.paramTypeStrs.size(); ++param) {
try {
Type paramType = definition.getType(this.paramTypeStrs.get(param));
paramClasses[param] = paramType.clazz;
paramTypes.add(paramType);
parameters.add(new Parameter(location, paramNameStrs.get(param), paramType));
} catch (IllegalArgumentException exception) {
throw createError(new IllegalArgumentException(
"Illegal parameter type [" + this.paramTypeStrs.get(param) + "] for function [" + name + "]."));
}
}
org.objectweb.asm.commons.Method method =
new org.objectweb.asm.commons.Method(name, MethodType.methodType(rtnType.clazz, paramClasses).toMethodDescriptorString());
this.method = new Method(name, null, false, rtnType, paramTypes, method, Modifier.STATIC | Modifier.PRIVATE, null);
}
@Override
void analyze(Locals locals) {
if (statements == null || statements.isEmpty()) {
throw createError(new IllegalArgumentException("Cannot generate an empty function [" + name + "]."));
}
locals = Locals.newLocalScope(locals);
AStatement last = statements.get(statements.size() - 1);
for (AStatement statement : statements) {
// Note that we do not need to check after the last statement because
// there is no statement that can be unreachable after the last.
if (allEscape) {
throw createError(new IllegalArgumentException("Unreachable statement."));
}
statement.lastSource = statement == last;
statement.analyze(locals);
methodEscape = statement.methodEscape;
allEscape = statement.allEscape;
}
if (!methodEscape && rtnType.sort != Sort.VOID) {
throw createError(new IllegalArgumentException("Not all paths provide a return value for method [" + name + "]."));
}
if (reserved.getMaxLoopCounter() > 0) {
loop = locals.getVariable(null, Locals.LOOP);
}
}
/** Writes the function to given ClassVisitor. */
void write (ClassVisitor writer, CompilerSettings settings, Globals globals) {
int access = Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC;
if (synthetic) {
access |= Opcodes.ACC_SYNTHETIC;
}
final MethodWriter function = new MethodWriter(access, method.method, writer, globals.getStatements(), settings);
function.visitCode();
write(function, globals);
function.endMethod();
}
@Override
void write(MethodWriter function, Globals globals) {
if (reserved.getMaxLoopCounter() > 0) {
// if there is infinite loop protection, we do this once:
// int #loop = settings.getMaxLoopCounter()
function.push(reserved.getMaxLoopCounter());
function.visitVarInsn(Opcodes.ISTORE, loop.getSlot());
}
for (AStatement statement : statements) {
statement.write(function, globals);
}
if (!methodEscape) {
if (rtnType.sort == Sort.VOID) {
function.returnValue();
} else {
throw createError(new IllegalStateException("Illegal tree structure."));
}
}
String staticHandleFieldName = Def.getUserFunctionHandleFieldName(name, parameters.size());
globals.addConstantInitializer(new Constant(location, WriterConstants.METHOD_HANDLE_TYPE,
staticHandleFieldName, this::initializeConstant));
}
private void initializeConstant(MethodWriter writer) {
final Handle handle = new Handle(Opcodes.H_INVOKESTATIC,
CLASS_TYPE.getInternalName(),
name,
method.method.getDescriptor(),
false);
writer.push(handle);
}
@Override
public String toString() {
List<Object> description = new ArrayList<>();
description.add(rtnTypeStr);
description.add(name);
if (false == (paramTypeStrs.isEmpty() && paramNameStrs.isEmpty())) {
description.add(joinWithName("Args", pairwiseToString(paramTypeStrs, paramNameStrs), emptyList()));
}
return multilineToString(description, statements);
}
}