package org.geoserver.security.onelogin.test; import static com.github.tomakehurst.wiremock.client.WireMock.*; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; import static org.junit.Assert.*; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.StringWriter; import java.util.List; import javax.servlet.ServletException; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.transform.Transformer; import javax.xml.transform.TransformerFactory; import javax.xml.transform.dom.DOMSource; import javax.xml.transform.stream.StreamResult; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathFactory; import org.apache.commons.io.IOUtils; import org.apache.http.NameValuePair; import org.apache.http.client.utils.URIBuilder; import org.geoserver.data.test.SystemTestData; import org.geoserver.security.GeoServerSecurityFilterChain; import org.geoserver.security.LogoutFilterChain; import org.geoserver.security.auth.AbstractAuthenticationProviderTest; import org.geoserver.security.config.PreAuthenticatedUserNameFilterConfig.PreAuthenticatedUserNameRoleSource; import org.geoserver.security.filter.GeoServerLogoutFilter; import org.geoserver.security.onelogin.OneloginAuthenticationFilter; import org.geoserver.security.onelogin.OneloginAuthenticationFilterConfig; import org.geotools.data.Base64; import org.hamcrest.CoreMatchers; import org.joda.time.DateTime; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.opensaml.common.SAMLObject; import org.springframework.http.MediaType; import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.saml.SAMLLogoutFilter; import org.springframework.security.saml.SAMLProcessingFilter; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.w3c.dom.Document; import org.w3c.dom.Node; import org.w3c.dom.NodeList; import org.xml.sax.InputSource; import com.github.tomakehurst.wiremock.WireMockServer; public class OneloginAuthenticationTest extends AbstractAuthenticationProviderTest { private static final String METADATA_URL = "/saml/metadata"; private static final String REDIRECT_URL = "/trust/saml2/http-redirect/sso"; private static final Integer IDP_PORT = 8443; private static final String IDP_LOGIN_URL = "http://localhost:" + IDP_PORT + "/login"; private static OneloginAuthenticationFilterConfig config; private static WireMockServer idpSamlService; @Override protected void onSetUp(SystemTestData testData) throws Exception { super.onSetUp(testData); idpSamlService.stubFor( com.github.tomakehurst.wiremock.client.WireMock.get(urlEqualTo(METADATA_URL)) .willReturn(aResponse().withStatus(200) .withHeader("Content-Type", MediaType.APPLICATION_XML_VALUE) .withBodyFile("metadata.xml"))); idpSamlService.stubFor(com.github.tomakehurst.wiremock.client.WireMock .get(urlPathEqualTo(REDIRECT_URL)) .willReturn(aResponse().withStatus(302).withHeader("Location", IDP_LOGIN_URL))); } @BeforeClass public static void beforeClass() throws Exception { SSLUtilities.registerKeyStore("keystore"); idpSamlService = new WireMockServer(wireMockConfig().httpsPort(IDP_PORT)); idpSamlService.start(); } @Before public void before() throws Exception { SecurityContextHolder.getContext().setAuthentication(null); } @AfterClass public static void afterClass() throws Exception { idpSamlService.shutdown(); } @Test public void metadataDiscovery() throws Exception { confgiureFilter(PreAuthenticatedUserNameRoleSource.UserGroupService); verify(getRequestedFor(urlEqualTo(METADATA_URL)).withUrl(METADATA_URL)); } @Test public void notAuthenticatedRedirect() throws Exception { confgiureFilter(PreAuthenticatedUserNameRoleSource.UserGroupService); MockHttpServletRequest request = createRequest("/foo/bar"); MockHttpServletResponse response = new MockHttpServletResponse(); MockFilterChain chain = new MockFilterChain(); getProxy().doFilter(request, response, chain); assertTrue(response.getStatus() == MockHttpServletResponse.SC_MOVED_TEMPORARILY); String redirectURL = response.getHeader("Location"); assertThat(redirectURL, CoreMatchers.containsString(REDIRECT_URL)); URIBuilder uriBuilder = new URIBuilder(redirectURL); List<NameValuePair> urlParameters = uriBuilder.getQueryParams(); String samlRequest = null; for (NameValuePair par : urlParameters) { if (par.getName().equals("SAMLRequest")) { samlRequest = par.getValue(); break; } } assertNotNull(samlRequest); StringSamlDecoder decoder = new StringSamlDecoder(); SAMLObject samlRequestObject = decoder.decode(samlRequest); assertNotNull(samlRequestObject); } @Test public void autorizationWithGroup() throws Exception { confgiureFilter(PreAuthenticatedUserNameRoleSource.UserGroupService); MockHttpServletRequest request = createRequest("/foo/bar"); MockHttpServletResponse response = new MockHttpServletResponse(); MockFilterChain chain = new MockFilterChain(); getProxy().doFilter(request, response, chain); /* * Build POST request form IDP to GeoServer */ String encodedResponseMessage = buildSAMLRespons("abc@xyz.com"); request = createRequest(SAMLProcessingFilter.FILTER_URL); request.setMethod("POST"); request.addParameter("SAMLResponse", encodedResponseMessage); chain = new MockFilterChain(); response = new MockHttpServletResponse(); getProxy().doFilter(request, response, chain); /* * Check user */ SecurityContext ctx = (SecurityContext) request.getSession(false) .getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY); assertNotNull(ctx); Authentication auth = ctx.getAuthentication(); assertNotNull(auth); assertNull(SecurityContextHolder.getContext().getAuthentication()); checkForAuthenticatedRole(auth); assertEquals("abc@xyz.com", auth.getPrincipal()); } @Test public void authenticationWithRoles() throws Exception { confgiureFilter(PreAuthenticatedUserNameRoleSource.RoleService); MockHttpServletRequest request = createRequest("/foo/bar"); MockHttpServletResponse response = new MockHttpServletResponse(); MockFilterChain chain = new MockFilterChain(); getProxy().doFilter(request, response, chain); /* * Build POST request form IDP to GeoServer */ String encodedResponseMessage = buildSAMLRespons(testUserName); request = createRequest(SAMLProcessingFilter.FILTER_URL); request.setMethod("POST"); request.addParameter("SAMLResponse", encodedResponseMessage); chain = new MockFilterChain(); response = new MockHttpServletResponse(); getProxy().doFilter(request, response, chain); /* * Check user */ SecurityContext ctx = (SecurityContext) request.getSession(false) .getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY); assertNotNull(ctx); Authentication auth = ctx.getAuthentication(); assertNotNull(auth); assertNull(SecurityContextHolder.getContext().getAuthentication()); checkForAuthenticatedRole(auth); boolean hasRootRole = false; for (GrantedAuthority a : auth.getAuthorities()) { if (a.getAuthority().equals(rootRole)) { hasRootRole = true; break; } } assertTrue(hasRootRole); assertEquals(testUserName, auth.getPrincipal()); } @Test public void logoutTest() throws Exception { LogoutFilterChain logoutchain = (LogoutFilterChain) getSecurityManager().getSecurityConfig() .getFilterChain().getRequestChainByName("webLogout"); confgiureFilter(PreAuthenticatedUserNameRoleSource.RoleService); MockHttpServletRequest request = createRequest("/foo/bar"); MockHttpServletResponse response = new MockHttpServletResponse(); MockFilterChain chain = new MockFilterChain(); getProxy().doFilter(request, response, chain); /* * Build POST request form IDP to GeoServer */ String encodedResponseMessage = buildSAMLRespons(testUserName); request = createRequest(SAMLProcessingFilter.FILTER_URL); request.setMethod("POST"); request.addParameter("SAMLResponse", encodedResponseMessage); chain = new MockFilterChain(); response = new MockHttpServletResponse(); getProxy().doFilter(request, response, chain); /* * Check user */ SecurityContext ctx = (SecurityContext) request.getSession(false) .getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY); assertNotNull(ctx); Authentication auth = ctx.getAuthentication(); assertEquals(testUserName, auth.getPrincipal()); /* * Logout */ SecurityContextHolder.setContext(ctx); request = createRequest(logoutchain.getPatterns().get(0)); response = new MockHttpServletResponse(); chain = new MockFilterChain(); GeoServerLogoutFilter logoutFilter = (GeoServerLogoutFilter) getSecurityManager() .loadFilter(GeoServerSecurityFilterChain.FORM_LOGOUT_FILTER); logoutFilter.doFilter(request, response, chain); assertTrue(response.getStatus() == MockHttpServletResponse.SC_MOVED_TEMPORARILY); String redirectURL = response.getHeader("Location"); /* * Check if SAML logut URL will be called */ assertThat(redirectURL, CoreMatchers.containsString(SAMLLogoutFilter.FILTER_URL)); } private String buildSAMLRespons(String username) throws Exception { /* * Buld valid SAML response from template */ DateTime now = new DateTime(); String xml = IOUtils.toString(this.getClass().getResourceAsStream("/__files/response.xml"), "UTF-8"); DocumentBuilderFactory domFactory = DocumentBuilderFactory.newInstance(); Document doc = domFactory.newDocumentBuilder() .parse(new InputSource(new ByteArrayInputStream(xml.getBytes("utf-8")))); XPath xpath = XPathFactory.newInstance().newXPath(); NodeList nodes = (NodeList) xpath.evaluate("//@IssueInstant", doc, XPathConstants.NODESET); for (int idx = 0; idx < nodes.getLength(); idx++) { Node value = nodes.item(idx); value.setNodeValue(now.toString("yyyy-MM-dd'T'HH:mm:ssZ")); } nodes = (NodeList) xpath.evaluate("//@NotOnOrAfter", doc, XPathConstants.NODESET); for (int idx = 0; idx < nodes.getLength(); idx++) { Node value = nodes.item(idx); value.setNodeValue(now.toString("yyyy-MM-dd'T'HH:mm:ssZ")); } nodes = (NodeList) xpath.evaluate("//@NotBefore", doc, XPathConstants.NODESET); for (int idx = 0; idx < nodes.getLength(); idx++) { Node value = nodes.item(idx); value.setNodeValue(now.toString("yyyy-MM-dd'T'HH:mm:ssZ")); } nodes = (NodeList) xpath.evaluate("//@AuthnInstant", doc, XPathConstants.NODESET); for (int idx = 0; idx < nodes.getLength(); idx++) { Node value = nodes.item(idx); value.setNodeValue(now.toString("yyyy-MM-dd'T'HH:mm:ssZ")); } nodes = (NodeList) xpath.evaluate("//@SessionNotOnOrAfter", doc, XPathConstants.NODESET); for (int idx = 0; idx < nodes.getLength(); idx++) { Node value = nodes.item(idx); value.setNodeValue(now.plusDays(1).toString("yyyy-MM-dd'T'HH:mm:ssZ")); } Node node = (Node) xpath.evaluate("//*[local-name() = 'NameID']/text()", doc, XPathConstants.NODE); node.setNodeValue(username); Transformer xformer = TransformerFactory.newInstance().newTransformer(); StringWriter writer = new StringWriter(); xformer.transform(new DOMSource(doc), new StreamResult(writer)); String output = writer.getBuffer().toString(); String encodedResponseMessage = Base64 .encodeBytes(output.getBytes("UTF-8"), Base64.DONT_BREAK_LINES).trim(); return encodedResponseMessage; } private void confgiureFilter(PreAuthenticatedUserNameRoleSource serviceType) { try { String oneloginFilterName = "testOneloginFilter"; if (config == null) { config = new OneloginAuthenticationFilterConfig(); config.setWantAssertionSigned(false); config.setClassName(OneloginAuthenticationFilter.class.getName()); config.setName(oneloginFilterName); config.setEntityId("geoserver"); config.setMetadataURL( "https://localhost:" + idpSamlService.httpsPort() + METADATA_URL); } config.setUserGroupServiceName( serviceType == PreAuthenticatedUserNameRoleSource.RoleService ? "rs1" : "ug1"); config.setRoleSource(serviceType); getSecurityManager().saveFilter(config); prepareFilterChain(pattern, oneloginFilterName); modifyChain(pattern, false, true, null); } catch (Exception e) { } } }