/* * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.web.method.annotation; import java.io.IOException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.Before; import org.junit.Test; import org.springframework.core.MethodIntrospector; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.ui.Model; import org.springframework.util.ReflectionUtils; import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.support.DefaultDataBinderFactory; import org.springframework.web.bind.support.DefaultSessionAttributeStore; import org.springframework.web.bind.support.SessionAttributeStore; import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.support.HandlerMethodArgumentResolverComposite; import org.springframework.web.method.support.InvocableHandlerMethod; import org.springframework.web.method.support.ModelAndViewContainer; import static org.junit.Assert.*; /** * Unit tests verifying {@code @ModelAttribute} method inter-dependencies. * * @author Rossen Stoyanchev */ public class ModelFactoryOrderingTests { private static final Log logger = LogFactory.getLog(ModelFactoryOrderingTests.class); private NativeWebRequest webRequest; private ModelAndViewContainer mavContainer; private SessionAttributeStore sessionAttributeStore; @Before public void setup() { this.sessionAttributeStore = new DefaultSessionAttributeStore(); this.webRequest = new ServletWebRequest(new MockHttpServletRequest(), new MockHttpServletResponse()); this.mavContainer = new ModelAndViewContainer(); this.mavContainer.addAttribute("methods", new ArrayList<String>()); } @Test public void straightLineDependency() throws Exception { runTest(new StraightLineDependencyController()); assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4"); assertInvokedBefore("getB1", "getB2", "getC1", "getC2", "getC3", "getC4"); assertInvokedBefore("getB2", "getC1", "getC2", "getC3", "getC4"); assertInvokedBefore("getC1", "getC2", "getC3", "getC4"); assertInvokedBefore("getC2", "getC3", "getC4"); assertInvokedBefore("getC3", "getC4"); } @Test public void treeDependency() throws Exception { runTest(new TreeDependencyController()); assertInvokedBefore("getA", "getB1", "getB2", "getC1", "getC2", "getC3", "getC4"); assertInvokedBefore("getB1", "getC1", "getC2"); assertInvokedBefore("getB2", "getC3", "getC4"); } @Test public void InvertedTreeDependency() throws Exception { runTest(new InvertedTreeDependencyController()); assertInvokedBefore("getC1", "getA", "getB1"); assertInvokedBefore("getC2", "getA", "getB1"); assertInvokedBefore("getC3", "getA", "getB2"); assertInvokedBefore("getC4", "getA", "getB2"); assertInvokedBefore("getB1", "getA"); assertInvokedBefore("getB2", "getA"); } @Test public void unresolvedDependency() throws Exception { runTest(new UnresolvedDependencyController()); assertInvokedBefore("getA", "getC1", "getC2", "getC3", "getC4"); // No other order guarantees for methods with unresolvable dependencies (and methods that depend on them), // Required dependencies will be created via default constructor. } private void runTest(Object controller) throws Exception { HandlerMethodArgumentResolverComposite resolvers = new HandlerMethodArgumentResolverComposite(); resolvers.addResolver(new ModelAttributeMethodProcessor(false)); resolvers.addResolver(new ModelMethodProcessor()); WebDataBinderFactory dataBinderFactory = new DefaultDataBinderFactory(null); Class<?> type = controller.getClass(); Set<Method> methods = MethodIntrospector.selectMethods(type, METHOD_FILTER); List<InvocableHandlerMethod> modelMethods = new ArrayList<>(); for (Method method : methods) { InvocableHandlerMethod modelMethod = new InvocableHandlerMethod(controller, method); modelMethod.setHandlerMethodArgumentResolvers(resolvers); modelMethod.setDataBinderFactory(dataBinderFactory); modelMethods.add(modelMethod); } Collections.shuffle(modelMethods); SessionAttributesHandler sessionHandler = new SessionAttributesHandler(type, this.sessionAttributeStore); ModelFactory factory = new ModelFactory(modelMethods, dataBinderFactory, sessionHandler); factory.initModel(this.webRequest, this.mavContainer, new HandlerMethod(controller, "handle")); if (logger.isDebugEnabled()) { StringBuilder sb = new StringBuilder(); for (String name : getInvokedMethods()) { sb.append(" >> ").append(name); } logger.debug(sb); } } private void assertInvokedBefore(String beforeMethod, String... afterMethods) { List<String> actual = getInvokedMethods(); for (String afterMethod : afterMethods) { assertTrue(beforeMethod + " should be before " + afterMethod + ". Actual order: " + actual.toString(), actual.indexOf(beforeMethod) < actual.indexOf(afterMethod)); } } @SuppressWarnings("unchecked") private List<String> getInvokedMethods() { return (List<String>) this.mavContainer.getModel().get("methods"); } private static class AbstractController { @RequestMapping public void handle() { } @SuppressWarnings("unchecked") <T> T updateAndReturn(Model model, String methodName, T returnValue) throws IOException { ((List<String>) model.asMap().get("methods")).add(methodName); return returnValue; } } private static class StraightLineDependencyController extends AbstractController { @ModelAttribute public A getA(Model model) throws IOException { return updateAndReturn(model, "getA", new A()); } @ModelAttribute public B1 getB1(@ModelAttribute A a, Model model) throws IOException { return updateAndReturn(model, "getB1", new B1()); } @ModelAttribute public B2 getB2(@ModelAttribute B1 b1, Model model) throws IOException { return updateAndReturn(model, "getB2", new B2()); } @ModelAttribute public C1 getC1(@ModelAttribute B2 b2, Model model) throws IOException { return updateAndReturn(model, "getC1", new C1()); } @ModelAttribute public C2 getC2(@ModelAttribute C1 c1, Model model) throws IOException { return updateAndReturn(model, "getC2", new C2()); } @ModelAttribute public C3 getC3(@ModelAttribute C2 c2, Model model) throws IOException { return updateAndReturn(model, "getC3", new C3()); } @ModelAttribute public C4 getC4(@ModelAttribute C3 c3, Model model) throws IOException { return updateAndReturn(model, "getC4", new C4()); } } private static class TreeDependencyController extends AbstractController { @ModelAttribute public A getA(Model model) throws IOException { return updateAndReturn(model, "getA", new A()); } @ModelAttribute public B1 getB1(@ModelAttribute A a, Model model) throws IOException { return updateAndReturn(model, "getB1", new B1()); } @ModelAttribute public B2 getB2(@ModelAttribute A a, Model model) throws IOException { return updateAndReturn(model, "getB2", new B2()); } @ModelAttribute public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException { return updateAndReturn(model, "getC1", new C1()); } @ModelAttribute public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException { return updateAndReturn(model, "getC2", new C2()); } @ModelAttribute public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException { return updateAndReturn(model, "getC3", new C3()); } @ModelAttribute public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException { return updateAndReturn(model, "getC4", new C4()); } } private static class InvertedTreeDependencyController extends AbstractController { @ModelAttribute public C1 getC1(Model model) throws IOException { return updateAndReturn(model, "getC1", new C1()); } @ModelAttribute public C2 getC2(Model model) throws IOException { return updateAndReturn(model, "getC2", new C2()); } @ModelAttribute public C3 getC3(Model model) throws IOException { return updateAndReturn(model, "getC3", new C3()); } @ModelAttribute public C4 getC4(Model model) throws IOException { return updateAndReturn(model, "getC4", new C4()); } @ModelAttribute public B1 getB1(@ModelAttribute C1 c1, @ModelAttribute C2 c2, Model model) throws IOException { return updateAndReturn(model, "getB1", new B1()); } @ModelAttribute public B2 getB2(@ModelAttribute C3 c3, @ModelAttribute C4 c4, Model model) throws IOException { return updateAndReturn(model, "getB2", new B2()); } @ModelAttribute public A getA(@ModelAttribute B1 b1, @ModelAttribute B2 b2, Model model) throws IOException { return updateAndReturn(model, "getA", new A()); } } private static class UnresolvedDependencyController extends AbstractController { @ModelAttribute public A getA(Model model) throws IOException { return updateAndReturn(model, "getA", new A()); } @ModelAttribute public C1 getC1(@ModelAttribute B1 b1, Model model) throws IOException { return updateAndReturn(model, "getC1", new C1()); } @ModelAttribute public C2 getC2(@ModelAttribute B1 b1, Model model) throws IOException { return updateAndReturn(model, "getC2", new C2()); } @ModelAttribute public C3 getC3(@ModelAttribute B2 b2, Model model) throws IOException { return updateAndReturn(model, "getC3", new C3()); } @ModelAttribute public C4 getC4(@ModelAttribute B2 b2, Model model) throws IOException { return updateAndReturn(model, "getC4", new C4()); } } private static class A { } private static class B1 { } private static class B2 { } private static class C1 { } private static class C2 { } private static class C3 { } private static class C4 { } private static final ReflectionUtils.MethodFilter METHOD_FILTER = new ReflectionUtils.MethodFilter() { @Override public boolean matches(Method method) { return ((AnnotationUtils.findAnnotation(method, RequestMapping.class) == null) && (AnnotationUtils.findAnnotation(method, ModelAttribute.class) != null)); } }; }