/* * Copyright 2017 ThoughtWorks, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.thoughtworks.go.server.plugin.controller; import com.thoughtworks.go.plugin.api.request.GoPluginApiRequest; import com.thoughtworks.go.plugin.api.response.DefaultGoPluginApiResponse; import com.thoughtworks.go.plugin.infra.PluginManager; import com.thoughtworks.go.util.ReflectionUtil; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.PrintWriter; import java.util.*; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.initMocks; public class PluginControllerTest { public static final String PLUGIN_ID = "plugin.id"; public static final String REQUEST_NAME = "request.name"; @Mock private PluginManager pluginManager; @Mock private HttpServletRequest servletRequest; @Mock private HttpServletResponse servletResponse; @Mock private PrintWriter writer; private PluginController pluginController; private ArgumentCaptor<GoPluginApiRequest> requestArgumentCaptor; private ArgumentCaptor<Integer> responseCodeArgumentCaptor; private ArgumentCaptor<String> contentTypeArgument; @Before public void setUp() throws Exception { initMocks(this); requestArgumentCaptor = ArgumentCaptor.forClass(GoPluginApiRequest.class); responseCodeArgumentCaptor = ArgumentCaptor.forClass(Integer.class); contentTypeArgument = ArgumentCaptor.forClass(String.class); when(servletResponse.getWriter()).thenReturn(writer); doNothing().when(servletResponse).setStatus(responseCodeArgumentCaptor.capture()); doNothing().when(servletResponse).setHeader(anyString(), contentTypeArgument.capture()); pluginController = new PluginController(pluginManager); } @Test public void shouldForwardWebRequestToPlugin() throws Exception { when(pluginManager.submitTo(eq(PLUGIN_ID), requestArgumentCaptor.capture())).thenReturn(new DefaultGoPluginApiResponse(200)); when(pluginManager.isPluginOfType(any(String.class), any(String.class))).thenReturn(true); Map<String, String[]> springParameterMap = new HashMap<>(); springParameterMap.put("k1", new String[]{"v1"}); springParameterMap.put("k2", new String[]{"v2", "v3"}); springParameterMap.put("k3", new String[]{}); springParameterMap.put("k4", null); when(servletRequest.getParameterMap()).thenReturn(springParameterMap); List<String> elements = Arrays.asList("h1", "h2", "h3"); when(servletRequest.getHeader("h1")).thenReturn("v1"); when(servletRequest.getHeader("h2")).thenReturn(""); when(servletRequest.getHeader("h3")).thenReturn(null); when(servletRequest.getHeaderNames()).thenReturn(getMockEnumeration(elements)); pluginController.handlePluginInteractRequest(PLUGIN_ID, REQUEST_NAME, servletRequest, servletResponse); Map<String, String> requestParameters = new HashMap<>(); requestParameters.put("k1", "v1"); requestParameters.put("k2", "v2"); requestParameters.put("k3", null); requestParameters.put("k4", null); Map<String, String> requestHeaders = new HashMap<>(); requestHeaders.put("h1", "v1"); requestHeaders.put("h2", ""); requestHeaders.put("h3", null); assertRequest(requestArgumentCaptor.getValue(), REQUEST_NAME, requestParameters, requestHeaders); } @Test public void shouldRenderPluginResponseWithDefaultContentTypeOn200() throws Exception { when(pluginManager.isPluginOfType(any(String.class), any(String.class))).thenReturn(true); DefaultGoPluginApiResponse apiResponse = new DefaultGoPluginApiResponse(200); String responseBody = "response-body"; apiResponse.setResponseBody(responseBody); when(pluginManager.submitTo(eq(PLUGIN_ID), requestArgumentCaptor.capture())).thenReturn(apiResponse); when(servletRequest.getParameterMap()).thenReturn(new HashMap<>()); when(servletRequest.getHeaderNames()).thenReturn(getMockEnumeration(new ArrayList<>())); pluginController.handlePluginInteractRequest(PLUGIN_ID, REQUEST_NAME, servletRequest, servletResponse); assertThat(contentTypeArgument.getValue(), is(PluginController.CONTENT_TYPE_HTML)); verify(writer).write(responseBody); assertRequest(requestArgumentCaptor.getValue(), REQUEST_NAME, new HashMap<>(), new HashMap<>()); } @Test public void shouldRenderPluginResponseWithSpecifiedContentTypeOn200() throws Exception { when(pluginManager.isPluginOfType(any(String.class), any(String.class))).thenReturn(true); DefaultGoPluginApiResponse apiResponse = new DefaultGoPluginApiResponse(200); String contentType = "image/png"; apiResponse.responseHeaders().put("Content-Type", contentType); String responseBody = "response-body"; apiResponse.setResponseBody(responseBody); when(pluginManager.submitTo(eq(PLUGIN_ID), any(GoPluginApiRequest.class))).thenReturn(apiResponse); when(servletRequest.getParameterMap()).thenReturn(new HashMap<>()); when(servletRequest.getHeaderNames()).thenReturn(getMockEnumeration(new ArrayList<>())); pluginController.handlePluginInteractRequest(PLUGIN_ID, REQUEST_NAME, servletRequest, servletResponse); assertThat(contentTypeArgument.getValue(), is(contentType)); verify(writer).write(responseBody); } @Test public void shouldRedirectToSpecifiedLocationOn302() throws Exception { when(pluginManager.isPluginOfType(any(String.class), any(String.class))).thenReturn(true); DefaultGoPluginApiResponse apiResponse = new DefaultGoPluginApiResponse(302); String redirectLocation = "/go/plugin/interact/plugin.id/request.name"; apiResponse.responseHeaders().put("Location", redirectLocation); when(pluginManager.submitTo(eq(PLUGIN_ID), any(GoPluginApiRequest.class))).thenReturn(apiResponse); when(servletRequest.getParameterMap()).thenReturn(new HashMap<>()); when(servletRequest.getHeaderNames()).thenReturn(getMockEnumeration(new ArrayList<>())); pluginController.handlePluginInteractRequest(PLUGIN_ID, REQUEST_NAME, servletRequest, servletResponse); verify(servletResponse, times(1)).sendRedirect(anyString()); } @Test public void shouldAllowInteractionOnlyForAuthPlugins() throws IOException { when(pluginManager.isPluginOfType("authentication", "github.pr")).thenReturn(false); pluginController.handlePluginInteractRequest(PLUGIN_ID, REQUEST_NAME, servletRequest, servletResponse); assertThat(responseCodeArgumentCaptor.getValue(), is(403)); } @Test public void shouldDisallowRequestsWhichNeedAuthentication() throws IOException { when(pluginManager.isPluginOfType(any(String.class), any(String.class))).thenReturn(true); List<String> restrictedRequests = Arrays.asList("go.plugin-settings.get-configuration", "go.plugin-settings.get-view", "go.plugin-settings.validate-configuration", "go.authentication.plugin-configuration", "go.authentication.authenticate-user", "go.authentication.search-user"); for (String requestName : restrictedRequests) { pluginController.handlePluginInteractRequest(PLUGIN_ID, requestName, servletRequest, servletResponse); assertThat(responseCodeArgumentCaptor.getValue(), is(403)); } } private Enumeration<String> getMockEnumeration(List<String> elements) { Enumeration<String> enumeration = new Enumeration<String>() { private List<String> elements; int i = 0; @Override public boolean hasMoreElements() { return i < elements.size(); } @Override public String nextElement() { return elements.get(i++); } }; ReflectionUtil.setField(enumeration, "elements", elements); return enumeration; } private void assertRequest(GoPluginApiRequest goPluginApiRequest, String requestName, Map<String, String> requestParameters, Map<String, String> requestHeaders) { assertThat(goPluginApiRequest.extension(), is(nullValue())); assertThat(goPluginApiRequest.extensionVersion(), is(nullValue())); assertThat(goPluginApiRequest.requestName(), is(requestName)); assertEquals(requestParameters, goPluginApiRequest.requestParameters()); assertEquals(requestHeaders, goPluginApiRequest.requestHeaders()); assertThat(goPluginApiRequest.requestBody(), is(nullValue())); } }