package org.apereo.cas.ticket.registry;
import com.mongodb.BasicDBObject;
import com.mongodb.DBCollection;
import org.apereo.cas.ticket.BaseTicketSerializers;
import org.apereo.cas.ticket.Ticket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.util.Assert;
import javax.annotation.PostConstruct;
import java.util.Collection;
import java.util.stream.Collectors;
/**
* A Ticket Registry storage backend based on MongoDB.
*
* @author Misagh Moayyed
* @since 5.1.0
*/
public class MongoDbTicketRegistry extends AbstractTicketRegistry {
private static final Logger LOGGER = LoggerFactory.getLogger(MongoDbTicketRegistry.class);
private final String collectionName;
private final boolean dropCollection;
private final MongoOperations mongoTemplate;
public MongoDbTicketRegistry(final String collectionName, final MongoOperations mongoTemplate) {
this(collectionName, false, mongoTemplate);
}
public MongoDbTicketRegistry(final String collectionName, final boolean dropCollection, final MongoOperations mongoTemplate) {
this.collectionName = collectionName;
this.dropCollection = dropCollection;
this.mongoTemplate = mongoTemplate;
}
/**
* Init registry.
**/
@PostConstruct
public void initialize() {
Assert.notNull(this.mongoTemplate);
LOGGER.debug("Setting up MongoDb Ticket Registry instance [{}]", this.collectionName);
if (this.dropCollection) {
LOGGER.debug("Dropping database collection: [{}]", this.collectionName);
this.mongoTemplate.dropCollection(this.collectionName);
}
if (!this.mongoTemplate.collectionExists(this.collectionName)) {
LOGGER.debug("Creating database collection: [{}]", this.collectionName);
this.mongoTemplate.createCollection(this.collectionName);
}
LOGGER.debug("Creating indices on collection [{}] to auto-expire documents...", this.collectionName);
final DBCollection collection = mongoTemplate.getCollection(this.collectionName);
collection.createIndex(new BasicDBObject(TicketHolder.FIELD_NAME_EXPIRE_AT, 1),
new BasicDBObject("expireAfterSeconds", 0));
LOGGER.info("Configured MongoDb Ticket Registry instance [{}]", this.collectionName);
}
@Override
public Ticket updateTicket(final Ticket ticket) {
LOGGER.debug("Updating ticket [{}]", ticket);
try {
final TicketHolder holder = buildTicketAsDocument(ticket);
this.mongoTemplate.updateFirst(new Query(Criteria.where(TicketHolder.FIELD_NAME_ID).is(holder.getTicketId())),
Update.update(TicketHolder.FIELD_NAME_JSON, holder.getJson()), this.collectionName);
} catch (final Exception e) {
LOGGER.error("Failed updating [{}]: [{}]", ticket, e);
}
return ticket;
}
@Override
public void addTicket(final Ticket ticket) {
try {
LOGGER.debug("Adding ticket [{}]", ticket);
this.mongoTemplate.insert(buildTicketAsDocument(ticket), this.collectionName);
} catch (final Exception e) {
LOGGER.error("Failed adding [{}]: [{}]", ticket, e);
}
}
@Override
public Ticket getTicket(final String ticketId) {
try {
LOGGER.debug("Locating ticket ticketId [{}]", ticketId);
final String encTicketId = encodeTicketId(ticketId);
if (encTicketId == null) {
LOGGER.debug("Ticket ticketId [{}] could not be found", ticketId);
return null;
}
final TicketHolder d = this.mongoTemplate.findOne(new Query(Criteria.where(TicketHolder.FIELD_NAME_ID).is(encTicketId)),
TicketHolder.class, this.collectionName);
if (d != null) {
return deserializeTicketFromMongoDocument(d);
}
} catch (final Exception e) {
LOGGER.error("Failed fetching [{}]: [{}]", ticketId, e);
}
return null;
}
@Override
public Collection<Ticket> getTickets() {
final Collection<TicketHolder> c = this.mongoTemplate.findAll(TicketHolder.class, this.collectionName);
return c.stream().map(MongoDbTicketRegistry::deserializeTicketFromMongoDocument).collect(Collectors.toSet());
}
@Override
public long sessionCount() {
return 0;
}
@Override
public long serviceTicketCount() {
return 0;
}
@Override
public boolean deleteSingleTicket(final String ticketId) {
LOGGER.debug("Deleting ticket [{}]", ticketId);
try {
this.mongoTemplate.remove(new Query(Criteria.where(TicketHolder.FIELD_NAME_ID).is(ticketId)), this.collectionName);
return true;
} catch (final Exception e) {
LOGGER.error("Failed deleting [{}]: [{}]", ticketId, e);
}
return false;
}
@Override
public long deleteAll() {
final Query query = new Query(Criteria.where(TicketHolder.FIELD_NAME_ID).regex(".+"));
final long count = this.mongoTemplate.count(query, this.collectionName);
mongoTemplate.remove(query, this.collectionName);
return count;
}
private static int getTimeToLive(final Ticket ticket) {
return ticket.getExpirationPolicy().getTimeToLive().intValue();
}
private static String serializeTicketForMongoDocument(final Ticket ticket) {
return BaseTicketSerializers.serializeTicket(ticket);
}
private static Ticket deserializeTicketFromMongoDocument(final TicketHolder holder) {
return BaseTicketSerializers.deserializeTicket(holder.getJson(), holder.getType());
}
private static TicketHolder buildTicketAsDocument(final Ticket ticket) {
final String json = serializeTicketForMongoDocument(ticket);
return new TicketHolder(json, ticket.getId(), ticket.getClass().getName(), getTimeToLive(ticket));
}
}