/**
* Copyright (c) Codice Foundation
* <p>
* This is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser
* General Public License as published by the Free Software Foundation, either version 3 of the
* License, or any later version.
* <p>
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details. A copy of the GNU Lesser General Public License
* is distributed along with this program and can be found at
* <http://www.gnu.org/licenses/lgpl.html>.
*/
package org.codice.ddf.security.handler.cas;
import static org.junit.Assert.assertEquals;
import static org.mockito.AdditionalMatchers.not;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.util.Arrays;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.codice.ddf.security.handler.api.HandlerResult;
import org.codice.ddf.security.handler.cas.filter.ProxyFilter;
import org.codice.ddf.security.handler.cas.filter.ProxyFilterChain;
import org.jasig.cas.client.authentication.AttributePrincipal;
import org.jasig.cas.client.util.AbstractCasFilter;
import org.jasig.cas.client.validation.Assertion;
import org.junit.Test;
import ddf.security.sts.client.configuration.STSClientConfiguration;
public class CasHandlerTest {
private static final String STS_ADDRESS = "http://localhost:8181/sts";
private static final String MOCK_TICKET = "ST-956-Lyg0BdLkgdrBO9W17bXS";
private static final String SESSION_ID = "12345678910";
/**
* Tests that the handler properly returns a NO_ACTION result if no assertion is available in
* the request and resolve is false.
*
* @throws ServletException
*/
@Test
public void testNoPrincipalNoResolve() throws ServletException {
CasHandler handler = createHandler();
HandlerResult result = handler.getNormalizedToken(createServletRequest(false),
mock(HttpServletResponse.class),
new ProxyFilterChain(null),
false);
// NO_ACTION due to resolve being false
assertEquals(HandlerResult.Status.NO_ACTION, result.getStatus());
}
/**
* Tests that the handler properly returns a COMPLETED result if the assertion is in the session.
*
* @throws ServletException
*/
@Test
public void testPrincipalNoResolve() throws ServletException {
CasHandler handler = createHandler();
HandlerResult result = handler.getNormalizedToken(createServletRequest(true),
mock(HttpServletResponse.class),
new ProxyFilterChain(null),
false);
assertEquals(HandlerResult.Status.COMPLETED, result.getStatus());
}
/**
* Tests that the handler properly returns a REDIRECTED result if the assertion is not in the
* session and resolve is true.
*
* @throws ServletException
* @throws IOException
*/
@Test
public void testNoPrincipalResolve() throws ServletException, IOException {
CasHandler handler = createHandler();
Filter testFilter = mock(Filter.class);
handler.setProxyFilter(new ProxyFilter(Arrays.asList(testFilter)));
HandlerResult result = handler.getNormalizedToken(createServletRequest(false),
mock(HttpServletResponse.class),
new ProxyFilterChain(null),
true);
assertEquals(HandlerResult.Status.REDIRECTED, result.getStatus());
// verify that the filter was called once
verify(testFilter).doFilter(any(ServletRequest.class),
any(ServletResponse.class),
any(FilterChain.class));
}
/**
* Tests that the handler properly returns a COMPLETED result if the assertion is in the
* session and resolve is true.
*
* @throws ServletException
* @throws IOException
*/
@Test
public void testPrincipalResolve() throws ServletException, IOException {
CasHandler handler = createHandler();
HandlerResult result = handler.getNormalizedToken(createServletRequest(true),
mock(HttpServletResponse.class),
new ProxyFilterChain(null),
true);
assertEquals(HandlerResult.Status.COMPLETED, result.getStatus());
}
/**
* Tests that the handler properly returns a COMPLETED result from having a cached session that
* contains the CAS assertion.
*
* @throws ServletException
* @throws IOException
*/
@Test
public void testCachedPrincipalResolve() throws ServletException, IOException {
CasHandler handler = createHandler();
HttpServletRequest servletRequest = createServletRequest(true);
HttpSession session = servletRequest.getSession();
HandlerResult result = handler.getNormalizedToken(servletRequest,
mock(HttpServletResponse.class),
new ProxyFilterChain(null),
true);
assertEquals(HandlerResult.Status.COMPLETED, result.getStatus());
// now check for caching sessions
servletRequest = createServletRequest(false);
when(servletRequest.getSession()).thenReturn(session);
when(servletRequest.getSession(any(Boolean.class))).thenReturn(session);
result = handler.getNormalizedToken(servletRequest,
mock(HttpServletResponse.class),
new ProxyFilterChain(null),
true);
assertEquals(HandlerResult.Status.COMPLETED, result.getStatus());
}
private CasHandler createHandler() {
CasHandler handler = new CasHandler();
STSClientConfiguration clientConfiguration = mock(STSClientConfiguration.class);
when(clientConfiguration.getAddress()).thenReturn(STS_ADDRESS);
handler.setClientConfiguration(clientConfiguration);
Filter testFilter = mock(Filter.class);
handler.setProxyFilter(new ProxyFilter(Arrays.asList(testFilter)));
return handler;
}
private HttpServletRequest createServletRequest(boolean shouldAddCas) {
HttpServletRequest servletRequest = mock(HttpServletRequest.class);
HttpSession session = mock(HttpSession.class);
when(session.getId()).thenReturn(SESSION_ID);
when(servletRequest.getSession()).thenReturn(session);
when(servletRequest.getSession(any(Boolean.class))).thenReturn(session);
if (shouldAddCas) {
// Mock CAS items
Assertion assertion = mock(Assertion.class);
when(session.getAttribute(AbstractCasFilter.CONST_CAS_ASSERTION)).thenReturn(assertion);
AttributePrincipal principal = mock(AttributePrincipal.class);
when(principal.getProxyTicketFor(STS_ADDRESS)).thenReturn(MOCK_TICKET);
when(principal.getProxyTicketFor(not(eq(STS_ADDRESS)))).thenThrow(new RuntimeException(
"Tried to create ticket for incorrect service."));
when(assertion.getPrincipal()).thenReturn(principal);
}
return servletRequest;
}
}