/* * Copyright [2007] [University Corporation for Advanced Internet Development, Inc.] * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.opensaml.saml1.binding.decoding; import java.net.MalformedURLException; import java.net.URL; import java.security.KeyPair; import org.opensaml.common.BaseTestCase; import org.opensaml.common.SAMLObject; import org.opensaml.common.binding.BasicSAMLMessageContext; import org.opensaml.common.binding.decoding.SAMLMessageDecoder; import org.opensaml.saml1.core.Response; import org.opensaml.ws.message.decoder.MessageDecodingException; import org.opensaml.ws.message.encoder.MessageEncodingException; import org.opensaml.ws.transport.http.HttpServletRequestAdapter; import org.opensaml.xml.io.MarshallingException; import org.opensaml.xml.security.SecurityException; import org.opensaml.xml.security.SecurityHelper; import org.opensaml.xml.security.SecurityTestHelper; import org.opensaml.xml.security.credential.Credential; import org.opensaml.xml.signature.Signature; import org.opensaml.xml.signature.Signer; import org.opensaml.xml.util.Base64; import org.opensaml.xml.util.XMLHelper; import org.springframework.mock.web.MockHttpServletRequest; /** * Test case for SAML 1 HTTP POST decoding. */ public class HTTPPostDecoderTest extends BaseTestCase { private String responseRecipient = "https://sp.example.org/sso/acs"; private String expectedRelayValue = "relay"; private SAMLMessageDecoder decoder; private BasicSAMLMessageContext messageContext; private MockHttpServletRequest httpRequest; /** {@inheritDoc} */ protected void setUp() throws Exception { super.setUp(); httpRequest = new MockHttpServletRequest(); httpRequest.setMethod("POST"); httpRequest.setParameter("TARGET", expectedRelayValue); messageContext = new BasicSAMLMessageContext(); messageContext.setInboundMessageTransport(new HttpServletRequestAdapter(httpRequest)); decoder = new HTTPPostDecoder(null); } /** Test decoding message. */ public void testDecode() throws Exception { Response samlResponse = (Response) unmarshallElement("/data/org/opensaml/saml1/binding/Response.xml"); String deliveredEndpointURL = samlResponse.getRecipient(); httpRequest.setParameter("SAMLResponse", encodeMessage(samlResponse)); populateRequestURL(httpRequest, deliveredEndpointURL); decoder.decode(messageContext); assertTrue(messageContext.getInboundMessage() instanceof Response); assertTrue(messageContext.getInboundSAMLMessage() instanceof Response); assertEquals(expectedRelayValue, messageContext.getRelayState()); } public void testMessageEndpointGood() throws Exception { Response samlResponse = (Response) unmarshallElement("/data/org/opensaml/saml1/binding/Response.xml"); String deliveredEndpointURL = samlResponse.getRecipient(); httpRequest.setParameter("SAMLResponse", encodeMessage(samlResponse)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); } catch (SecurityException e) { fail("Caught SecurityException: " + e.getMessage()); } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointGoodWithQueryParams() throws Exception { Response samlResponse = (Response) unmarshallElement("/data/org/opensaml/saml1/binding/Response.xml"); String deliveredEndpointURL = samlResponse.getRecipient() + "?paramFoo=bar¶mBar=baz"; httpRequest.setParameter("SAMLResponse", encodeMessage(samlResponse)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); } catch (SecurityException e) { fail("Caught SecurityException: " + e.getMessage()); } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointInvalidURI() throws Exception { Response samlResponse = (Response) unmarshallElement("/data/org/opensaml/saml1/binding/Response.xml"); String deliveredEndpointURL = samlResponse.getRecipient() + "/some/other/endpointURI"; httpRequest.setParameter("SAMLResponse", encodeMessage(samlResponse)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); fail("Passed delivered endpoint check, should have failed"); } catch (SecurityException e) { // do nothing, failure expected } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointInvalidHost() throws Exception { Response samlResponse = (Response) unmarshallElement("/data/org/opensaml/saml1/binding/Response.xml"); String deliveredEndpointURL = "https://bogus-sp.example.com/sso/acs"; httpRequest.setParameter("SAMLResponse", encodeMessage(samlResponse)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); fail("Passed delivered endpoint check, should have failed"); } catch (SecurityException e) { // do nothing, failure expected } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointMissingDestinationNotSigned() throws Exception { Response samlResponse = (Response) unmarshallElement("/data/org/opensaml/saml1/binding/Response.xml"); samlResponse.setRecipient(null); String deliveredEndpointURL = responseRecipient; httpRequest.setParameter("SAMLResponse", encodeMessage(samlResponse)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); fail("Passed delivered endpoint check, should have failed, binding requires endpoint on unsigned message"); } catch (SecurityException e) { // do nothing, failure expected } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } public void testMessageEndpointMissingDestinationSigned() throws Exception { Response samlResponse = (Response) unmarshallElement("/data/org/opensaml/saml1/binding/Response.xml"); samlResponse.setRecipient(null); Signature signature = (Signature) buildXMLObject(Signature.DEFAULT_ELEMENT_NAME); KeyPair kp = SecurityTestHelper.generateKeyPair("RSA", 1024, null); Credential signingCred = SecurityHelper.getSimpleCredential(kp.getPublic(), kp.getPrivate()); signature.setSigningCredential(signingCred); samlResponse.setSignature(signature); SecurityHelper.prepareSignatureParams(signature, signingCred, null, null); marshallerFactory.getMarshaller(samlResponse).marshall(samlResponse); Signer.signObject(signature); String deliveredEndpointURL = responseRecipient; httpRequest.setParameter("SAMLResponse", encodeMessage(samlResponse)); populateRequestURL(httpRequest, deliveredEndpointURL); try { decoder.decode(messageContext); fail("Passed delivered endpoint check, should have failed, binding requires endpoint on signed message"); } catch (SecurityException e) { // do nothing, failure expected } catch (MessageDecodingException e) { fail("Caught MessageDecodingException: " + e.getMessage()); } } private void populateRequestURL(MockHttpServletRequest request, String requestURL) { URL url = null; try { url = new URL(requestURL); } catch (MalformedURLException e) { fail("Malformed URL: " + e.getMessage()); } request.setScheme(url.getProtocol()); request.setServerName(url.getHost()); if (url.getPort() != -1) { request.setServerPort(url.getPort()); } else { if ("https".equalsIgnoreCase(url.getProtocol())) { request.setServerPort(443); } else if ("http".equalsIgnoreCase(url.getProtocol())) { request.setServerPort(80); } } request.setRequestURI(url.getPath()); request.setQueryString(url.getQuery()); } protected String encodeMessage(SAMLObject message) throws MessageEncodingException, MarshallingException { marshallerFactory.getMarshaller(message).marshall(message); String messageStr = XMLHelper.nodeToString(message.getDOM()); return Base64.encodeBytes(messageStr.getBytes(), Base64.DONT_BREAK_LINES); } }