package org.apereo.cas; import org.apereo.cas.authentication.AuthenticationServiceSelectionPlan; import org.apereo.cas.authentication.DefaultAuthenticationServiceSelectionPlan; import org.apereo.cas.authentication.DefaultAuthenticationServiceSelectionStrategy; import org.apereo.cas.authentication.policy.AcceptAnyAuthenticationPolicyFactory; import org.apereo.cas.authentication.Authentication; import org.apereo.cas.authentication.AuthenticationHandler; import org.apereo.cas.authentication.AuthenticationResult; import org.apereo.cas.authentication.BasicCredentialMetaData; import org.apereo.cas.authentication.CredentialMetaData; import org.apereo.cas.authentication.DefaultHandlerResult; import org.apereo.cas.authentication.HandlerResult; import org.apereo.cas.authentication.principal.DefaultPrincipalFactory; import org.apereo.cas.authentication.principal.Service; import org.apereo.cas.authentication.principal.WebApplicationServiceFactory; import org.apereo.cas.logout.LogoutManager; import org.apereo.cas.services.DefaultRegisteredServiceAccessStrategy; import org.apereo.cas.services.DefaultRegisteredServiceUsernameProvider; import org.apereo.cas.services.RefuseRegisteredServiceProxyPolicy; import org.apereo.cas.services.RegexMatchingRegisteredServiceProxyPolicy; import org.apereo.cas.services.RegisteredService; import org.apereo.cas.services.RegisteredServiceProxyPolicy; import org.apereo.cas.services.RegisteredServiceTestUtils; import org.apereo.cas.services.ReturnAllAttributeReleasePolicy; import org.apereo.cas.services.ServicesManager; import org.apereo.cas.services.UnauthorizedProxyingException; import org.apereo.cas.services.UnauthorizedServiceException; import org.apereo.cas.ticket.ExpirationPolicy; import org.apereo.cas.ticket.InvalidTicketException; import org.apereo.cas.ticket.ServiceTicket; import org.apereo.cas.ticket.Ticket; import org.apereo.cas.ticket.TicketGrantingTicket; import org.apereo.cas.ticket.factory.DefaultProxyGrantingTicketFactory; import org.apereo.cas.ticket.factory.DefaultProxyTicketFactory; import org.apereo.cas.ticket.factory.DefaultServiceTicketFactory; import org.apereo.cas.ticket.factory.DefaultTicketFactory; import org.apereo.cas.ticket.factory.DefaultTicketGrantingTicketFactory; import org.apereo.cas.ticket.registry.TicketRegistry; import org.apereo.cas.ticket.support.NeverExpiresExpirationPolicy; import org.apereo.cas.validation.Assertion; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentMatcher; import org.springframework.context.ApplicationEventPublisher; import org.springframework.mock.web.MockHttpServletRequest; import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.IntStream; import static org.junit.Assert.*; import static org.mockito.Mockito.*; /** * Unit tests with the help of Mockito framework. * * @author Dmitriy Kopylenko * @since 3.0.0 */ public class CentralAuthenticationServiceImplWithMockitoTests { private static final String TGT_ID = "tgt-id"; private static final String TGT2_ID = "tgt2-id"; private static final String ST_ID = "st-id"; private static final String ST2_ID = "st2-id"; private static final String SVC1_ID = "test1"; private static final String SVC2_ID = "test2"; private static final String PRINCIPAL = "principal"; @Rule public ExpectedException thrown = ExpectedException.none(); private DefaultCentralAuthenticationService cas; private Authentication authentication; private TicketRegistry ticketRegMock; private static class VerifyServiceByIdMatcher implements ArgumentMatcher<Service> { private final String id; VerifyServiceByIdMatcher(final String id) { this.id = id; } @Override public boolean matches(final Service s) { return s != null && s.getId().equals(this.id); } } @Before public void prepareNewCAS() throws Exception { this.authentication = mock(Authentication.class); when(this.authentication.getAuthenticationDate()).thenReturn(ZonedDateTime.now(ZoneOffset.UTC)); final CredentialMetaData metadata = new BasicCredentialMetaData( RegisteredServiceTestUtils.getCredentialsWithSameUsernameAndPassword("principal")); final Map<String, HandlerResult> successes = new HashMap<>(); successes.put("handler1", new DefaultHandlerResult(mock(AuthenticationHandler.class), metadata)); when(this.authentication.getCredentials()).thenReturn(Arrays.asList(metadata)); when(this.authentication.getSuccesses()).thenReturn(successes); when(this.authentication.getPrincipal()).thenReturn(new DefaultPrincipalFactory().createPrincipal(PRINCIPAL)); final Service service1 = getService(SVC1_ID); final ServiceTicket stMock = createMockServiceTicket(ST_ID, service1); final TicketGrantingTicket tgtRootMock = createRootTicketGrantingTicket(); final TicketGrantingTicket tgtMock = createMockTicketGrantingTicket(TGT_ID, stMock, false, tgtRootMock, new ArrayList<>()); when(tgtMock.getProxiedBy()).thenReturn(getService("proxiedBy")); final List<Authentication> authnListMock = mock(List.class); //Size is required to be 2, so that we can simulate proxying capabilities when(authnListMock.size()).thenReturn(2); when(authnListMock.get(anyInt())).thenReturn(this.authentication); when(tgtMock.getChainedAuthentications()).thenReturn(authnListMock); when(stMock.getGrantingTicket()).thenReturn(tgtMock); final Service service2 = getService(SVC2_ID); final ServiceTicket stMock2 = createMockServiceTicket(ST2_ID, service2); final TicketGrantingTicket tgtMock2 = createMockTicketGrantingTicket(TGT2_ID, stMock2, false, tgtRootMock, authnListMock); //Mock TicketRegistry mockTicketRegistry(stMock, tgtMock, stMock2, tgtMock2); //Mock ServicesManager final ServicesManager smMock = getServicesManager(service1, service2); final DefaultTicketFactory factory = new DefaultTicketFactory( new DefaultProxyGrantingTicketFactory(null, null, null), new DefaultTicketGrantingTicketFactory(null, null, null), new DefaultServiceTicketFactory(new NeverExpiresExpirationPolicy(), Collections.emptyMap(), false, null), new DefaultProxyTicketFactory(null, Collections.emptyMap(), null, true)); final AuthenticationServiceSelectionPlan authenticationRequestServiceSelectionStrategies = new DefaultAuthenticationServiceSelectionPlan(new DefaultAuthenticationServiceSelectionStrategy()); this.cas = new DefaultCentralAuthenticationService(ticketRegMock, factory, smMock, mock(LogoutManager.class), authenticationRequestServiceSelectionStrategies, new AcceptAnyAuthenticationPolicyFactory(), new DefaultPrincipalFactory(), null); this.cas.setApplicationEventPublisher(mock(ApplicationEventPublisher.class)); } private AuthenticationResult getAuthenticationContext() { final AuthenticationResult ctx = mock(AuthenticationResult.class); when(ctx.getAuthentication()).thenReturn(this.authentication); return ctx; } private static ServicesManager getServicesManager(final Service service1, final Service service2) { final RegisteredService mockRegSvc1 = createMockRegisteredService(service1.getId(), true, getServiceProxyPolicy(false)); final RegisteredService mockRegSvc2 = createMockRegisteredService("test", false, getServiceProxyPolicy(true)); final RegisteredService mockRegSvc3 = createMockRegisteredService(service2.getId(), true, getServiceProxyPolicy(true)); final ServicesManager smMock = mock(ServicesManager.class); when(smMock.findServiceBy(argThat(new VerifyServiceByIdMatcher(service1.getId())))).thenReturn(mockRegSvc1); when(smMock.findServiceBy(argThat(new VerifyServiceByIdMatcher("test")))).thenReturn(mockRegSvc2); when(smMock.findServiceBy(argThat(new VerifyServiceByIdMatcher(service2.getId())))).thenReturn(mockRegSvc3); return smMock; } private void mockTicketRegistry(final ServiceTicket stMock, final TicketGrantingTicket tgtMock, final ServiceTicket stMock2, final TicketGrantingTicket tgtMock2) { this.ticketRegMock = mock(TicketRegistry.class); when(ticketRegMock.getTicket(eq(tgtMock.getId()), eq(TicketGrantingTicket.class))).thenReturn(tgtMock); when(ticketRegMock.getTicket(eq(tgtMock2.getId()), eq(TicketGrantingTicket.class))).thenReturn(tgtMock2); when(ticketRegMock.getTicket(eq(stMock.getId()), eq(ServiceTicket.class))).thenReturn(stMock); when(ticketRegMock.getTicket(eq(stMock2.getId()), eq(ServiceTicket.class))).thenReturn(stMock2); when(ticketRegMock.getTickets()).thenReturn(Arrays.asList(tgtMock, tgtMock2, stMock, stMock2)); } @Test public void verifyNonExistentServiceWhenDelegatingTicketGrantingTicket() throws Exception { this.thrown.expect(InvalidTicketException.class); this.thrown.expectMessage("bad-st"); this.cas.createProxyGrantingTicket("bad-st", getAuthenticationContext()); } @Test public void verifyInvalidServiceWhenDelegatingTicketGrantingTicket() throws Exception { this.thrown.expect(UnauthorizedServiceException.class); this.cas.createProxyGrantingTicket(ST_ID, getAuthenticationContext()); } @Test public void disallowVendingServiceTicketsWhenServiceIsNotAllowedToProxyCAS1019() throws Exception { this.thrown.expect(UnauthorizedProxyingException.class); this.cas.grantServiceTicket(TGT_ID, RegisteredServiceTestUtils.getService(SVC1_ID), getAuthenticationContext()); } @Test public void getTicketGrantingTicketIfTicketIdIsNull() throws InvalidTicketException { this.thrown.expect(IllegalArgumentException.class); this.cas.getTicket(null, TicketGrantingTicket.class); } @Test public void getTicketGrantingTicketIfTicketIdIsMissing() throws InvalidTicketException { this.thrown.expect(InvalidTicketException.class); this.cas.getTicket("TGT-9000", TicketGrantingTicket.class); } @Test public void getTicketsWithNoPredicate() { final Collection<Ticket> c = this.cas.getTickets(ticket -> true); assertEquals(c.size(), this.ticketRegMock.getTickets().size()); } @Test public void verifyChainedAuthenticationsOnValidation() throws Exception { final Service svc = RegisteredServiceTestUtils.getService(SVC2_ID); final ServiceTicket st = this.cas.grantServiceTicket(TGT2_ID, svc, getAuthenticationContext()); assertNotNull(st); final Assertion assertion = this.cas.validateServiceTicket(st.getId(), svc); assertNotNull(assertion); assertEquals(assertion.getService(), svc); assertEquals(assertion.getPrimaryAuthentication().getPrincipal().getId(), PRINCIPAL); assertSame(2, assertion.getChainedAuthentications().size()); IntStream.range(0, assertion.getChainedAuthentications().size()) .forEach(i -> assertEquals(assertion.getChainedAuthentications().get(i), authentication)); } private TicketGrantingTicket createRootTicketGrantingTicket() { final TicketGrantingTicket tgtRootMock = mock(TicketGrantingTicket.class); when(tgtRootMock.isExpired()).thenReturn(false); when(tgtRootMock.getAuthentication()).thenReturn(this.authentication); return tgtRootMock; } private TicketGrantingTicket createMockTicketGrantingTicket(final String id, final ServiceTicket svcTicket, final boolean isExpired, final TicketGrantingTicket root, final List<Authentication> chainedAuthnList) { final TicketGrantingTicket tgtMock = mock(TicketGrantingTicket.class); when(tgtMock.isExpired()).thenReturn(isExpired); when(tgtMock.getId()).thenReturn(id); final String svcId = svcTicket.getService().getId(); when(tgtMock.grantServiceTicket(anyString(), argThat(new VerifyServiceByIdMatcher(svcId)), any(ExpirationPolicy.class), anyBoolean(), anyBoolean())).thenReturn(svcTicket); when(tgtMock.getRoot()).thenReturn(root); when(tgtMock.getChainedAuthentications()).thenReturn(chainedAuthnList); when(tgtMock.getAuthentication()).thenReturn(this.authentication); when(svcTicket.getGrantingTicket()).thenReturn(tgtMock); return tgtMock; } private static ServiceTicket createMockServiceTicket(final String id, final Service svc) { final ServiceTicket stMock = mock(ServiceTicket.class); when(stMock.getService()).thenReturn(svc); when(stMock.getId()).thenReturn(id); when(stMock.isValidFor(svc)).thenReturn(true); return stMock; } private static RegisteredServiceProxyPolicy getServiceProxyPolicy(final boolean canProxy) { if (!canProxy) { return new RefuseRegisteredServiceProxyPolicy(); } return new RegexMatchingRegisteredServiceProxyPolicy(".*"); } private static RegisteredService createMockRegisteredService(final String svcId, final boolean enabled, final RegisteredServiceProxyPolicy proxy) { final RegisteredService mockRegSvc = mock(RegisteredService.class); when(mockRegSvc.getServiceId()).thenReturn(svcId); when(mockRegSvc.getProxyPolicy()).thenReturn(proxy); when(mockRegSvc.getName()).thenReturn(svcId); when(mockRegSvc.matches(argThat(new VerifyServiceByIdMatcher(svcId)))).thenReturn(true); when(mockRegSvc.getAttributeReleasePolicy()).thenReturn(new ReturnAllAttributeReleasePolicy()); when(mockRegSvc.getUsernameAttributeProvider()).thenReturn(new DefaultRegisteredServiceUsernameProvider()); when(mockRegSvc.getAccessStrategy()).thenReturn(new DefaultRegisteredServiceAccessStrategy(enabled, true)); return mockRegSvc; } private static Service getService(final String name) { final MockHttpServletRequest request = new MockHttpServletRequest(); request.addParameter(CasProtocolConstants.PARAMETER_SERVICE, name); return new WebApplicationServiceFactory().createService(request); } }