package org.apereo.cas.ticket.registry;
import com.google.common.base.Throwables;
import com.google.common.io.ByteSource;
import org.apache.commons.lang3.StringUtils;
import org.apereo.cas.CipherExecutor;
import org.apereo.cas.authentication.principal.Service;
import org.apereo.cas.ticket.ServiceTicket;
import org.apereo.cas.ticket.Ticket;
import org.apereo.cas.ticket.TicketGrantingTicket;
import org.apereo.cas.ticket.proxy.ProxyGrantingTicket;
import org.apereo.cas.util.DigestUtils;
import org.apereo.cas.util.serialization.SerializationUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
/**
* @author Scott Battaglia
* @since 3.0.0
* <p>
* This is a published and supported CAS Server API.
* </p>
*/
public abstract class AbstractTicketRegistry implements TicketRegistry {
private static final String MESSAGE = "Ticket encryption is not enabled. Falling back to default behavior";
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractTicketRegistry.class);
/**
* The cipher executor for ticket objects.
*/
protected CipherExecutor cipherExecutor;
/**
* Default constructor.
*/
@SuppressWarnings("unchecked")
public AbstractTicketRegistry() {
}
/**
*
*
* @return specified ticket from the registry
* @throws IllegalArgumentException if class is null.
* @throws ClassCastException if class does not match requested ticket
* class.
*/
@Override
public <T extends Ticket> T getTicket(final String ticketId, final Class<T> clazz) {
Assert.notNull(clazz, "clazz cannot be null");
final Ticket ticket = this.getTicket(ticketId);
if (ticket == null) {
return null;
}
if (!clazz.isAssignableFrom(ticket.getClass())) {
throw new ClassCastException("Ticket [" + ticket.getId()
+ " is of type " + ticket.getClass()
+ " when we were expecting " + clazz);
}
return (T) ticket;
}
@Override
public long sessionCount() {
try {
return getTickets().stream().filter(TicketGrantingTicket.class::isInstance).count();
} catch (final Throwable t) {
LOGGER.trace("sessionCount() operation is not implemented by the ticket registry instance [{}]. "
+ "Message is: [{}] Returning unknown as [{}]",
this.getClass().getName(), t.getMessage(), Long.MIN_VALUE);
return Long.MIN_VALUE;
}
}
@Override
public long serviceTicketCount() {
try {
return getTickets().stream().filter(ServiceTicket.class::isInstance).count();
} catch (final Throwable t) {
LOGGER.trace("serviceTicketCount() operation is not implemented by the ticket registry instance [{}]. "
+ "Message is: [{}] Returning unknown as [{}]",
this.getClass().getName(), t.getMessage(), Long.MIN_VALUE);
return Long.MIN_VALUE;
}
}
@Override
public int deleteTicket(final String ticketId) {
final AtomicInteger count = new AtomicInteger(0);
if (StringUtils.isBlank(ticketId)) {
return count.intValue();
}
final Ticket ticket = getTicket(ticketId);
if (ticket == null) {
return count.intValue();
}
if (ticket instanceof TicketGrantingTicket) {
if (ticket instanceof ProxyGrantingTicket) {
LOGGER.debug("Removing proxy-granting ticket [{}]", ticketId);
}
LOGGER.debug("Removing children of ticket [{}] from the registry.", ticket.getId());
final TicketGrantingTicket tgt = (TicketGrantingTicket) ticket;
count.addAndGet(deleteChildren(tgt));
final Collection<ProxyGrantingTicket> proxyGrantingTickets = tgt.getProxyGrantingTickets();
proxyGrantingTickets.stream().map(Ticket::getId).forEach(t -> count.addAndGet(this.deleteTicket(t)));
}
LOGGER.debug("Removing ticket [{}] from the registry.", ticket);
if (deleteSingleTicket(ticketId)) {
count.incrementAndGet();
}
return count.intValue();
}
/**
* Delete TGT's service tickets.
*
* @param ticket the ticket
* @return the count of tickets that were removed including child tickets and zero if the ticket was not deleted
*/
public int deleteChildren(final TicketGrantingTicket ticket) {
final AtomicInteger count = new AtomicInteger(0);
// delete service tickets
final Map<String, Service> services = ticket.getServices();
if (services != null && !services.isEmpty()) {
services.keySet().stream().forEach(ticketId -> {
if (deleteSingleTicket(ticketId)) {
LOGGER.debug("Removed ticket [{}]", ticketId);
count.incrementAndGet();
} else {
LOGGER.debug("Unable to remove ticket [{}]", ticketId);
}
});
}
return count.intValue();
}
/**
* Delete a single ticket instance from the store.
*
* @param ticketId the ticket id
* @return true/false
*/
public boolean deleteSingleTicket(final Ticket ticketId) {
return deleteSingleTicket(ticketId.getId());
}
/**
* Delete a single ticket instance from the store.
*
* @param ticketId the ticket id
* @return true/false
*/
public abstract boolean deleteSingleTicket(String ticketId);
public void setCipherExecutor(final CipherExecutor<byte[], byte[]> cipherExecutor) {
this.cipherExecutor = cipherExecutor;
}
/**
* Encode ticket id into a SHA-512.
*
* @param ticketId the ticket id
* @return the ticket
*/
protected String encodeTicketId(final String ticketId) {
if (!isCipherExecutorEnabled()) {
LOGGER.trace(MESSAGE);
return ticketId;
}
if (StringUtils.isBlank(ticketId)) {
return ticketId;
}
final String encodedId = DigestUtils.sha512(ticketId);
LOGGER.debug("Encoded original ticket id [{}] to [{}]", ticketId, encodedId);
return encodedId;
}
/**
* Encode ticket.
*
* @param ticket the ticket
* @return the ticket
*/
protected Ticket encodeTicket(final Ticket ticket) {
if (!isCipherExecutorEnabled()) {
LOGGER.trace(MESSAGE);
return ticket;
}
if (ticket == null) {
LOGGER.debug("Ticket passed is null and cannot be encoded");
return null;
}
LOGGER.debug("Encoding ticket [{}]", ticket);
final byte[] encodedTicketObject = SerializationUtils.serializeAndEncodeObject(this.cipherExecutor, ticket);
final String encodedTicketId = encodeTicketId(ticket.getId());
final Ticket encodedTicket = new EncodedTicket(ByteSource.wrap(encodedTicketObject), encodedTicketId);
LOGGER.debug("Created encoded ticket [{}]", encodedTicket);
return encodedTicket;
}
/**
* Decode ticket.
*
* @param result the result
* @return the ticket
*/
protected Ticket decodeTicket(final Ticket result) {
try {
if (!isCipherExecutorEnabled()) {
LOGGER.trace(MESSAGE);
return result;
}
if (result == null) {
LOGGER.debug("Ticket passed is null and cannot be decoded");
return null;
}
LOGGER.debug("Attempting to decode [{}]", result);
final EncodedTicket encodedTicket = (EncodedTicket) result;
final Ticket ticket = SerializationUtils.decodeAndDeserializeObject(
encodedTicket.getEncoded(), this.cipherExecutor, Ticket.class);
LOGGER.debug("Decoded ticket to [{}]", ticket);
return ticket;
} catch (final Exception e) {
throw Throwables.propagate(e);
}
}
/**
* Decode tickets.
*
* @param items the items
* @return the set
*/
protected Collection<Ticket> decodeTickets(final Collection<Ticket> items) {
if (!isCipherExecutorEnabled()) {
LOGGER.trace(MESSAGE);
return items;
}
return items.stream().map(this::decodeTicket).collect(Collectors.toSet());
}
protected boolean isCipherExecutorEnabled() {
return this.cipherExecutor != null && this.cipherExecutor.isEnabled();
}
}