package org.exist.xquery.modules.persistentlogin;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.exist.util.Base64Encoder;
import org.exist.xquery.XPathException;
import org.exist.xquery.value.DateTimeValue;
import org.exist.xquery.value.DurationValue;
import java.security.SecureRandom;
import java.util.*;
/**
* A persistent login feature ("remember me") similar to the implementation in <a href="https://github.com/SpringSource/spring-security">Spring Security</a>,
* which is based on <a href="http://jaspan.com/improved_persistent_login_cookie_best_practice">Improved Persistent Login Cookie
* Best Practice</a> .
*
* The one-time tokens generated by this class are purely random and do not contain a user name or other information. For security reasons,
* tokens and user information are not stored anywhere, so if the database is shut down, registered tokens will be gone.
*
* The one-time token approach has the negative effect that requests need to be made in sequence, which is sometimes difficult if an app uses
* concurrent AJAX requests. Unfortunately, this is the price we have to pay for a sufficiently secure protection against
* cookie stealing attacks.
*
* @author Wolfgang Meier
*/
public class PersistentLogin {
private final static PersistentLogin instance = new PersistentLogin();
public static PersistentLogin getInstance() {
return instance;
}
private final static Logger LOG = LogManager.getLogger(PersistentLogin.class);
public final static int DEFAULT_SERIES_LENGTH = 16;
public final static int DEFAULT_TOKEN_LENGTH = 16;
public final static int INVALIDATION_TIMEOUT = 20000;
private Map<String, LoginDetails> seriesMap = Collections.synchronizedMap(new HashMap<>());
private SecureRandom random;
public PersistentLogin() {
random = new SecureRandom();
}
/**
* Register the user and generate a first login token which will be valid for the next
* call to {@link #lookup(String)}.
*
* The generated token will have the format base64(series-hash):base64(token-hash).
*
* @param user the user name
* @param password the password
* @param timeToLive timeout of the token
* @return a first login token
* @throws XPathException
*/
public LoginDetails register(String user, String password, DurationValue timeToLive) throws XPathException {
DateTimeValue now = new DateTimeValue(new Date());
DateTimeValue expires = (DateTimeValue) now.plus(timeToLive);
LoginDetails login = new LoginDetails(user, password, timeToLive, expires.getTimeInMillis());
seriesMap.put(login.getSeries(), login);
return login;
}
/**
* Look up the given token and return login details. If the token is found, it will be updated
* with a new hash before returning and the old hash is removed.
*
* @param token the token string provided by the user
* @return login details for the user or null if no session was found or it was expired
* @throws XPathException series matched but the token not. may indicate a cookie theft attack
* or an out-of-sequence request.
*/
public LoginDetails lookup(String token) throws XPathException {
String[] tokens = token.split(":");
LoginDetails data = seriesMap.get(tokens[0]);
if (data == null) {
LOG.debug("No session found for series " + tokens[0]);
return null;
}
long now = System.currentTimeMillis();
if (now > data.expires) {
LOG.debug("Persistent session expired");
seriesMap.remove(tokens[0]);
return null;
}
// sequential token checking is disabled by default
if (data.seqBehavior) {
LOG.debug("Using sequential tokens");
if (!data.checkAndUpdateToken(tokens[1])) {
LOG.debug("Out-of-sequence request or cookie theft attack. Deleting session.");
seriesMap.remove(tokens[0]);
throw new XPathException("Token mismatch. This may indicate an out-of-sequence request (likely) or a cookie theft attack. " +
"Session is deleted for security reasons.");
}
}
return data;
}
/**
* Invalidate the session associated with the token string. Looks up the series hash
* and deletes it.
*
* @param token token string provided by the user
*/
public void invalidate(String token) {
String[] tokens = token.split(":");
seriesMap.remove(tokens[0]);
}
private String generateSeriesToken() {
byte[] newSeries = new byte[DEFAULT_SERIES_LENGTH];
random.nextBytes(newSeries);
Base64Encoder encoder = new Base64Encoder();
encoder.translate(newSeries);
return new String(encoder.getCharArray());
}
private String generateToken() {
byte[] newSeries = new byte[DEFAULT_TOKEN_LENGTH];
random.nextBytes(newSeries);
Base64Encoder encoder = new Base64Encoder();
encoder.translate(newSeries);
return new String(encoder.getCharArray());
}
public class LoginDetails {
private String userName;
private String password;
private String token;
private String series;
private long expires;
private DurationValue timeToLive;
// disable sequential token checking by default
private boolean seqBehavior = false;
private Map<String, Long> invalidatedTokens = new HashMap<>();
public LoginDetails(String user, String password, DurationValue timeToLive, long expires) {
this.userName = user;
this.password = password;
this.timeToLive = timeToLive;
this.expires = expires;
this.token = generateToken();
this.series = generateSeriesToken();
}
public String getToken() {
return this.token;
}
public String getSeries() {
return this.series;
}
public String getUser() {
return this.userName;
}
public String getPassword() {
return this.password;
}
public DurationValue getTimeToLive() {
return timeToLive;
}
public boolean checkAndUpdateToken(String token) {
if (this.token.equals(token)) {
update();
return true;
}
// check map of invalidating tokens
Long timeout = invalidatedTokens.get(token);
if (timeout == null)
return false;
// timed out: remove
if (System.currentTimeMillis() > timeout) {
invalidatedTokens.remove(token);
return false;
}
// found invalidating token: return true but do not replace token
return true;
}
public String update() {
timeoutCheck();
// leave a small time window until previous token is deleted
// to allow for concurrent requests
invalidatedTokens.put(this.token, System.currentTimeMillis() + INVALIDATION_TIMEOUT);
this.token = generateToken();
return this.token;
}
private void timeoutCheck() {
long now = System.currentTimeMillis();
for (Iterator<Map.Entry<String, Long>> i = invalidatedTokens.entrySet().iterator(); i.hasNext(); ) {
Map.Entry<String, Long> entry = i.next();
if (entry.getValue() < now) {
i.remove();
}
}
}
@Override
public String toString() {
return this.series + ":" + this.token;
}
}
}