/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.test.context.support;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import org.springframework.beans.BeanUtils;
import org.springframework.core.GenericTypeResolver;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.test.context.TestSecurityContextHolder;
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestExecutionListener;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.test.util.MetaAnnotationUtils;
import org.springframework.test.web.servlet.MockMvc;
/**
* A {@link TestExecutionListener} that will find annotations that are annotated with
* {@link WithSecurityContext} on a test method or at the class level. If found, the
* {@link WithSecurityContext#factory()} is used to create a {@link SecurityContext} that
* will be used with this test. If using with {@link MockMvc} the
* {@link SecurityMockMvcRequestPostProcessors#testSecurityContext()} needs to be used
* too.
*
* @author Rob Winch
* @author EddĂș MelĂ©ndez
* @since 4.0
*/
public class WithSecurityContextTestExecutionListener
extends AbstractTestExecutionListener {
/**
* Sets up the {@link SecurityContext} for each test method. First the specific method
* is inspected for a {@link WithSecurityContext} or {@link Annotation} that has
* {@link WithSecurityContext} on it. If that is not found, the class is inspected. If
* still not found, then no {@link SecurityContext} is populated.
*/
@Override
public void beforeTestMethod(TestContext testContext) throws Exception {
SecurityContext securityContext = createSecurityContext(
testContext.getTestMethod(), testContext);
if (securityContext == null) {
securityContext = createSecurityContext(testContext.getTestClass(),
testContext);
}
if (securityContext != null) {
TestSecurityContextHolder.setContext(securityContext);
}
}
private SecurityContext createSecurityContext(AnnotatedElement annotated,
TestContext context) {
WithSecurityContext withSecurityContext = AnnotationUtils
.findAnnotation(annotated, WithSecurityContext.class);
return createSecurityContext(annotated, withSecurityContext, context);
}
private SecurityContext createSecurityContext(Class<?> annotated,
TestContext context) {
MetaAnnotationUtils.AnnotationDescriptor<WithSecurityContext> withSecurityContextDescriptor = MetaAnnotationUtils
.findAnnotationDescriptor(annotated, WithSecurityContext.class);
WithSecurityContext withSecurityContext = withSecurityContextDescriptor == null
? null : withSecurityContextDescriptor.getAnnotation();
return createSecurityContext(annotated, withSecurityContext, context);
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private SecurityContext createSecurityContext(AnnotatedElement annotated,
WithSecurityContext withSecurityContext, TestContext context) {
if (withSecurityContext == null) {
return null;
}
WithSecurityContextFactory factory = createFactory(withSecurityContext, context);
Class<? extends Annotation> type = (Class<? extends Annotation>) GenericTypeResolver
.resolveTypeArgument(factory.getClass(),
WithSecurityContextFactory.class);
Annotation annotation = findAnnotation(annotated, type);
try {
return factory.createSecurityContext(annotation);
}
catch (RuntimeException e) {
throw new IllegalStateException(
"Unable to create SecurityContext using " + annotation, e);
}
}
private Annotation findAnnotation(AnnotatedElement annotated,
Class<? extends Annotation> type) {
Annotation findAnnotation = AnnotationUtils.findAnnotation(annotated, type);
if (findAnnotation != null) {
return findAnnotation;
}
Annotation[] allAnnotations = AnnotationUtils.getAnnotations(annotated);
for (Annotation annotationToTest : allAnnotations) {
WithSecurityContext withSecurityContext = AnnotationUtils.findAnnotation(
annotationToTest.annotationType(), WithSecurityContext.class);
if (withSecurityContext != null) {
return annotationToTest;
}
}
return null;
}
private WithSecurityContextFactory<? extends Annotation> createFactory(
WithSecurityContext withSecurityContext, TestContext testContext) {
Class<? extends WithSecurityContextFactory<? extends Annotation>> clazz = withSecurityContext
.factory();
try {
return testContext.getApplicationContext().getAutowireCapableBeanFactory()
.createBean(clazz);
}
catch (IllegalStateException e) {
return BeanUtils.instantiateClass(clazz);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* Clears out the {@link TestSecurityContextHolder} and the
* {@link SecurityContextHolder} after each test method.
*/
@Override
public void afterTestMethod(TestContext testContext) throws Exception {
TestSecurityContextHolder.clearContext();
}
/**
* Returns {@code 10000}.
*/
@Override
public int getOrder() {
return 10000;
}
}