/* (c) 2016 Open Source Geospatial Foundation - all rights reserved * This code is licensed under the GPL 2.0 license, available at the root * application directory. */ package org.geoserver.security.oauth2.services; import java.io.IOException; import java.util.Map; import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.client.ClientHttpResponse; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.common.exceptions.InvalidTokenException; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.RestTemplate; /** * Remote Token Services for GitHub token details. * * @author Alessio Fabiani, GeoSolutions S.A.S. */ public class GitHubTokenServices extends GeoServerOAuthRemoteTokenServices { public GitHubTokenServices() { tokenConverter = new GitHubAccessTokenConverter(); restTemplate = new RestTemplate(); ((RestTemplate) restTemplate).setErrorHandler(new DefaultResponseErrorHandler() { @Override // Ignore 400 public void handleError(ClientHttpResponse response) throws IOException { if (response.getRawStatusCode() != 400) { super.handleError(response); } } }); } @Override public OAuth2Authentication loadAuthentication(String accessToken) throws AuthenticationException, InvalidTokenException { Map<String, Object> checkTokenResponse = checkToken(accessToken); if (checkTokenResponse.containsKey("message") && checkTokenResponse.get("message").toString().startsWith("Problems")) { logger.debug("check_token returned error: " + checkTokenResponse.get("message")); throw new InvalidTokenException(accessToken); } transformNonStandardValuesToStandardValues(checkTokenResponse); Assert.state(checkTokenResponse.containsKey("client_id"), "Client id must be present in response from auth server"); return tokenConverter.extractAuthentication(checkTokenResponse); } private Map<String, Object> checkToken(String accessToken) { MultiValueMap<String, String> formData = new LinkedMultiValueMap<>(); formData.add("token", accessToken); HttpHeaders headers = new HttpHeaders(); // headers.set("Authorization", getAuthorizationHeader(clientId, clientSecret)); headers.set("Authorization", getAuthorizationHeader(accessToken)); String accessTokenUrl = new StringBuilder(checkTokenEndpointUrl).append("?access_token=") .append(accessToken).toString(); return postForMap(accessTokenUrl, formData, headers); } private void transformNonStandardValuesToStandardValues(Map<String, Object> map) { LOGGER.debug("Original map = " + map); map.put("client_id", clientId); // GitHub does not send 'client_id' map.put("user_name", map.get("login")); // GitHub sends 'user_name' as 'login' LOGGER.debug("Transformed = " + map); } private String getAuthorizationHeader(String accessToken) { return "Bearer " + accessToken; } private Map<String, Object> postForMap(String path, MultiValueMap<String, String> formData, HttpHeaders headers) { if (headers.getContentType() == null) { headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); } ParameterizedTypeReference<Map<String, Object>> map = new ParameterizedTypeReference<Map<String, Object>>() { }; return restTemplate .exchange(path, HttpMethod.GET, new HttpEntity<>(formData, headers), map) .getBody(); } }