/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.method.annotation;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Before;
import org.junit.Test;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.ui.Model;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.support.DefaultDataBinderFactory;
import org.springframework.web.bind.support.DefaultSessionAttributeStore;
import org.springframework.web.bind.support.SessionAttributeStore;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.support.HandlerMethodArgumentResolverComposite;
import org.springframework.web.method.support.InvocableHandlerMethod;
import org.springframework.web.method.support.ModelAndViewContainer;
import static org.junit.Assert.*;
/**
* Unit tests verifying {@code @ModelAttribute} method inter-dependencies.
*
* @author Rossen Stoyanchev
*/
public class ModelFactoryOrderingTests {
private static final Log logger = LogFactory.getLog(ModelFactoryOrderingTests.class);
private NativeWebRequest webRequest;
private ModelAndViewContainer mavContainer;
private SessionAttributeStore sessionAttributeStore;
@Before
public void setup() {
this.sessionAttributeStore = new DefaultSessionAttributeStore();
this.webRequest = new ServletWebRequest(new MockHttpServletRequest(), new MockHttpServletResponse());
this.mavContainer = new ModelAndViewContainer();
this.mavContainer.addAttribute("methods", new ArrayList<String>());
}
@Test
public void straightLineDependency() throws Exception {
runTest(new StraightLineDependencyController());
assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getB1", "getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getC2", "getC3", "getC4");
assertInvokedBefore("getC3", "getC4");
}
@Test
public void treeDependency() throws Exception {
runTest(new TreeDependencyController());
assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4");
assertInvokedBefore("getB1", "getC1", "getC2");
assertInvokedBefore("getB2", "getC3", "getC4");
}
@Test
public void InvertedTreeDependency() throws Exception {
runTest(new InvertedTreeDependencyController());
assertInvokedBefore("getC1", "getA", "getB1");
assertInvokedBefore("getC2", "getA", "getB1");
assertInvokedBefore("getC3", "getA", "getB2");
assertInvokedBefore("getC4", "getA", "getB2");
assertInvokedBefore("getB1", "getA");
assertInvokedBefore("getB2", "getA");
}
@Test
public void unresolvedDependency() throws Exception {
runTest(new UnresolvedDependencyController());
assertInvokedBefore("getA", "getC1", "getC2", "getC3", "getC4");
// No other order guarantees for methods with unresolvable dependencies (and methods that depend on them),
// Required dependencies will be created via default constructor.
}
private void runTest(Object controller) throws Exception {
HandlerMethodArgumentResolverComposite resolvers = new HandlerMethodArgumentResolverComposite();
resolvers.addResolver(new ModelAttributeMethodProcessor(false));
resolvers.addResolver(new ModelMethodProcessor());
WebDataBinderFactory dataBinderFactory = new DefaultDataBinderFactory(null);
Class<?> type = controller.getClass();
Set<Method> methods = MethodIntrospector.selectMethods(type, METHOD_FILTER);
List<InvocableHandlerMethod> modelMethods = new ArrayList<>();
for (Method method : methods) {
InvocableHandlerMethod modelMethod = new InvocableHandlerMethod(controller, method);
modelMethod.setHandlerMethodArgumentResolvers(resolvers);
modelMethod.setDataBinderFactory(dataBinderFactory);
modelMethods.add(modelMethod);
}
Collections.shuffle(modelMethods);
SessionAttributesHandler sessionHandler = new SessionAttributesHandler(type, this.sessionAttributeStore);
ModelFactory factory = new ModelFactory(modelMethods, dataBinderFactory, sessionHandler);
factory.initModel(this.webRequest, this.mavContainer, new HandlerMethod(controller, "handle"));
if (logger.isDebugEnabled()) {
StringBuilder sb = new StringBuilder();
for (String name : getInvokedMethods()) {
sb.append(" >> ").append(name);
}
logger.debug(sb);
}
}
private void assertInvokedBefore(String beforeMethod, String... afterMethods) {
List<String> actual = getInvokedMethods();
for (String afterMethod : afterMethods) {
assertTrue(beforeMethod + " should be before " + afterMethod + ". Actual order: " +
actual.toString(), actual.indexOf(beforeMethod) < actual.indexOf(afterMethod));
}
}
@SuppressWarnings("unchecked")
private List<String> getInvokedMethods() {
return (List<String>) this.mavContainer.getModel().get("methods");
}
private static class AbstractController {
@RequestMapping
public void handle() {
}
@SuppressWarnings("unchecked")
<T> T updateAndReturn(Model model, String methodName, T returnValue) throws IOException {
((List<String>) model.asMap().get("methods")).add(methodName);
return returnValue;
}
}
private static class StraightLineDependencyController extends AbstractController {
@ModelAttribute
public A getA(Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
@ModelAttribute
public B1 getB1(@ModelAttribute A a, Model model) throws IOException {
return updateAndReturn(model, "getB1", new B1());
}
@ModelAttribute
public B2 getB2(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getB2", new B2());
}
@ModelAttribute
public C1 getC1(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(@ModelAttribute C1 c1, Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(@ModelAttribute C2 c2, Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(@ModelAttribute C3 c3, Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
}
private static class TreeDependencyController extends AbstractController {
@ModelAttribute
public A getA(Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
@ModelAttribute
public B1 getB1(@ModelAttribute A a, Model model) throws IOException {
return updateAndReturn(model, "getB1", new B1());
}
@ModelAttribute
public B2 getB2(@ModelAttribute A a, Model model) throws IOException {
return updateAndReturn(model, "getB2", new B2());
}
@ModelAttribute
public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
}
private static class InvertedTreeDependencyController extends AbstractController {
@ModelAttribute
public C1 getC1(Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
@ModelAttribute
public B1 getB1(@ModelAttribute C1 c1, @ModelAttribute C2 c2, Model model) throws IOException {
return updateAndReturn(model, "getB1", new B1());
}
@ModelAttribute
public B2 getB2(@ModelAttribute C3 c3, @ModelAttribute C4 c4, Model model) throws IOException {
return updateAndReturn(model, "getB2", new B2());
}
@ModelAttribute
public A getA(@ModelAttribute B1 b1, @ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
}
private static class UnresolvedDependencyController extends AbstractController {
@ModelAttribute
public A getA(Model model) throws IOException {
return updateAndReturn(model, "getA", new A());
}
@ModelAttribute
public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC1", new C1());
}
@ModelAttribute
public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException {
return updateAndReturn(model, "getC2", new C2());
}
@ModelAttribute
public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC3", new C3());
}
@ModelAttribute
public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException {
return updateAndReturn(model, "getC4", new C4());
}
}
private static class A { }
private static class B1 { }
private static class B2 { }
private static class C1 { }
private static class C2 { }
private static class C3 { }
private static class C4 { }
private static final ReflectionUtils.MethodFilter METHOD_FILTER = new ReflectionUtils.MethodFilter() {
@Override
public boolean matches(Method method) {
return ((AnnotationUtils.findAnnotation(method, RequestMapping.class) == null) &&
(AnnotationUtils.findAnnotation(method, ModelAttribute.class) != null));
}
};
}