/*
* Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.tck;
import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.impl.TVMCI;
import com.oracle.truffle.api.nodes.RootNode;
import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.junit.Test;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.Statement;
import org.junit.runners.model.TestClass;
import com.oracle.truffle.tck.TruffleRunner.Inject;
final class TruffleTestInvoker<T extends CallTarget> extends TVMCI.TestAccessor<T> {
static TruffleTestInvoker<?> create() {
TVMCI.Test<?> testTvmci = Truffle.getRuntime().getCapability(TVMCI.Test.class);
return new TruffleTestInvoker<>(testTvmci);
}
private TruffleTestInvoker(TVMCI.Test<T> testTvmci) {
super(testTvmci);
}
private static int getWarmupIterations(FrameworkMethod method) {
TruffleRunner.Warmup warmup = method.getAnnotation(TruffleRunner.Warmup.class);
if (warmup != null) {
return warmup.value();
} else {
return 3;
}
}
private static RootNode[] createTestRootNodes(TestClass testClass, FrameworkMethod testMethod, Object test) {
int paramCount = testMethod.getMethod().getParameterCount();
if (paramCount == 0) {
// non-truffle test
return null;
}
RootNode[] testNodes = new RootNode[paramCount];
for (int i = 0; i < paramCount; i++) {
Inject testRootNode = findRootNodeAnnotation(testMethod.getMethod().getParameterAnnotations()[i]);
Function<Object, RootNode> cons = getNodeConstructor(testRootNode, testClass);
testNodes[i] = cons.apply(test);
}
return testNodes;
}
Statement createStatement(String testName, TestClass testClass, FrameworkMethod method, Object test) {
final RootNode[] testNodes = createTestRootNodes(testClass, method, test);
if (testNodes == null) {
return null;
}
final int warmupIterations = getWarmupIterations(method);
return new Statement() {
@Override
public void evaluate() throws Throwable {
ArrayList<T> callTargets = new ArrayList<>(testNodes.length);
for (RootNode testNode : testNodes) {
callTargets.add(createTestCallTarget(testName, testNode));
}
Object[] args = callTargets.toArray();
for (int i = 0; i < warmupIterations; i++) {
method.invokeExplosively(test, args);
}
for (T callTarget : callTargets) {
finishWarmup(callTarget);
}
method.invokeExplosively(test, args);
}
};
}
private static Inject findRootNodeAnnotation(Annotation[] annotations) {
for (Annotation a : annotations) {
if (a instanceof Inject) {
return (Inject) a;
}
}
return null;
}
private static Function<Object, RootNode> getNodeConstructor(Inject annotation, TestClass testClass) {
Class<? extends RootNode> nodeClass = annotation.value();
try {
Constructor<? extends RootNode> cons = nodeClass.getConstructor(testClass.getJavaClass());
return (obj) -> {
try {
return cons.newInstance(obj);
} catch (IllegalAccessException | IllegalArgumentException | InstantiationException | InvocationTargetException ex) {
throw new AssertionError(ex);
}
};
} catch (NoSuchMethodException e) {
try {
Constructor<? extends RootNode> cons = nodeClass.getConstructor();
return (obj) -> {
try {
return cons.newInstance();
} catch (IllegalAccessException | IllegalArgumentException | InstantiationException | InvocationTargetException ex) {
throw new AssertionError(ex);
}
};
} catch (NoSuchMethodException ex) {
return null;
}
}
}
static void validateTestMethods(TestClass testClass, List<Throwable> errors) {
List<FrameworkMethod> methods = testClass.getAnnotatedMethods(Test.class);
for (FrameworkMethod method : methods) {
method.validatePublicVoid(false, errors);
Annotation[][] parameterAnnotations = method.getMethod().getParameterAnnotations();
Class<?>[] parameterTypes = method.getMethod().getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
if (parameterTypes[i] == CallTarget.class) {
TruffleRunner.Inject testRootNode = findRootNodeAnnotation(parameterAnnotations[i]);
if (testRootNode == null) {
errors.add(new Exception("CallTarget parameter of test method " + method.getName() + " should have @Inject annotation"));
} else {
if (getNodeConstructor(testRootNode, testClass) == null) {
errors.add(new Exception("Node " + testRootNode.value().getName() + " should have a default constructor or a constructor taking a " + testClass.getName()));
}
}
} else {
errors.add(new Exception("Invalid parameter type " + parameterTypes[i].getSimpleName() + " on test method " + method.getName()));
}
}
}
}
}