/*
* Copyright (c) 2002-2012 Alibaba Group Holding Limited.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.citrus.test;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.io.File;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.hamcrest.Matcher;
import org.slf4j.MDC;
/**
* 方便测试的工具类。
*
* @author Michael Zhou
*/
public class TestUtil {
private final static ThreadLocal<String> TEST_NAME_HOLDER = new ThreadLocal<String>();
public static String getTestName() {
return TEST_NAME_HOLDER.get();
}
public static void setTestName(String name) {
if (name == null) {
TEST_NAME_HOLDER.remove();
MDC.remove("testName");
} else {
TEST_NAME_HOLDER.set(name);
MDC.put("testName", name);
}
}
public static <T extends Throwable> Matcher<T> exception(Class<? extends Throwable> cause, String... snippets) {
return new ExceptionMatcher<T>(cause, snippets);
}
public static <T extends Throwable> Matcher<T> exception(String... snippets) {
return new ExceptionMatcher<T>(snippets);
}
public static Matcher<String> containsRegex(String regex) {
return new RegexMatcher(regex);
}
public static Matcher<String> containsAll(String... strs) {
List<Matcher<? extends String>> list = new ArrayList<Matcher<? extends String>>();
for (String str : strs) {
list.add(containsString(str));
}
return allOf(list);
}
public static Matcher<String> containsAllRegex(String... regexes) {
List<Matcher<? extends String>> list = new ArrayList<Matcher<? extends String>>();
for (String regex : regexes) {
list.add(containsRegex(regex));
}
return allOf(list);
}
public static File getJavaHome() {
File javaHome = new File(System.getProperty("java.home"));
if ("jre".equals(javaHome.getName())) {
javaHome = javaHome.getParentFile();
}
return javaHome;
}
public static File getClassesDir(Class<?> classWithinDir) {
return getClassesDir(classWithinDir == null ? null : classWithinDir.getName());
}
public static File getClassesDir(String classWithinDir) {
String clazzResourceName = getResourceNameOfClass(classWithinDir) + ".class";
File classFile;
try {
classFile = new File(Thread.currentThread().getContextClassLoader().getResource(clazzResourceName).toURI());
} catch (URISyntaxException e) {
throw new RuntimeException("Could not find classes dir of " + classWithinDir, e);
}
File classesdir = classFile.getParentFile();
for (; classesdir != null && classesdir.isDirectory(); classesdir = classesdir.getParentFile()) {
if (new File(classesdir, clazzResourceName).equals(classFile)) {
break;
}
}
if (!classFile.equals(new File(classesdir, clazzResourceName))) {
throw new RuntimeException("Could not find classes dir of " + classWithinDir);
}
return classesdir;
}
public static File[] getClassDirs(Class<?>... classes) {
Set<File> dirs = new LinkedHashSet<File>();
for (Class<?> clazz : classes) {
dirs.add(getClassesDir(clazz));
}
return dirs.toArray(new File[dirs.size()]);
}
private static String getResourceNameOfClass(String className) {
if (className == null) {
return null;
}
return className.trim().replace('.', '/');
}
/** 取得field,并设置为可访问。 */
public static Field getAccessibleField(Class<?> targetType, String fieldName) {
assertNotNull("missing targetType", targetType);
Field field = null;
for (Class<?> c = targetType; c != null && field == null; c = c.getSuperclass()) {
try {
field = c.getDeclaredField(fieldName);
} catch (NoSuchFieldException e) {
} catch (Exception e) {
fail(e.toString());
return null;
}
}
assertNotNull("field " + fieldName + " not found in " + targetType, field);
field.setAccessible(true);
return field;
}
/** 取得method,并设置为可访问。 */
public static Method getAccessibleMethod(Class<?> targetType, String methodName, Class<?>[] argTypes) {
assertNotNull("missing targetType", targetType);
Method method = null;
for (Class<?> c = targetType; c != null && method == null; c = c.getSuperclass()) {
try {
method = c.getDeclaredMethod(methodName, argTypes);
} catch (NoSuchMethodException e) {
} catch (Exception e) {
fail(e.toString());
return null;
}
}
assertNotNull("method " + methodName + " not found in " + targetType, method);
method.setAccessible(true);
return method;
}
/** 取得field值,即使private也可以。 */
public static <T> T getFieldValue(Object target, String fieldName, Class<T> fieldType) {
return getFieldValue(target, null, fieldName, fieldType);
}
/** 取得field值,即使private也可以。 */
@SuppressWarnings("unchecked")
public static <T> T getFieldValue(Object target, Class<?> targetType, String fieldName, Class<T> fieldType) {
if (targetType == null && target != null) {
targetType = target.getClass();
}
Field field = getAccessibleField(targetType, fieldName);
Object value = null;
try {
value = field.get(target);
} catch (Exception e) {
fail(e.toString());
return null;
}
if (fieldType != null) {
return fieldType.cast(value);
} else {
return (T) value;
}
}
/** 执行方法,即使private也没关系。 */
public static <T> T invokeMethod(Object target, String methodName, Class<?>[] argTypes, Object[] args,
Class<T> returnType) throws IllegalArgumentException, IllegalAccessException,
InvocationTargetException {
return invokeMethod(target, null, methodName, argTypes, args, returnType);
}
/** 执行方法,即使private也没关系。 */
@SuppressWarnings("unchecked")
public static <T> T invokeMethod(Object target, Class<?> targetType, String methodName, Class<?>[] argTypes,
Object[] args, Class<T> returnType) throws IllegalArgumentException,
IllegalAccessException,
InvocationTargetException {
if (targetType == null && target != null) {
targetType = target.getClass();
}
Method method = getAccessibleMethod(targetType, methodName, argTypes);
Object value = method.invoke(target, args);
if (returnType != null) {
return returnType.cast(value);
} else {
return (T) value;
}
}
}