/* * Copyright 2012 SURFnet bv, The Netherlands * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.surfnet.oaaas.conext; import nl.surfnet.coin.api.client.OpenConextOAuthClient; import nl.surfnet.coin.api.client.domain.Group20; import nl.surfnet.spring.security.opensaml.AuthnRequestGenerator; import nl.surfnet.spring.security.opensaml.Provisioner; import nl.surfnet.spring.security.opensaml.SAMLMessageHandler; import nl.surfnet.spring.security.opensaml.ServiceProviderAuthenticationException; import nl.surfnet.spring.security.opensaml.util.IDService; import nl.surfnet.spring.security.opensaml.util.TimeService; import nl.surfnet.spring.security.opensaml.xml.EndpointGenerator; import org.apache.commons.beanutils.BeanUtils; import org.apache.commons.lang.StringUtils; import org.opensaml.common.binding.SAMLMessageContext; import org.opensaml.saml2.core.AuthnRequest; import org.opensaml.saml2.core.RequesterID; import org.opensaml.saml2.core.Response; import org.opensaml.saml2.core.Scoping; import org.opensaml.saml2.core.impl.RequesterIDBuilder; import org.opensaml.saml2.core.impl.ScopingBuilder; import org.opensaml.saml2.metadata.Endpoint; import org.opensaml.saml2.metadata.SingleSignOnService; import org.opensaml.ws.message.decoder.MessageDecodingException; import org.opensaml.ws.message.encoder.MessageEncodingException; import org.opensaml.xml.security.CriteriaSet; import org.opensaml.xml.security.credential.Credential; import org.opensaml.xml.security.credential.UsageType; import org.opensaml.xml.security.criteria.EntityIDCriteria; import org.opensaml.xml.security.criteria.UsageCriteria; import org.opensaml.xml.validation.ValidationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.io.support.PropertiesLoaderUtils; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import org.surfnet.oaaas.auth.AbstractAuthenticator; import org.surfnet.oaaas.auth.principal.AuthenticatedPrincipal; import org.surfnet.oaaas.model.AuthorizationRequest; import org.surfnet.oaaas.model.Client; import org.surfnet.oaaas.repository.AuthorizationRequestRepository; import javax.inject.Inject; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.List; import java.util.Properties; @Component public class SAMLAuthenticator extends AbstractAuthenticator { private static final Logger LOG = LoggerFactory.getLogger(SAMLAuthenticator.class); private static final String RELAY_STATE_FROM_SAML = "RELAY_STATE_FROM_SAML"; private static final String PRINCIPAL_FROM_SAML = "PRINCIPAL_FROM_SAML"; private static final String CLIENT_SAML_ENTITY_NAME = "CLIENT_SAML_ENTITY_NAME"; private TimeService timeService = new TimeService(); private IDService idService = new IDService(); private ScopingBuilder scopingBuilder = new ScopingBuilder(); private RequesterIDBuilder requesterIDBuilder = new RequesterIDBuilder(); private OpenSAMLContext openSAMLContext; private OpenConextOAuthClient apiClient; private String callbackFlagParameter = "apiOauthCallback"; private boolean enrichPricipal; private String adminGroup; @Inject private AuthorizationRequestRepository authorizationRequestRepository; private final Properties properties; { try { properties = PropertiesLoaderUtils.loadAllProperties("surfconext.authn.properties"); } catch (IOException e) { throw new RuntimeException(e); } } @Override public void init(FilterConfig filterConfig) throws ServletException { try { super.init(filterConfig); openSAMLContext = createOpenSAMLContext(properties); enrichPricipal = Boolean.valueOf(properties.getProperty("api-enrich-principal")); if (enrichPricipal) { apiClient = createOpenConextOAuthClient(properties); adminGroup = properties.getProperty("admin.client.apis.teamname"); } } catch (Exception e) { throw new ServletException(e); } } protected OpenConextOAuthClient createOpenConextOAuthClient(Properties properties) throws ClassNotFoundException, IllegalAccessException, InstantiationException, InvocationTargetException { OpenConextOAuthClient apiClient = (OpenConextOAuthClient) getClass().getClassLoader().loadClass(properties.getProperty("openConextApiClient")).newInstance(); BeanUtils.setProperty(apiClient, "callbackUrl", properties.getProperty("api-callbackuri")); BeanUtils.setProperty(apiClient, "consumerSecret", properties.getProperty("api-consumersecret")); BeanUtils.setProperty(apiClient, "consumerKey", properties.getProperty("api-consumerkey")); BeanUtils.setProperty(apiClient, "endpointBaseUrl", properties.getProperty("api-baseurl")); return apiClient; } /** * Default Context factory method. */ protected OpenSAMLContext createOpenSAMLContext(Properties properties) { return new OpenSAMLContext(properties, createProvisioner()); } /** * Default Provisioner factory method. */ protected Provisioner createProvisioner() { SAMLProvisioner samlProvisioner = new SAMLProvisioner(); samlProvisioner.setUuidAttribute((String) properties.get("samlUuidAttribute")); return samlProvisioner; } @Override public boolean canCommence(HttpServletRequest request) { return isSAMLResponse(request) || isOAuthCallback(request); } @Override public void authenticate(HttpServletRequest request, HttpServletResponse response, FilterChain chain, String authStateValue, String returnUri) throws IOException, ServletException { LOG.debug("Hitting SAML Authenticator filter"); if (isSAMLResponse(request)) { Response samlResponse = extractSamlResponse(request); SAMLAuthenticatedPrincipal principal = (SAMLAuthenticatedPrincipal) openSAMLContext.assertionConsumer().consume(samlResponse); if (enrichPricipal) { //need to save the Principal and the AuthState somewhere request.getSession().setAttribute(PRINCIPAL_FROM_SAML, principal); request.getSession().setAttribute(RELAY_STATE_FROM_SAML, getSAMLRelayState(request)); response.sendRedirect(apiClient.getAuthorizationUrl()); } else { proceedWithChain(request, response, chain, principal, getSAMLRelayState(request)); } } else if (isOAuthCallback(request)) { SAMLAuthenticatedPrincipal principal = (SAMLAuthenticatedPrincipal) request.getSession().getAttribute(PRINCIPAL_FROM_SAML); String authState = (String) request.getSession().getAttribute(RELAY_STATE_FROM_SAML); if (principal == null) { //huh throw new ServiceProviderAuthenticationException("No principal anymore in the session"); } String userId = principal.getName(); if (StringUtils.isEmpty(userId)) { throw new ServiceProviderAuthenticationException("No userId in SAML assertion!"); } apiClient.oauthCallback(request, userId); List<Group20> groups = apiClient.getGroups20(userId, userId); if (!CollectionUtils.isEmpty(groups)) { for (Group20 group : groups) { principal.addGroup(group.getId()); if (StringUtils.isNotBlank(this.adminGroup) && adminGroup.equalsIgnoreCase(group.getId())) { principal.setAdminPrincipal(true); } } } proceedWithChain(request, response, chain, principal, authState); } else { sendAuthnRequest(response, authStateValue, getReturnUri(request)); } } private void proceedWithChain(HttpServletRequest request, HttpServletResponse response, FilterChain chain, AuthenticatedPrincipal principal, String authStateValue) throws IOException, ServletException { super.setPrincipal(request, principal); super.setAuthStateValue(request, authStateValue); chain.doFilter(request, response); } private boolean isOAuthCallback(HttpServletRequest request) { return request.getParameter(callbackFlagParameter) != null; } protected String getSAMLRelayState(HttpServletRequest request) { return request.getParameter("RelayState"); } protected boolean isSAMLResponse(HttpServletRequest request) { return request.getParameter("SAMLResponse") != null; } private Response extractSamlResponse(HttpServletRequest request) { SAMLMessageContext messageContext; final SAMLMessageHandler samlMessageHandler = openSAMLContext.samlMessageHandler(); try { messageContext = samlMessageHandler.extractSAMLMessageContext(request); } catch (MessageDecodingException me) { throw new ServiceProviderAuthenticationException("Could not decode SAML Response", me); } catch (org.opensaml.xml.security.SecurityException se) { throw new ServiceProviderAuthenticationException("Could not decode SAML Response", se); } LOG.debug("Message received from issuer: " + messageContext.getInboundMessageIssuer()); if (!(messageContext.getInboundSAMLMessage() instanceof Response)) { throw new ServiceProviderAuthenticationException("SAML Message was not a Response."); } final Response inboundSAMLMessage = (Response) messageContext.getInboundSAMLMessage(); try { openSAMLContext.validatorSuite().validate(inboundSAMLMessage); return inboundSAMLMessage; } catch (ValidationException ve) { LOG.warn("Response Message failed Validation", ve); throw new RuntimeException("Invalid SAML Response Message", ve); } } private void sendAuthnRequest(HttpServletResponse response, String authState, String returnUri) throws IOException { AuthnRequestGenerator authnRequestGenerator = new AuthnRequestGenerator(openSAMLContext.entityId(), timeService, idService); EndpointGenerator endpointGenerator = new EndpointGenerator(); final String target = openSAMLContext.getIdpUrl(); Endpoint endpoint = endpointGenerator.generateEndpoint( SingleSignOnService.DEFAULT_ELEMENT_NAME, target, openSAMLContext.assertionConsumerUri()); AuthnRequest authnRequest = authnRequestGenerator.generateAuthnRequest(target, openSAMLContext.assertionConsumerUri()); Client client = getClientByRequest(authState); String spEntityIdBy = client.getAttributes().get(CLIENT_SAML_ENTITY_NAME); if (StringUtils.isNotEmpty(spEntityIdBy)) { Scoping scoping = scopingBuilder.buildObject(); scoping.getRequesterIDs().add(createRequesterID(spEntityIdBy)); authnRequest.setScoping(scoping); } else { LOG.warn("For Client {} there is no key CLIENT_SAML_ENTITY_NAME configured to identify the SP entity name. NO SCOPING IS APPLIED", client.getClientId()); } CriteriaSet criteriaSet = new CriteriaSet(); criteriaSet.add(new EntityIDCriteria(openSAMLContext.entityId())); criteriaSet.add(new UsageCriteria(UsageType.SIGNING)); try { Credential signingCredential = openSAMLContext.keyStoreCredentialResolver().resolveSingle(criteriaSet); String relayState = authState; LOG.debug("Sending authnRequest to {}", target); openSAMLContext.samlMessageHandler().sendSAMLMessage(authnRequest, endpoint, response, relayState, signingCredential); } catch (MessageEncodingException mee) { LOG.error("Could not send authnRequest to Identity Provider.", mee); response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); } catch (org.opensaml.xml.security.SecurityException e) { throw new RuntimeException(e); } } private RequesterID createRequesterID(String id) { RequesterID requesterID = requesterIDBuilder.buildObject(); requesterID.setRequesterID(id); return requesterID; } /** * Get the Client */ protected Client getClientByRequest(String authState) { AuthorizationRequest authorizationRequest = authorizationRequestRepository.findByAuthState(authState); return authorizationRequest.getClient(); } }