/* * Copyright [2007] [University Corporation for Advanced Internet Development, Inc.] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.opensaml.saml2.binding.decoding; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; import java.security.KeyPair; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; import org.opensaml.common.BaseTestCase; import org.opensaml.common.SAMLObject; import org.opensaml.common.binding.BasicSAMLMessageContext; import org.opensaml.common.binding.decoding.SAMLMessageDecoder; import org.opensaml.saml2.core.AuthnRequest; import org.opensaml.saml2.core.RequestAbstractType; import org.opensaml.saml2.core.Response; import org.opensaml.util.URLBuilder; import org.opensaml.ws.message.decoder.MessageDecodingException; import org.opensaml.ws.message.encoder.MessageEncodingException; import org.opensaml.ws.transport.http.HttpServletRequestAdapter; import org.opensaml.xml.io.MarshallingException; import org.opensaml.xml.security.SecurityException; import org.opensaml.xml.security.SecurityHelper; import org.opensaml.xml.security.SecurityTestHelper; import org.opensaml.xml.security.credential.Credential; import org.opensaml.xml.signature.Signature; import org.opensaml.xml.signature.Signer; import org.opensaml.xml.util.Base64; import org.opensaml.xml.util.XMLHelper; import org.springframework.mock.web.MockHttpServletRequest; /** * Test case for HTTP POST decoders. */ public class HTTPPostDecoderTest extends BaseTestCase { private String authnRequestDestination = "https://idp.example.com/idp/sso"; private String expectedRelayValue = "relay"; private SAMLMessageDecoder decoder; private BasicSAMLMessageContext messageContext; private MockHttpServletRequest httpRequest; /** {@inheritDoc} */ protected void setUp() throws Exception { super.setUp(); httpRequest = new MockHttpServletRequest(); httpRequest.setMethod("POST"); httpRequest.setParameter("RelayState", expectedRelayValue); messageContext = new BasicSAMLMessageContext(); messageContext.setInboundMessageTransport(new HttpServletRequestAdapter(httpRequest)); decoder = new HTTPPostDecoder(); } /** * Test decoding a SAML httpRequest. */ public void testRequestDecoding() throws Exception { httpRequest.setParameter("SAMLRequest", "PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz4KPHNhbWxwOkF1dGhuUm" + "VxdWVzdCBJRD0iZm9vIiBJc3N1ZUluc3RhbnQ9IjE5NzAtMDEtMDFUMDA6MDA6MDAuMDAwWiIgVmVyc2lvbj0iMi4wIiB4bW" + "xuczpzYW1scD0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOnByb3RvY29sIi8+"); decoder.decode(messageContext); assertTrue(messageContext.getInboundMessage() instanceof RequestAbstractType); assertTrue(messageContext.getInboundSAMLMessage() instanceof RequestAbstractType); assertEquals(expectedRelayValue, messageContext.getRelayState()); } /** * Test decoding a SAML response. */ public void testResponseDecoding() throws Exception { httpRequest.setParameter("SAMLResponse", "PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz4KPHNhbWxwOlJlc3Bvbn" + "NlIElEPSJmb28iIElzc3VlSW5zdGFudD0iMTk3MC0wMS0wMVQwMDowMDowMC4wMDBaIiBWZXJzaW9uPSIyLjAiIHhtbG5zOnN" + "hbWxwPSJ1cm46b2FzaXM6bmFtZXM6dGM6U0FNTDoyLjA6cHJvdG9jb2wiPjxzYW1scDpTdGF0dXM+PHNhbWxwOlN0YXR1c0Nv" + "ZGUgVmFsdWU9InVybjpvYXNpczpuYW1lczp0YzpTQU1MOjIuMDpzdGF0dXM6U3VjY2VzcyIvPjwvc2FtbHA6U3RhdHVzPjwvc" + "2FtbHA6UmVzcG9uc2U+"); decoder.decode(messageContext); assertTrue(messageContext.getInboundMessage() instanceof Response); assertTrue(messageContext.getInboundSAMLMessage() instanceof Response); assertEquals(expectedRelayValue, messageContext.getRelayState()); } public void testMessageEndpointGood() throws Exception { AuthnRequest samlRequest = (AuthnRequest) unmarshallElement("/data/org/opensaml/saml2/binding/AuthnRequest.xml"); String deliveredEndpointURL = samlRequest.getDestination(); httpRequest.setParameter("SAMLRequest", encodeMessage(samlRequest)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); } catch (SecurityException e) { fail("Caught SecurityException: " + e.getMessage()); } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointGoodWithQueryParams() throws Exception { AuthnRequest samlRequest = (AuthnRequest) unmarshallElement("/data/org/opensaml/saml2/binding/AuthnRequest.xml"); String deliveredEndpointURL = samlRequest.getDestination() + "?paramFoo=bar¶mBar=baz"; httpRequest.setParameter("SAMLRequest", encodeMessage(samlRequest)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); } catch (SecurityException e) { fail("Caught SecurityException: " + e.getMessage()); } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointInvalidURI() throws Exception { AuthnRequest samlRequest = (AuthnRequest) unmarshallElement("/data/org/opensaml/saml2/binding/AuthnRequest.xml"); String deliveredEndpointURL = samlRequest.getDestination() + "/some/other/endpointURI"; httpRequest.setParameter("SAMLRequest", encodeMessage(samlRequest)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); fail("Passed delivered endpoint check, should have failed"); } catch (SecurityException e) { // do nothing, failure expected } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointInvalidHost() throws Exception { AuthnRequest samlRequest = (AuthnRequest) unmarshallElement("/data/org/opensaml/saml2/binding/AuthnRequest.xml"); String deliveredEndpointURL = "https://bogusidp.example.com/idp/sso"; httpRequest.setParameter("SAMLRequest", encodeMessage(samlRequest)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); fail("Passed delivered endpoint check, should have failed"); } catch (SecurityException e) { // do nothing, failure expected } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointMissingDestinationNotSigned() throws Exception { AuthnRequest samlRequest = (AuthnRequest) unmarshallElement("/data/org/opensaml/saml2/binding/AuthnRequest.xml"); samlRequest.setDestination(null); String deliveredEndpointURL = authnRequestDestination; httpRequest.setParameter("SAMLRequest", encodeMessage(samlRequest)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); } catch (SecurityException e) { fail("Caught SecurityException: " + e.getMessage()); } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointMissingDestinationSigned() throws Exception { AuthnRequest samlRequest = (AuthnRequest) unmarshallElement("/data/org/opensaml/saml2/binding/AuthnRequest.xml"); samlRequest.setDestination(null); Signature signature = (Signature) buildXMLObject(Signature.DEFAULT_ELEMENT_NAME); KeyPair kp = SecurityTestHelper.generateKeyPair("RSA", 1024, null); Credential signingCred = SecurityHelper.getSimpleCredential(kp.getPublic(), kp.getPrivate()); signature.setSigningCredential(signingCred); samlRequest.setSignature(signature); SecurityHelper.prepareSignatureParams(signature, signingCred, null, null); marshallerFactory.getMarshaller(samlRequest).marshall(samlRequest); Signer.signObject(signature); String deliveredEndpointURL = authnRequestDestination; httpRequest.setParameter("SAMLRequest", encodeMessage(samlRequest)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); fail("Passed delivered endpoint check, should have failed, binding requires endpoint on signed message"); } catch (SecurityException e) { // do nothing, failure expected } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } private void populateRequestURL(MockHttpServletRequest request, String requestURL) { URL url = null; try { url = new URL(requestURL); } catch (MalformedURLException e) { fail("Malformed URL: " + e.getMessage()); } request.setScheme(url.getProtocol()); request.setServerName(url.getHost()); if (url.getPort() != -1) { request.setServerPort(url.getPort()); } else { if ("https".equalsIgnoreCase(url.getProtocol())) { request.setServerPort(443); } else if ("http".equalsIgnoreCase(url.getProtocol())) { request.setServerPort(80); } } request.setRequestURI(url.getPath()); request.setQueryString(url.getQuery()); } protected String encodeMessage(SAMLObject message) throws MessageEncodingException, MarshallingException { marshallerFactory.getMarshaller(message).marshall(message); String messageStr = XMLHelper.nodeToString(message.getDOM()); return Base64.encodeBytes(messageStr.getBytes(), Base64.DONT_BREAK_LINES); } }