/******************************************************************************* * Cloud Foundry * Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved. * * This product is licensed to you under the Apache License, Version 2.0 (the "License"). * You may not use this product except in compliance with the License. * * This product includes a number of subcomponents with * separate copyright notices and license terms. Your use of these * subcomponents is subject to the terms and conditions of the * subcomponent's license, as noted in the LICENSE file. *******************************************************************************/ package org.cloudfoundry.identity.uaa; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.http.client.HttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.cloudfoundry.identity.uaa.test.TestProfileEnvironment; import org.cloudfoundry.identity.uaa.test.UrlHelper; import org.junit.Assume; import org.junit.internal.AssumptionViolatedException; import org.junit.rules.MethodRule; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.Statement; import org.springframework.core.env.Environment; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.support.HttpAccessor; import org.springframework.security.oauth2.client.test.RestTemplateHolder; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriUtils; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.Socket; import java.net.URI; import java.net.URISyntaxException; import java.util.HashMap; import java.util.Map; /** * <p> * A rule that prevents integration tests from failing if the server application * is not running or not accessible. If the server is not running in the * background all the tests here will simply be skipped because of a violated * assumption (showing as successful). Usage: * </p> * * <pre> * @Rule public static ServerRunning brokerIsRunning = ServerRunning.isRunning(); * * @Test public void testSendAndReceive() throws Exception { // ... test using server etc. } * </pre> * <p> * The rule can be declared as static so that it only has to check once for all * tests in the enclosing test case, but there isn't a lot of overhead in making * it non-static. * </p> * * @see Assume * @see AssumptionViolatedException * * @author Dave Syer * */ public class ServerRunning implements MethodRule, RestTemplateHolder, UrlHelper { private static Log logger = LogFactory.getLog(ServerRunning.class); private Environment environment; // Static so that we only test once on failure: speeds up test suite private static Map<Integer, Boolean> serverOnline = new HashMap<Integer, Boolean>(); private final boolean integrationTest; private static int DEFAULT_PORT = 8080; private static String DEFAULT_HOST = "localhost"; private static String DEFAULT_ROOT_PATH = "/uaa"; private int port; private String hostName = DEFAULT_HOST; private String rootPath = DEFAULT_ROOT_PATH; private RestOperations client; /** * @return a new rule that assumes an existing running broker */ public static ServerRunning isRunning() { return new ServerRunning(); } private ServerRunning() { environment = TestProfileEnvironment.getEnvironment(); integrationTest = environment.getProperty("uaa.integration.test", Boolean.class, false); client = getRestTemplate(); setPort(environment.getProperty("uaa.port", Integer.class, DEFAULT_PORT)); setRootPath(environment.getProperty("uaa.path", DEFAULT_ROOT_PATH)); setHostName(environment.getProperty("uaa.host", DEFAULT_HOST)); } /** * @param port the port to set */ public void setPort(int port) { this.port = port; } /** * @param hostName the hostName to set */ public void setHostName(String hostName) { this.hostName = hostName; } public String getHostName() { return hostName; } /** * The context root in the application, e.g. "/uaa" for a local deployment. * * @param rootPath the rootPath to set */ public void setRootPath(String rootPath) { if (rootPath.equals("/")) { rootPath = ""; } else { if (!rootPath.startsWith("/")) { rootPath = "/" + rootPath; } } this.rootPath = rootPath; } @Override public Statement apply(Statement statement, FrameworkMethod frameworkMethod, Object o) { Assume.assumeTrue("Test ignored as the server cannot be reached at " + hostName + ":" + port, integrationTest || getStatus()); return statement; } private synchronized Boolean getStatus() { Boolean available = serverOnline.get(port); if (available == null) { available = connectionAvailable(); serverOnline.put(port, available); } return available; } private boolean connectionAvailable() { logger.info("Testing connectivity for " + hostName + ":" + port); try (Socket socket = new Socket(hostName, port)) { logger.info("Connectivity test succeeded for " + hostName + ":" + port); return true; } catch (IOException e) { logger.warn("Connectivity test failed for " + hostName + ":" + port, e); return false; } } @Override public String getBaseUrl() { return "http://" + hostName + (port == 80 ? "" : ":" + port) + rootPath; } @Override public String getAccessTokenUri() { return getUrl("/oauth/token"); } @Override public String getAuthorizationUri() { return getUrl("/oauth/authorize"); } @Override public String getClientsUri() { return getUrl("/oauth/clients"); } @Override public String getUsersUri() { return getUrl("/Users"); } @Override public String getUserUri() { return getUrl("/Users"); } @Override public String getUrl(String path) { if (path.startsWith("http:")) { return path; } if (!path.startsWith("/")) { path = "/" + path; } return getBaseUrl() + path; } public ResponseEntity<String> postForString(String path, MultiValueMap<String, String> formData, HttpHeaders headers) { if (headers.getContentType() == null) { headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); } return client.exchange(getUrl(path), HttpMethod.POST, new HttpEntity<>(formData, headers), String.class); } @SuppressWarnings("rawtypes") public ResponseEntity<Map> postForMap(String path, MultiValueMap<String, String> formData, HttpHeaders headers) { if (headers.getContentType() == null) { headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); } return client.exchange(getUrl(path), HttpMethod.POST, new HttpEntity<>(formData, headers), Map.class); } public ResponseEntity<String> getForString(String path) { return getForString(path, new HttpHeaders()); } public <T> ResponseEntity<T> getForObject(String path, Class<T> type, final HttpHeaders headers) { return client.exchange(getUrl(path), HttpMethod.GET, new HttpEntity<>(null, headers), type); } public <T> ResponseEntity<T> getForObject(String path, Class<T> type) { return getForObject(path, type, new HttpHeaders()); } public ResponseEntity<String> getForString(String path, final HttpHeaders headers) { HttpEntity<Void> request = new HttpEntity<>(null, headers); return client.exchange(getUrl(path), HttpMethod.GET, request, String.class); } public ResponseEntity<Void> getForResponse(String path, final HttpHeaders headers, Object... uriVariables) { HttpEntity<Void> request = new HttpEntity<>(null, headers); return client.exchange(getUrl(path), HttpMethod.GET, request, Void.class, uriVariables); } public ResponseEntity<Void> postForResponse(String path, HttpHeaders headers, MultiValueMap<String, String> params) { HttpHeaders actualHeaders = new HttpHeaders(); actualHeaders.putAll(headers); actualHeaders.setContentType(MediaType.APPLICATION_FORM_URLENCODED); return client.exchange(getUrl(path), HttpMethod.POST, new HttpEntity<>(params, actualHeaders), Void.class); } public ResponseEntity<Void> postForRedirect(String path, HttpHeaders headers, MultiValueMap<String, String> params) { ResponseEntity<Void> exchange = postForResponse(path, headers, params); if (exchange.getStatusCode() != HttpStatus.FOUND) { throw new IllegalStateException("Expected 302 but server returned status code " + exchange.getStatusCode()); } headers.remove("Cookie"); if (exchange.getHeaders().containsKey("Set-Cookie")) { for (String cookie : exchange.getHeaders().get("Set-Cookie")) { headers.add("Cookie", cookie); } } String location = exchange.getHeaders().getLocation().toString(); return client.exchange(location, HttpMethod.GET, new HttpEntity<Void>(null, headers), Void.class); } @Override public RestOperations getRestTemplate() { if (client == null) { client = createRestTemplate(); } return client; } @Override public void setRestTemplate(RestOperations restTemplate) { this.client = restTemplate; if (restTemplate instanceof HttpAccessor) { ((HttpAccessor) restTemplate).setRequestFactory(new StatelessRequestFactory()); } } public RestTemplate createRestTemplate() { RestTemplate client = new RestTemplate(); client.setRequestFactory(new StatelessRequestFactory()); client.setErrorHandler(new ResponseErrorHandler() { // Pass errors through in response entity for status code analysis @Override public boolean hasError(ClientHttpResponse response) throws IOException { return false; } @Override public void handleError(ClientHttpResponse response) throws IOException { } }); return client; } public UriBuilder buildUri(String url) { return UriBuilder.fromUri(url.startsWith("http:") ? url : getUrl(url)); } private static class StatelessRequestFactory extends HttpComponentsClientHttpRequestFactory { @Override public HttpClient getHttpClient() { return HttpClientBuilder.create() .useSystemProperties() .disableRedirectHandling() .disableCookieManagement() .build(); } } public static class UriBuilder { private final String url; private MultiValueMap<String, String> params = new LinkedMultiValueMap<>(); public UriBuilder(String url) { this.url = url; } public static UriBuilder fromUri(String url) { return new UriBuilder(url); } public UriBuilder queryParam(String key, String value) { params.add(key, value); return this; } public URI build() { StringBuilder builder = new StringBuilder(url); try { if (!params.isEmpty()) { builder.append("?"); boolean first = true; for (String key : params.keySet()) { if (!first) { builder.append("&"); } else { first = false; } for (String value : params.get(key)) { builder.append(key + "=" + UriUtils.encodeQueryParam(value, "UTF-8")); } } } return new URI(builder.toString()); } catch (UnsupportedEncodingException ex) { // should not happen, UTF-8 is always supported throw new IllegalStateException(ex); } catch (URISyntaxException ex) { throw new IllegalArgumentException("Could not create URI from [" + builder + "]: " + ex, ex); } } } }