package com.griddynamics.jagger.invoker.v2;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.util.FileCopyUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.UnknownHttpStatusCodeException;
import org.springframework.web.util.UriTemplateHandler;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static com.google.common.collect.Maps.newHashMap;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.CONNECT_TIMEOUT_IN_MS;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.DEFAULT_URI_VARIABLES;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.ERROR_HANDLER;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.INTERCEPTORS;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.MAX_CONN_PER_ROUTE;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.MAX_CONN_TOTAL;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.MESSAGE_CONVERTERS;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.REQUEST_FACTORY;
import static com.griddynamics.jagger.invoker.v2.SpringBasedHttpClient.JSpringBasedHttpClientParameters.URI_TEMPLATE_HANDLER;
/**
* Implementation of {@link JHttpClient}. <p>
* This implementation is based on the Spring {@link RestTemplate}.
*
* @author Anton Antonenko
* @see JHttpClient
* @since 2.0
*
* @ingroup Main_Http_group
*/
@SuppressWarnings({"unused", "unchecked"})
public class SpringBasedHttpClient implements JHttpClient {
private static final Logger log = LoggerFactory.getLogger(SpringBasedHttpClient.class);
private static final int DEFAULT_MAX_CONN_TOTAL = Integer.MAX_VALUE;
private static final int DEFAULT_MAX_CONN_PER_ROUTE = Integer.MAX_VALUE;
private static final int DEFAULT_CONNECT_TIMEOUT_IN_MS = 60000;
/**
* values: {@link JSpringBasedHttpClientParameters#DEFAULT_URI_VARIABLES}, {@link JSpringBasedHttpClientParameters#ERROR_HANDLER},
* {@link JSpringBasedHttpClientParameters#MESSAGE_CONVERTERS}, {@link JSpringBasedHttpClientParameters#URI_TEMPLATE_HANDLER},
* {@link JSpringBasedHttpClientParameters#INTERCEPTORS}, {@link JSpringBasedHttpClientParameters#REQUEST_FACTORY}
*/
public enum JSpringBasedHttpClientParameters {
DEFAULT_URI_VARIABLES("default_uri_variables"),
ERROR_HANDLER("error_handler"),
MESSAGE_CONVERTERS("message_converters"),
URI_TEMPLATE_HANDLER("uri_template_handler"),
INTERCEPTORS("interceptors"),
REQUEST_FACTORY("request_factory"),
MAX_CONN_TOTAL("max_conn_total"),
MAX_CONN_PER_ROUTE("max_conn_per_route"),
CONNECT_TIMEOUT_IN_MS("connect_timeout");
private String value;
JSpringBasedHttpClientParameters(String value) {
this.value = value;
}
public String getValue() {
return value;
}
}
/**
* This field is a container for {@link RestTemplate} parameters which can be passed by the
* {@link SpringBasedHttpClient#SpringBasedHttpClient(Map)} constructor .<p>
* <p>
* The list of supported client params (look at {@link JSpringBasedHttpClientParameters}): <p>
* - {@code Map<String, ?> default_uri_variables} (look at {@link RestTemplate#setDefaultUriVariables(Map)}) <p>
* - {@code ResponseErrorHandler error_handler} (look at {@link RestTemplate#setErrorHandler(ResponseErrorHandler)}) <p>
* - {@code List<HttpMessageConverter<?>> message_converters} (look at {@link RestTemplate#setMessageConverters(List)}) <p>
* - {@code UriTemplateHandler uri_template_handler} (look at {@link RestTemplate#setUriTemplateHandler(UriTemplateHandler)}) <p>
* - {@code List<ClientHttpRequestInterceptor> interceptors} (look at {@link RestTemplate#setInterceptors(List)}) <p>
* - {@code ClientHttpRequestFactory request_factory} (look at {@link RestTemplate#setRequestFactory(ClientHttpRequestFactory)}) <p>
* - {@code int max_conn_total} (look at {@link HttpClientBuilder#setMaxConnTotal(int)}) <p>
* - {@code int max_conn_per_route} (look at {@link HttpClientBuilder#setMaxConnPerRoute(int)}) <p>
* - {@code int connect_timeout} (look at {@link HttpComponentsClientHttpRequestFactory#setConnectTimeout(int)}) <p>
*/
private final Map<String, Object> clientParams;
private RestTemplate restTemplate;
public SpringBasedHttpClient() {
clientParams = new HashMap<>();
restTemplate = new RestTemplate();
restTemplate.setRequestFactory(getRequestFactory());
restTemplate.setErrorHandler(new AllowAllCodesResponseErrorHandler());
}
public SpringBasedHttpClient(Map<String, Object> clientParams) {
this();
this.clientParams.putAll(clientParams);
setRestTemplateParams(this.clientParams);
}
@Override
public JHttpResponse execute(JHttpEndpoint endpoint, JHttpQuery query) {
if (query == JHttpQuery.EMPTY_QUERY)
return execute(endpoint);
URI endpointURI = endpoint.getURI(query.getPath(), query.getQueryParams());
RequestEntity requestEntity = mapToRequestEntity(query, endpointURI);
ResponseEntity responseEntity;
if (query.getResponseBodyType() != null) {
responseEntity = restTemplate.exchange(endpointURI, query.getMethod(), requestEntity, query.getResponseBodyType());
} else {
responseEntity = restTemplate.exchange(endpointURI, query.getMethod(), requestEntity, byte[].class);
}
return mapToJHttpResponse(responseEntity);
}
public JHttpResponse execute(JHttpEndpoint endpoint) {
URI endpointURI = endpoint.getURI();
RequestEntity requestEntity = mapToRequestEntity(endpointURI);
ResponseEntity responseEntity = restTemplate.exchange(endpointURI, HttpMethod.GET, requestEntity, byte[].class);
return mapToJHttpResponse(responseEntity);
}
private void setRestTemplateParams(Map<String, Object> clientParams) {
int maxConnPerRoute = DEFAULT_MAX_CONN_PER_ROUTE;
int maxConnTotal = DEFAULT_MAX_CONN_TOTAL;
int connectTimeoutInMs = DEFAULT_CONNECT_TIMEOUT_IN_MS;
if (clientParams.containsKey(DEFAULT_URI_VARIABLES.value)) {
restTemplate.setDefaultUriVariables((Map<String, ?>) clientParams.get(DEFAULT_URI_VARIABLES.value));
}
if (clientParams.containsKey(ERROR_HANDLER.value)) {
restTemplate.setErrorHandler((ResponseErrorHandler) clientParams.get(ERROR_HANDLER.value));
}
if (clientParams.containsKey(MESSAGE_CONVERTERS.value)) {
restTemplate.setMessageConverters((List<HttpMessageConverter<?>>) clientParams.get(MESSAGE_CONVERTERS.value));
}
if (clientParams.containsKey(URI_TEMPLATE_HANDLER.value)) {
restTemplate.setUriTemplateHandler((UriTemplateHandler) clientParams.get(URI_TEMPLATE_HANDLER.value));
}
if (clientParams.containsKey(INTERCEPTORS.value)) {
restTemplate.setInterceptors((List<ClientHttpRequestInterceptor>) clientParams.get(INTERCEPTORS.value));
}
if (clientParams.containsKey(MAX_CONN_PER_ROUTE.value)) {
Object value = clientParams.get(MAX_CONN_PER_ROUTE.value);
if (value instanceof String)
maxConnPerRoute = Integer.parseInt((String) value);
else
maxConnPerRoute = (int) value;
}
if (clientParams.containsKey(MAX_CONN_TOTAL.value)) {
Object value = clientParams.get(MAX_CONN_TOTAL.value);
if (value instanceof String)
maxConnTotal = Integer.parseInt((String) value);
else
maxConnTotal = (int) value;
}
if (clientParams.containsKey(CONNECT_TIMEOUT_IN_MS.value)) {
Object value = clientParams.get(CONNECT_TIMEOUT_IN_MS.value);
if (value instanceof String)
connectTimeoutInMs = Integer.parseInt((String) value);
else
connectTimeoutInMs = (int) value;
}
if (clientParams.containsKey(REQUEST_FACTORY.value)) {
restTemplate.setRequestFactory((ClientHttpRequestFactory) clientParams.get(REQUEST_FACTORY.value));
}
if (!clientParams.containsKey(REQUEST_FACTORY.value) && containsAnyRequestFactoryParam(clientParams)) {
restTemplate.setRequestFactory(getRequestFactory(maxConnPerRoute, maxConnTotal, connectTimeoutInMs));
} else if (clientParams.containsKey(REQUEST_FACTORY.value) && containsAnyRequestFactoryParam(clientParams)) {
throw new IllegalArgumentException("Parameters max_conn_total, max_conn_per_route and connect_timeout cannot be set if " +
"request_factory parameter presents. You must configure these parameters in your request_factory entity.");
}
}
private boolean containsAnyRequestFactoryParam(Map<String, Object> clientParams) {
return clientParams.containsKey(MAX_CONN_PER_ROUTE.value) ||
clientParams.containsKey(MAX_CONN_TOTAL.value) ||
clientParams.containsKey(CONNECT_TIMEOUT_IN_MS.value);
}
private HttpComponentsClientHttpRequestFactory getRequestFactory() {
return getRequestFactory(DEFAULT_MAX_CONN_PER_ROUTE, DEFAULT_MAX_CONN_TOTAL, DEFAULT_CONNECT_TIMEOUT_IN_MS);
}
private HttpComponentsClientHttpRequestFactory getRequestFactory(int maxConnPerRoute, int maxConnTotal, int connectTimeoutInMs) {
HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory();
CloseableHttpClient httpClient = HttpClients.custom()
.setSSLHostnameVerifier(new NoopHostnameVerifier())
.setMaxConnPerRoute(maxConnPerRoute)
.setMaxConnTotal(maxConnTotal)
.build();
requestFactory.setHttpClient(httpClient);
requestFactory.setConnectTimeout(connectTimeoutInMs);
return requestFactory;
}
private <T> RequestEntity<T> mapToRequestEntity(JHttpQuery<T> query, URI endpointURI) {
return new RequestEntity<>(query.getBody(), query.getHeaders(), query.getMethod(), endpointURI);
}
private <T> RequestEntity<T> mapToRequestEntity(URI endpointURI) {
return new RequestEntity<>(HttpMethod.GET, endpointURI);
}
private <T> JHttpResponse<T> mapToJHttpResponse(ResponseEntity<T> responseEntity) {
return new JHttpResponse<>(responseEntity.getStatusCode(), responseEntity.getBody(), responseEntity.getHeaders());
}
public Map<String, Object> getClientParams() {
return newHashMap(clientParams);
}
public static class AllowAllCodesResponseErrorHandler implements ResponseErrorHandler {
@Override
public boolean hasError(ClientHttpResponse response) throws IOException {
return false;
}
@Override
public void handleError(ClientHttpResponse response) throws IOException {
try {
response.getStatusCode();
} catch (IllegalArgumentException ex) {
throw new UnknownHttpStatusCodeException(response.getRawStatusCode(),
response.getStatusText(), response.getHeaders(), getResponseBody(response), getCharset(response));
}
}
private byte[] getResponseBody(ClientHttpResponse response) {
try {
InputStream responseBody = response.getBody();
if (responseBody != null) {
return FileCopyUtils.copyToByteArray(responseBody);
}
} catch (IOException ex) {
// ignore
}
return new byte[0];
}
private Charset getCharset(ClientHttpResponse response) {
HttpHeaders headers = response.getHeaders();
MediaType contentType = headers.getContentType();
return contentType != null ? contentType.getCharset() : null;
}
}
}