/**
* Copyright (c) Codice Foundation
* <p>
* This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
* General Public License as published by the Free Software Foundation, either version 3 of the
* License, or any later version.
* <p>
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details. A copy of the GNU Lesser General Public License
* is distributed along with this program and can be found at
* <http://www.gnu.org/licenses/lgpl.html>.
*/
package org.codice.ddf.security.idp.client;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.xml.HasXPath.hasXPath;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.io.ByteArrayInputStream;
import java.util.Base64;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.Response;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import org.apache.http.HttpStatus;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.opensaml.saml.saml2.core.StatusCode;
import org.w3c.dom.Document;
import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import ddf.security.encryption.EncryptionService;
import ddf.security.http.SessionFactory;
import ddf.security.samlp.SimpleSign;
import ddf.security.samlp.SystemCrypto;
import ddf.security.samlp.impl.RelayStates;
public class AssertionConsumerServiceTest {
private AssertionConsumerService assertionConsumerService;
private SimpleSign simpleSign;
private IdpMetadata idpMetadata;
private SystemCrypto systemCrypto;
private RelayStates<String> relayStates;
private Filter loginFilter;
private SessionFactory sessionFactory;
private EncryptionService encryptionService;
private String cannedResponse;
private HttpServletRequest httpRequest;
private static final String RELAY_STATE_VAL = "b0b4e449-7f69-413f-a844-61fe2256de19";
private static final String LOCATION = "test";
private static final String SIG_ALG_VAL = "http://www.w3.org/2000/09/xmldsig#rsa-sha1";
private static final String SIGNATURE_VAL =
"UTSaVBKoDCw7BM6gLtaJOU7xXo5G4oHaZwvUaHE2Cc48IiJ4nCJLlTamGnGbec/MQhTXv//yHpGQ/jFoasQ4cJ0kRomGItdBQZoOLcmtZ2bJc8V8yTKNctEYziIG9NTQevZOoRCiVzClOFhflTqc+kZ4FZBjLgLBdtySP0OL08Q=";
private static String deflatedSamlResponse;
@Before
public void setUp() throws Exception {
encryptionService = mock(EncryptionService.class);
systemCrypto = new SystemCrypto("encryption.properties",
"signature.properties",
encryptionService);
simpleSign = new SimpleSign(systemCrypto);
relayStates = (RelayStates<String>) mock(RelayStates.class);
when(relayStates.encode("fubar")).thenReturn(RELAY_STATE_VAL);
when(relayStates.decode(RELAY_STATE_VAL)).thenReturn(LOCATION);
loginFilter = mock(javax.servlet.Filter.class);
sessionFactory = mock(SessionFactory.class);
httpRequest = mock(HttpServletRequest.class);
when(httpRequest.getRequestURL()).thenReturn(new StringBuffer("fubar"));
when(httpRequest.isSecure()).thenReturn(true);
idpMetadata = new IdpMetadata();
assertionConsumerService = new AssertionConsumerService(simpleSign,
idpMetadata,
systemCrypto,
relayStates);
assertionConsumerService.setRequest(httpRequest);
assertionConsumerService.setLoginFilter(loginFilter);
assertionConsumerService.setSessionFactory(sessionFactory);
cannedResponse = Resources.toString(Resources.getResource(getClass(), "/SAMLResponse.xml"),
Charsets.UTF_8);
String metadata = Resources.toString(Resources.getResource(getClass(), "/IDPmetadata.xml"),
Charsets.UTF_8);
deflatedSamlResponse = Resources.toString(Resources.getResource(getClass(),
"/DeflatedSAMLResponse.txt"), Charsets.UTF_8);
idpMetadata.setMetadata(metadata);
}
@Test
public void testPostSamlResponse() throws Exception {
Response response = assertionConsumerService.postSamlResponse(Base64.getEncoder()
.encodeToString(this.cannedResponse.getBytes()), RELAY_STATE_VAL);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("Response LOCATION was " + response.getLocation() + " expected " + LOCATION,
response.getLocation()
.toString(),
equalTo(LOCATION));
}
@Test
public void testPostSamlResponseNotSecure() throws Exception {
when(httpRequest.isSecure()).thenReturn(false);
Response response = assertionConsumerService.postSamlResponse(Base64.getEncoder()
.encodeToString(this.cannedResponse.getBytes()), RELAY_STATE_VAL);
assertThat("The http response was not 500 ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Ignore
@Test
public void testPostSamlResponseDoubleSignature() throws Exception {
cannedResponse = Resources.toString(Resources.getResource(getClass(),
"/DoubleSignedSAMLResponse.txt"), Charsets.UTF_8);
String relayStateValue = "a0552c29-8b2b-492c-87fb-17a20d22f887";
when(relayStates.encode("fubar")).thenReturn(relayStateValue);
when(relayStates.decode(relayStateValue)).thenReturn(LOCATION);
Response response =
assertionConsumerService.postSamlResponse(new String(this.cannedResponse.getBytes()),
relayStateValue);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("Response LOCATION was " + response.getLocation() + " expected " + LOCATION,
response.getLocation()
.toString(),
equalTo(LOCATION));
}
@Test
public void testGetSamlResponse() throws Exception {
Response response = assertionConsumerService.getSamlResponse(deflatedSamlResponse,
RELAY_STATE_VAL,
SIG_ALG_VAL,
SIGNATURE_VAL);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("Response LOCATION was " + response.getLocation() + " expected " + LOCATION,
response.getLocation()
.toString(),
equalTo(LOCATION));
}
@Test
public void testGetSamlResponseInvalidSignature() throws Exception {
Response response = assertionConsumerService.getSamlResponse(deflatedSamlResponse,
RELAY_STATE_VAL,
SIG_ALG_VAL,
SIGNATURE_VAL.replace('z', 'x'));
assertThat("The http response was not 500 SERVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testGetSamlResponseNoSignature() throws Exception {
Response response = assertionConsumerService.getSamlResponse(deflatedSamlResponse,
RELAY_STATE_VAL,
SIG_ALG_VAL,
null);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("Response LOCATION was " + response.getLocation() + " expected " + LOCATION,
response.getLocation()
.toString(),
equalTo(LOCATION));
}
@Test
public void testGetSamlResponseNoSignatureAlgorithm() throws Exception {
Response response = assertionConsumerService.getSamlResponse(deflatedSamlResponse,
RELAY_STATE_VAL,
null,
SIGNATURE_VAL);
assertThat("The http response was not 500 SERVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessBadSamlResponse() throws Exception {
String badRequest = Resources.toString(Resources.getResource(getClass(),
"/SAMLRequest.xml"), Charsets.UTF_8);
Response response = assertionConsumerService.processSamlResponse(badRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseAgainstLoginPage() throws Exception {
when(relayStates.decode(RELAY_STATE_VAL)).thenReturn("https://test/login?prevurl=/newurl");
Response response = assertionConsumerService.processSamlResponse(cannedResponse,
RELAY_STATE_VAL);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("The response did not redirect to the correct location.",
response.getLocation()
.getPath(),
is("/newurl"));
}
@Test
public void testProcessSamlResponseAgainstLoginPage1() throws Exception {
when(relayStates.decode(RELAY_STATE_VAL)).thenReturn("https://test/login/?prevurl=/newurl");
Response response = assertionConsumerService.processSamlResponse(cannedResponse,
RELAY_STATE_VAL);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("The response did not redirect to the correct location.",
response.getLocation()
.getPath(),
is("/newurl"));
}
@Test
public void testProcessSamlResponseAgainstLoginPageBadQuery() throws Exception {
when(relayStates.decode(RELAY_STATE_VAL)).thenReturn("https://test/login?blah=/newurl");
Response response = assertionConsumerService.processSamlResponse(cannedResponse,
RELAY_STATE_VAL);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("The response did not redirect to the correct location.",
response.getLocation()
.getPath(),
is("/login"));
}
@Test
public void testProcessSamlResponseAuthnFailure() throws Exception {
String failureRequest = cannedResponse.replace(StatusCode.SUCCESS, StatusCode.AUTHN_FAILED);
Response response = assertionConsumerService.processSamlResponse(failureRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseRequestDenied() throws Exception {
String failureRequest = cannedResponse.replace(StatusCode.SUCCESS,
StatusCode.REQUEST_DENIED);
Response response = assertionConsumerService.processSamlResponse(failureRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseRequestUnsupported() throws Exception {
String failureRequest = cannedResponse.replace(StatusCode.SUCCESS,
StatusCode.REQUEST_UNSUPPORTED);
Response response = assertionConsumerService.processSamlResponse(failureRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseUnsupportedBinding() throws Exception {
String failureRequest = cannedResponse.replace(StatusCode.SUCCESS,
StatusCode.UNSUPPORTED_BINDING);
Response response = assertionConsumerService.processSamlResponse(failureRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseNoAssertion() throws Exception {
String failureRequest = Resources.toString(Resources.getResource(getClass(),
"/SAMLResponse-noAssertion.xml"), Charsets.UTF_8);
Response response = assertionConsumerService.processSamlResponse(failureRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseNullAssertion() throws Exception {
String failureRequest = Resources.toString(Resources.getResource(getClass(),
"/SAMLResponse-nullAssertion.xml"), Charsets.UTF_8);
Response response = assertionConsumerService.processSamlResponse(failureRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("Response LOCATION was " + response.getLocation() + " expected " + LOCATION,
response.getLocation()
.toString(),
equalTo(LOCATION));
}
@Test
public void testProcessSamlResponseMalformedAssertion() throws Exception {
String failureRequest = Resources.toString(Resources.getResource(getClass(),
"/SAMLResponse-malformedAssertion.xml"), Charsets.UTF_8);
Response response = assertionConsumerService.processSamlResponse(failureRequest,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseMultipleAssertion() throws Exception {
String multipleAssertions = Resources.toString(Resources.getResource(getClass(),
"/SAMLResponse-multipleAssertions.xml"), Charsets.UTF_8);
Response response = assertionConsumerService.processSamlResponse(multipleAssertions,
RELAY_STATE_VAL);
assertThat("The http response was not 303 SEE OTHER",
response.getStatus(),
is(HttpStatus.SC_SEE_OTHER));
assertThat("Response LOCATION was " + response.getLocation() + " expected " + LOCATION,
response.getLocation()
.toString(),
equalTo(LOCATION));
}
@Test
public void testProcessSamlResponseEmptySamlResponse() throws Exception {
Response response = assertionConsumerService.processSamlResponse("", RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseLoginFail() throws Exception {
doThrow(ServletException.class).when(loginFilter)
.doFilter(any(ServletRequest.class),
isNull(ServletResponse.class),
any(FilterChain.class));
Response response = assertionConsumerService.processSamlResponse(this.cannedResponse,
RELAY_STATE_VAL);
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
@Test
public void testProcessSamlResponseEmptyRelayState() throws Exception {
Response response = assertionConsumerService.processSamlResponse(this.cannedResponse, "");
assertThat("The http response was not 500 SEVER ERROR",
response.getStatus(),
is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
}
/*
We cannot assume the presence of the SingleLogout Service
DDF-1605
*/
@Ignore
@Test
public void testRetrieveMetadata() throws Exception {
Response response = assertionConsumerService.retrieveMetadata();
Document document = parse(response.getEntity()
.toString());
assertThat("SingleLogoutService Binding attribute was not the expected HTTP-Redirect",
document,
hasXPath("//urn:oasis:names:tc:SAML:2.0:metadata:SingleLogoutService/@Binding",
is(equalTo("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"))));
assertThat("SingleLogoutService Binding attribute was not the expected HTTP-Redirect",
document,
hasXPath("//urn:oasis:names:tc:SAML:2.0:metadata:SingleLogoutService/@Location",
is(equalTo("https://localhost:8993/logout"))));
assertThat("The http response was not 200 OK", response.getStatus(), is(HttpStatus.SC_OK));
assertThat("Response entity was null", response.getEntity(), notNullValue());
}
private static Document parse(String xml) {
try {
DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
documentBuilderFactory.setNamespaceAware(true);
DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder();
return documentBuilder.parse(new ByteArrayInputStream(xml.getBytes()));
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
@Test
public void testGetLoginFilter() throws Exception {
Filter filter = assertionConsumerService.getLoginFilter();
assertThat("Returned login filter was not the same as the one set",
filter,
equalTo(loginFilter));
}
}