package com.tngtech.archunit.core.domain; import java.io.File; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableSet; import com.tngtech.archunit.base.DescribedPredicate; import com.tngtech.archunit.base.Optional; import com.tngtech.archunit.core.domain.AccessTarget.ConstructorCallTarget; import com.tngtech.archunit.core.domain.AccessTarget.FieldAccessTarget; import com.tngtech.archunit.core.domain.AccessTarget.MethodCallTarget; import com.tngtech.archunit.core.domain.JavaFieldAccess.AccessType; import com.tngtech.archunit.core.domain.Source.Md5sum; import com.tngtech.archunit.core.importer.ClassFileImporter; import com.tngtech.archunit.core.importer.DomainBuilders.JavaMethodCallBuilder; import com.tngtech.archunit.core.importer.ImportTestUtils; import com.tngtech.archunit.core.importer.ImportTestUtils.ImportedTestClasses; import org.assertj.core.util.Files; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getOnlyElement; import static com.tngtech.archunit.core.domain.Formatters.formatMethod; import static com.tngtech.archunit.core.domain.JavaConstructor.CONSTRUCTOR_NAME; import static com.tngtech.archunit.core.importer.ImportTestUtils.newFieldAccess; import static com.tngtech.archunit.core.importer.ImportTestUtils.newMethodCall; import static org.assertj.core.util.Files.temporaryFolderPath; import static org.assertj.core.util.Strings.concat; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class TestUtils { public static final Md5sum MD5_SUM_DISABLED = Md5sum.DISABLED; private static final Random random = new Random(); /** * NOTE: The resolution of {@link Files#newTemporaryFolder()}, using {@link System#currentTimeMillis()} * is not good enough and makes tests flaky. */ public static File newTemporaryFolder() { String folderName = "archtmp" + System.nanoTime() + random.nextLong(); File folder = new File(concat(temporaryFolderPath(), folderName)); if (folder.exists()) { Files.delete(folder); } checkArgument(folder.mkdirs(), "Folder %s already exists", folder.getAbsolutePath()); folder.deleteOnExit(); return folder; } public static Object invoke(Method method, Object owner, Object... params) { try { method.setAccessible(true); return method.invoke(owner, params); } catch (Exception e) { throw new RuntimeException(e); } } public static JavaMethod javaMethodViaReflection(Class<?> owner, String name, Class<?>... args) { return javaMethodViaReflection(javaClassViaReflection(owner), name, args); } public static JavaMethod javaMethodViaReflection(JavaClass clazz, String name, Class<?>... args) { try { return javaMethodViaReflection(clazz, clazz.reflect().getDeclaredMethod(name, args)); } catch (NoSuchMethodException e) { throw new RuntimeException(e); } } public static JavaMethod javaMethodViaReflection(JavaClass clazz, Method method) { return ImportTestUtils.javaMethodViaReflection(clazz, method); } public static JavaClassList javaClassList(Class<?>... types) { List<JavaClass> classes = new ArrayList<>(); for (Class<?> type : types) { classes.add(javaClassViaReflection(type)); } return new JavaClassList(classes); } public static JavaClasses importClasses(Class<?>... classes) { return new ClassFileImporter().importClasses(classes); } public static ImportedContext withinImportedClasses(Class<?>... contextClasses) { return new ImportedContext(importClasses(contextClasses)); } public static Md5sum md5sumOf(byte[] bytes) { return Md5sum.of(bytes); } public static JavaClass javaClassViaReflection(Class<?> owner) { return getOnlyElement(javaClassesViaReflection(owner)); } public static JavaField javaFieldViaReflection(Class<?> owner, String name) { return javaFieldViaReflection(javaClassViaReflection(owner), name); } public static JavaField javaFieldViaReflection(JavaClass owner, String name) { try { Field field = owner.reflect().getDeclaredField(name); return ImportTestUtils.javaFieldViaReflection(field, owner); } catch (NoSuchFieldException e) { throw new RuntimeException(e); } } public static JavaClasses javaClassesViaReflection(Class<?>... classes) { final ImportedTestClasses importedClasses = ImportTestUtils.simpleImportedClasses(); final Map<String, JavaClass> result = new HashMap<>(); for (Class<?> aClass : classes) { JavaClass newClass = ImportTestUtils.simulateImport(aClass, importedClasses); result.put(newClass.getName(), newClass); } ImportContext context = simulateContextForCompletion(importedClasses); for (JavaClass javaClass : result.values()) { javaClass.completeClassHierarchyFrom(context); } return JavaClasses.of(result, context); } private static ImportContext simulateContextForCompletion(final ImportedTestClasses importedClasses) { ImportContext context = mock(ImportContext.class); when(context.createSuperClass(any(JavaClass.class))).thenAnswer(new Answer<Optional<JavaClass>>() { @Override public Optional<JavaClass> answer(InvocationOnMock invocation) throws Throwable { Class<?> clazz = classForName(((JavaClass) invocation.getArguments()[0]).getName()); return clazz.getSuperclass() != null ? Optional.of(importedClasses.get(clazz.getSuperclass().getName())) : Optional.<JavaClass>absent(); } }); when(context.createInterfaces(any(JavaClass.class))).thenAnswer(new Answer<Set<JavaClass>>() { @Override public Set<JavaClass> answer(InvocationOnMock invocation) throws Throwable { Class<?> clazz = classForName(((JavaClass) invocation.getArguments()[0]).getName()); ImmutableSet.Builder<JavaClass> result = ImmutableSet.builder(); for (Class<?> iface : clazz.getInterfaces()) { result.add(importedClasses.get(iface.getName())); } return result.build(); } }); when(context.getJavaClassWithType(anyString())).thenAnswer(new Answer<JavaClass>() { @Override public JavaClass answer(InvocationOnMock invocation) throws Throwable { String typeName = (String) invocation.getArguments()[0]; return importedClasses.get(typeName); } }); when(context.createEnclosingClass(any(JavaClass.class))).thenAnswer(new Answer<Optional<JavaClass>>() { @Override public Optional<JavaClass> answer(InvocationOnMock invocation) throws Throwable { Class<?> clazz = classForName(((JavaClass) invocation.getArguments()[0]).getName()); return clazz.getEnclosingClass() != null ? Optional.of(importedClasses.get(clazz.getEnclosingClass().getName())) : Optional.<JavaClass>absent(); } }); return context; } public static JavaMethodCallBuilder newMethodCallBuilder(JavaMethod origin, MethodCallTarget target, int lineNumber) { return ImportTestUtils.newMethodCallBuilder(origin, target, lineNumber); } public static AccessesSimulator simulateCall() { return new AccessesSimulator(); } public static DescribedPredicate<Object> predicateWithDescription(String description) { return DescribedPredicate.alwaysTrue().as(description); } public static MethodCallTarget resolvedTargetFrom(JavaMethod target) { return ImportTestUtils.targetFrom(target, Suppliers.ofInstance(Collections.singleton(target))); } static MethodCallTarget unresolvedTargetFrom(JavaMethod target) { return ImportTestUtils.targetFrom(target, Suppliers.ofInstance(Collections.<JavaMethod>emptySet())); } public static Class[] asClasses(List<JavaClass> parameters) { List<Class> result = new ArrayList<>(); for (JavaClass javaClass : parameters) { result.add(javaClass.reflect()); } return result.toArray(new Class[result.size()]); } public static Class<?> classForName(String name) { return JavaType.From.name(name).resolveClass(); } static ImportedTestClasses simpleImportedClasses() { return ImportTestUtils.simpleImportedClasses(); } static JavaAnnotation javaAnnotationFrom(Annotation annotation) { return ImportTestUtils.javaAnnotationFrom(annotation); } public static FieldAccessTarget targetFrom(JavaField javaField) { return ImportTestUtils.targetFrom(javaField); } public static ConstructorCallTarget targetFrom(JavaConstructor constructor) { return ImportTestUtils.targetFrom(constructor); } public static Dependency dependencyFrom(JavaAccess<?> access) { return Dependency.from(access); } public static class AccessesSimulator { private final Set<MethodCallTarget> targets = new HashSet<>(); public AccessSimulator from(JavaMethod method, int lineNumber) { return new AccessSimulator(targets, method, lineNumber); } public AccessSimulator from(Class<?> clazz, String methodName, Class<?>... params) { return new AccessSimulator(targets, javaMethodViaReflection(clazz, methodName, params), 0); } public AccessSimulator from(JavaClass clazz, String methodName, Class<?>... params) { return from(clazz.getMethod(methodName, params), 0); } } public static class AccessSimulator { private final Set<MethodCallTarget> targets; private final JavaMethod method; private final int lineNumber; private AccessSimulator(Set<MethodCallTarget> targets, JavaMethod method, int lineNumber) { this.targets = targets; this.method = method; this.lineNumber = lineNumber; } public JavaMethodCall to(JavaMethod target) { return to(resolvedTargetFrom(target)); } private JavaMethodCall to(MethodCallTarget methodCallTarget) { targets.add(methodCallTarget); ImportContext context = mock(ImportContext.class); Set<JavaMethodCall> calls = new HashSet<>(); for (MethodCallTarget target : targets) { calls.add(newMethodCall(method, target, lineNumber)); } when(context.getMethodCallsFor(method)).thenReturn(ImmutableSet.copyOf(calls)); method.completeFrom(context); return getCallToTarget(methodCallTarget); } public JavaMethodCall to(Class<?> clazz, String methodName, Class<?>... params) { return to(resolvedTargetFrom(javaMethodViaReflection(clazz, methodName, params))); } public JavaMethodCall toUnresolved(Class<?> clazz, String methodName, Class<?>... params) { return to(unresolvedTargetFrom(javaMethodViaReflection(clazz, methodName, params))); } private JavaMethodCall getCallToTarget(MethodCallTarget callTarget) { Set<JavaMethodCall> matchingCalls = new HashSet<>(); for (JavaMethodCall call : method.getMethodCallsFromSelf()) { if (call.getTarget().equals(callTarget)) { matchingCalls.add(call); } } return getOnlyElement(matchingCalls); } public void to(JavaField target, AccessType accessType) { ImportContext context = mock(ImportContext.class); when(context.getFieldAccessesFor(method)) .thenReturn(ImmutableSet.of( newFieldAccess(method, target, lineNumber, accessType) )); method.completeFrom(context); } } public static class ImportedContext { private final JavaClasses classes; private ImportedContext(JavaClasses classes) { this.classes = classes; } public CallRetriever getCallFrom(Class<?> originClass, String codeUnitName, Class<?>... paramTypes) { JavaClass owner = classes.get(originClass); return new CallRetriever(owner.getCodeUnitWithParameterTypes(codeUnitName, paramTypes)); } } public static class CallRetriever { private final JavaCodeUnit codeUnit; private CallRetriever(JavaCodeUnit codeUnit) { this.codeUnit = codeUnit; } public JavaConstructorCall toConstructor(Class<?> targetOwner, Class<?>... paramTypes) { return findMethod(codeUnit.getConstructorCallsFromSelf(), targetOwner, CONSTRUCTOR_NAME, paramTypes); } public JavaMethodCall toMethod(Class<?> targetOwner, String methodName, Class<?>... paramTypes) { return findMethod(codeUnit.getMethodCallsFromSelf(), targetOwner, methodName, paramTypes); } private <T extends JavaCall<?>> T findMethod(Set<T> callsFromSelf, Class<?> targetOwner, String methodName, Class<?>[] paramTypes) { List<String> paramNames = JavaClass.namesOf(paramTypes); for (T call : callsFromSelf) { if (call.getTargetOwner().isEquivalentTo(targetOwner) && call.getTarget().getName().equals(methodName) && call.getTarget().getParameters().getNames().equals(paramNames)) { return call; } } throw new IllegalStateException( "Couldn't find any call with target " + formatMethod(targetOwner.getName(), methodName, paramNames)); } } }