/* (c) 2016 Open Source Geospatial Foundation - all rights reserved
* This code is licensed under the GPL 2.0 license, available at the root
* application directory.
*/
package org.geoserver.security.oauth2.services;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Map;
import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.codec.Base64;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RestTemplate;
/**
* Remote Token Services for Google token details.
*
* @author Alessio Fabiani, GeoSolutions S.A.S.
*/
public class GoogleTokenServices extends GeoServerOAuthRemoteTokenServices {
public GoogleTokenServices() {
tokenConverter = new GoogleAccessTokenConverter();
restTemplate = new RestTemplate();
((RestTemplate) restTemplate).setErrorHandler(new DefaultResponseErrorHandler() {
@Override
// Ignore 400
public void handleError(ClientHttpResponse response) throws IOException {
if (response.getRawStatusCode() != 400) {
super.handleError(response);
}
}
});
}
@Override
public OAuth2Authentication loadAuthentication(String accessToken)
throws AuthenticationException, InvalidTokenException {
Map<String, Object> checkTokenResponse = checkToken(accessToken);
if (checkTokenResponse.containsKey("error")) {
logger.debug("check_token returned error: " + checkTokenResponse.get("error"));
throw new InvalidTokenException(accessToken);
}
transformNonStandardValuesToStandardValues(checkTokenResponse);
Assert.state(checkTokenResponse.containsKey("client_id"),
"Client id must be present in response from auth server");
return tokenConverter.extractAuthentication(checkTokenResponse);
}
private Map<String, Object> checkToken(String accessToken) {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
formData.add("token", accessToken);
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", getAuthorizationHeader(clientId, clientSecret));
String accessTokenUrl = new StringBuilder(checkTokenEndpointUrl).append("?access_token=")
.append(accessToken).toString();
return postForMap(accessTokenUrl, formData, headers);
}
private void transformNonStandardValuesToStandardValues(Map<String, Object> map) {
LOGGER.debug("Original map = " + map);
map.put("client_id", map.get("issued_to")); // Google sends 'client_id' as 'issued_to'
map.put("user_name", map.get("user_id")); // Google sends 'user_name' as 'user_id'
LOGGER.debug("Transformed = " + map);
}
private String getAuthorizationHeader(String clientId, String clientSecret) {
String creds = String.format("%s:%s", clientId, clientSecret);
try {
return "Basic " + new String(Base64.encode(creds.getBytes("UTF-8")));
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException("Could not convert String");
}
}
private Map<String, Object> postForMap(String path, MultiValueMap<String, String> formData,
HttpHeaders headers) {
if (headers.getContentType() == null) {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
}
ParameterizedTypeReference<Map<String, Object>> map = new ParameterizedTypeReference<Map<String, Object>>() {
};
return restTemplate
.exchange(path, HttpMethod.POST, new HttpEntity<>(formData, headers), map)
.getBody();
}
}