package org.apereo.cas.adaptors.jdbc;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.apereo.cas.authentication.HandlerResult;
import org.apereo.cas.authentication.PreventedException;
import org.apereo.cas.authentication.UsernamePasswordCredential;
import org.apereo.cas.authentication.exceptions.AccountDisabledException;
import org.apereo.cas.authentication.exceptions.AccountPasswordMustChangeException;
import org.apereo.cas.authentication.principal.PrincipalFactory;
import org.apereo.cas.services.ServicesManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import javax.security.auth.login.AccountNotFoundException;
import javax.security.auth.login.FailedLoginException;
import javax.sql.DataSource;
import java.security.GeneralSecurityException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* Class that if provided a query that returns a password (parameter of query
* must be username) will compare that password to a translated version of the
* password provided by the user. If they match, then authentication succeeds.
* Default password translator is plaintext translator.
*
* @author Scott Battaglia
* @author Dmitriy Kopylenko
* @author Marvin S. Addison
* @since 3.0.0
*/
public class QueryDatabaseAuthenticationHandler extends AbstractJdbcUsernamePasswordAuthenticationHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(QueryDatabaseAuthenticationHandler.class);
private final String sql;
private final String fieldPassword;
private final String fieldExpired;
private final String fieldDisabled;
private Map<String, String> principalAttributeMap = Collections.emptyMap();
public QueryDatabaseAuthenticationHandler(final String name, final ServicesManager servicesManager,
final PrincipalFactory principalFactory,
final Integer order, final DataSource dataSource, final String sql,
final String fieldPassword, final String fieldExpired, final String fieldDisabled,
final Map<String, String> attributes) {
super(name, servicesManager, principalFactory, order, dataSource);
this.sql = sql;
this.fieldPassword = fieldPassword;
this.fieldExpired = fieldExpired;
this.fieldDisabled = fieldDisabled;
this.principalAttributeMap = attributes;
}
@Override
protected HandlerResult authenticateUsernamePasswordInternal(final UsernamePasswordCredential credential, final String originalPassword)
throws GeneralSecurityException, PreventedException {
if (StringUtils.isBlank(this.sql) || getJdbcTemplate() == null) {
throw new GeneralSecurityException("Authentication handler is not configured correctly. "
+ "No SQL statement or JDBC template is found.");
}
final Map<String, Object> attributes = new LinkedHashMap<>(this.principalAttributeMap.size());
final String username = credential.getUsername();
final String password = credential.getPassword();
try {
final Map<String, Object> dbFields = getJdbcTemplate().queryForMap(this.sql, username);
final String dbPassword = (String) dbFields.get(this.fieldPassword);
if (StringUtils.isNotBlank(originalPassword) && !matches(originalPassword, dbPassword)
|| StringUtils.isBlank(originalPassword) && !StringUtils.equals(password, dbPassword)) {
throw new FailedLoginException("Password does not match value on record.");
}
if (StringUtils.isNotBlank(this.fieldDisabled)) {
final Object dbDisabled = dbFields.get(this.fieldDisabled);
if (dbDisabled != null && (Boolean.TRUE.equals(BooleanUtils.toBoolean(dbDisabled.toString())) || dbDisabled.equals(Integer.valueOf(1)))) {
throw new AccountDisabledException("Account has been disabled");
}
}
if (StringUtils.isNotBlank(this.fieldExpired)) {
final Object dbExpired = dbFields.get(this.fieldExpired);
if (dbExpired != null && (Boolean.TRUE.equals(BooleanUtils.toBoolean(dbExpired.toString())) || dbExpired.equals(Integer.valueOf(1)))) {
throw new AccountPasswordMustChangeException("Password has expired");
}
}
this.principalAttributeMap.entrySet().forEach(a -> {
final Object attribute = dbFields.get(a.getKey());
if (attribute != null) {
LOGGER.debug("Found attribute [{}] from the query results", a);
if (attribute != null) {
LOGGER.debug("Found attribute [{}] from the query results", a);
final String principalAttrName = a.getValue();
attributes.put(principalAttrName, attribute.toString());
} else {
LOGGER.warn("Requested attribute [{}] could not be found in the query results", a.getKey());
}
}
});
} catch (final IncorrectResultSizeDataAccessException e) {
if (e.getActualSize() == 0) {
throw new AccountNotFoundException(username + " not found with SQL query");
}
throw new FailedLoginException("Multiple records found for " + username);
} catch (final DataAccessException e) {
throw new PreventedException("SQL exception while executing query for " + username, e);
}
return createHandlerResult(credential, this.principalFactory.createPrincipal(username, attributes), null);
}
}