/*
* Copyright (c) 2002-2012 Alibaba Group Holding Limited.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.citrus.test.runner;
import static org.junit.Assert.*;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import com.alibaba.citrus.test.TestUtil;
import org.junit.runner.Runner;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.BlockJUnit4ClassRunner;
import org.junit.runners.Parameterized;
import org.junit.runners.Suite;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InitializationError;
import org.junit.runners.model.Statement;
import org.junit.runners.model.TestClass;
/**
* 类似{@link Parameterized},作了如下改进和变化:
* <p>
* <ul>
* <li>通过原型对象来创建test case对象</li> *
* <li>可通过@TestName注释指定的方法(非static方法)来设置测试组的名称,如“[0] xxx”,而不只是[0]、[1]、[2]。</li>
* <li>支持多个@Prototypes方法。</li>
* <li>支持<code>TestUtil.getTestName()</code>,以便在测试中取得当前测试的名称。</li>
* </ul>
* </p>
*/
public class Prototyped extends Suite {
/** 系统将通过标记此注释的方法,取得测试的名称。 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public static @interface TestName {
}
/** 系统将通过标记此注释的方法,取得测试的数据。 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public static @interface Prototypes {
}
/** 用来简化Prototyped测试。 */
public static class TestData<T> extends LinkedList<T> {
private static final long serialVersionUID = 2818372350747718688L;
private final Class<T> prototypeClass;
public static <T> TestData<T> getInstance(Class<T> prototypeClass) {
return new TestData<T>(prototypeClass);
}
public TestData(Class<T> prototypeClass) {
assertNotNull("prototypeClass not specified", prototypeClass);
this.prototypeClass = prototypeClass;
}
public T newPrototype() {
T prototype = null;
try {
prototype = prototypeClass.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
super.add(prototype);
return prototype;
}
}
private static class TestClassRunnerForPrototypes extends BlockJUnit4ClassRunner {
private final Object fPrototype;
private final int fPrototypeNumber;
TestClassRunnerForPrototypes(Class<?> type, Object prototype, int i) throws InitializationError {
super(type);
fPrototype = prototype;
fPrototypeNumber = i;
}
@Override
public Object createTest() throws Exception {
if (fPrototype instanceof Cloneable && getTestClass().getJavaClass().isInstance(fPrototype)) {
Method cloneMethod = null;
for (Class<?> clazz = fPrototype.getClass(); clazz != null; clazz = clazz.getSuperclass()) {
try {
cloneMethod = clazz.getDeclaredMethod("clone");
cloneMethod.setAccessible(true);
break;
} catch (NoSuchMethodException e) {
}
}
return cloneMethod.invoke(fPrototype);
}
fail(String.format("Class %s is not Cloneable", getTestClass().getJavaClass().getSimpleName()));
return null;
}
@Override
protected String getName() {
List<FrameworkMethod> methods = new ArrayList<FrameworkMethod>(getTestClass().getAnnotatedMethods(
TestName.class));
for (FrameworkMethod each : methods) {
int modifiers = each.getMethod().getModifiers();
if (!Modifier.isStatic(modifiers) && Modifier.isPublic(modifiers)
&& String.class.equals(each.getMethod().getReturnType())
&& each.getMethod().getParameterTypes().length == 0) {
String name = null;
try {
name = (String) each.invokeExplosively(fPrototype);
assertNotNull(String.format("%s.%s() returned null", getTestClass().getName(), each.getName()),
name);
return String.format("[%s] %s", fPrototypeNumber, name);
} catch (Error e) {
throw e;
} catch (RuntimeException e) {
throw e;
} catch (Throwable e) {
throw new RuntimeException(e);
}
} else {
throw new RuntimeException(String.format(
"%s.%s() should be public, non-static, accept no arguments, and return String",
getTestClass().getName(), each.getName()));
}
}
return String.format("[%s]", fPrototypeNumber);
}
@Override
protected String testName(final FrameworkMethod method) {
return String.format("%s[%s]", method.getName(), fPrototypeNumber);
}
@Override
protected void validateZeroArgConstructor(List<Throwable> errors) {
// constructor can, nay, should have args.
}
@Override
protected Statement classBlock(RunNotifier notifier) {
return childrenInvoker(notifier);
}
@Override
protected void runChild(FrameworkMethod method, RunNotifier notifier) {
TestUtil.setTestName(method.getName());
try {
super.runChild(method, notifier);
} finally {
TestUtil.setTestName(null);
}
}
}
/** Only called reflectively. Do not use programmatically. */
public Prototyped(Class<?> klass) throws Throwable {
super(klass, getRunners(klass));
}
private static List<Runner> getRunners(Class<?> klass) throws Throwable, InitializationError {
List<Runner> runners = new ArrayList<Runner>();
int i = 0;
for (final Object each : getPrototypesList(klass)) {
runners.add(new TestClassRunnerForPrototypes(klass, each, i++));
}
return runners;
}
private static Collection<Object> getPrototypesList(Class<?> klass) throws Throwable {
Collection<Object> prototypeList = new ArrayList<Object>();
for (FrameworkMethod method : getPrototypesMethods(klass)) {
@SuppressWarnings("unchecked")
Collection<Object> results = (Collection<Object>) method.invokeExplosively(null);
for (final Object each : results) {
if (!klass.isInstance(each)) {
throw new Exception(String.format("%s.%s() must return a Collection of test object.",
klass.getName(), method.getName()));
}
}
prototypeList.addAll(results);
}
return prototypeList;
}
private static List<FrameworkMethod> getPrototypesMethods(Class<?> testClass) throws Exception {
List<FrameworkMethod> methods = new TestClass(testClass).getAnnotatedMethods(Prototypes.class);
if (methods.isEmpty()) {
throw new Exception("No public static prototypes method on class " + testClass.getName());
}
for (FrameworkMethod each : methods) {
int modifiers = each.getMethod().getModifiers();
if (!Modifier.isStatic(modifiers) || !Modifier.isPublic(modifiers)) {
throw new Exception(String.format("%s.%s() should be public static method: ", testClass.getName(),
each.getName()));
}
}
return methods;
}
}