package org.apereo.cas.adaptors.jdbc;
import com.google.common.base.Throwables;
import org.apache.shiro.crypto.hash.DefaultHashService;
import org.apache.shiro.crypto.hash.HashRequest;
import org.apache.shiro.util.ByteSource;
import org.apereo.cas.authentication.CoreAuthenticationTestUtils;
import org.apereo.cas.authentication.HandlerResult;
import org.apereo.cas.authentication.PreventedException;
import org.apereo.cas.authentication.UsernamePasswordCredential;
import org.apereo.cas.authentication.exceptions.AccountDisabledException;
import org.apereo.cas.authentication.exceptions.AccountPasswordMustChangeException;
import org.apereo.cas.util.transforms.PrefixSuffixPrincipalNameTransformer;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.autoconfigure.RefreshAutoConfiguration;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringRunner;
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.security.auth.login.AccountNotFoundException;
import javax.security.auth.login.FailedLoginException;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.Statement;
import static org.junit.Assert.*;
/**
* @author Misagh Moayyed
* @since 4.0.0
*/
@RunWith(SpringRunner.class)
@SpringBootTest(classes = {RefreshAutoConfiguration.class})
@ContextConfiguration(locations = {"classpath:/jpaTestApplicationContext.xml"})
public class QueryAndEncodeDatabaseAuthenticationHandlerTests {
private static final String ALG_NAME = "SHA-512";
private static final String SQL = "SELECT * FROM users where %s";
private static final int NUM_ITERATIONS = 5;
private static final String STATIC_SALT = "STATIC_SALT";
private static final String PASSWORD_FIELD_NAME = "password";
private static final String EXPIRED_FIELD_NAME = "expired";
private static final String DISABLED_FIELD_NAME = "disabled";
private static final String NUM_ITERATIONS_FIELD_NAME = "numIterations";
@Rule
public ExpectedException thrown = ExpectedException.none();
@Autowired
@Qualifier("dataSource")
private DataSource dataSource;
@Before
public void setUp() throws Exception {
final Connection c = this.dataSource.getConnection();
final Statement s = c.createStatement();
c.setAutoCommit(true);
s.execute(getSqlInsertStatementToCreateUserAccount(0, Boolean.FALSE.toString(), Boolean.FALSE.toString()));
for (int i = 0; i < 10; i++) {
s.execute(getSqlInsertStatementToCreateUserAccount(i, Boolean.FALSE.toString(), Boolean.FALSE.toString()));
}
s.execute(getSqlInsertStatementToCreateUserAccount(20, Boolean.TRUE.toString(), Boolean.FALSE.toString()));
s.execute(getSqlInsertStatementToCreateUserAccount(21, Boolean.FALSE.toString(), Boolean.TRUE.toString()));
c.close();
}
private static String getSqlInsertStatementToCreateUserAccount(final int i, final String expired, final String disabled) {
final String psw = genPassword("user" + i, "salt" + i, NUM_ITERATIONS);
return String.format(
"insert into users (username, password, salt, numIterations, expired, disabled) values('%s', '%s', '%s', %s, '%s', '%s');",
"user" + i, psw, "salt" + i, NUM_ITERATIONS, expired, disabled);
}
@After
public void tearDown() throws Exception {
final Connection c = this.dataSource.getConnection();
final Statement s = c.createStatement();
c.setAutoCommit(true);
for (int i = 0; i < 5; i++) {
s.execute("delete from users;");
}
c.close();
}
@Test
public void verifyAuthenticationFailsToFindUser() throws Exception {
final QueryAndEncodeDatabaseAuthenticationHandler q = new QueryAndEncodeDatabaseAuthenticationHandler("", null, null, null, dataSource, ALG_NAME,
buildSql(), PASSWORD_FIELD_NAME, "salt", null, null, "ops", 0, "");
this.thrown.expect(AccountNotFoundException.class);
this.thrown.expectMessage("test not found with SQL query");
q.authenticate(CoreAuthenticationTestUtils.getCredentialsWithSameUsernameAndPassword());
}
@Test
public void verifyAuthenticationInvalidSql() throws Exception {
final QueryAndEncodeDatabaseAuthenticationHandler q = new QueryAndEncodeDatabaseAuthenticationHandler("", null, null, null, dataSource, ALG_NAME,
buildSql("makesNoSenseInSql"), PASSWORD_FIELD_NAME, "salt", null, null, "ops", 0, "");
this.thrown.expect(PreventedException.class);
this.thrown.expectMessage("SQL exception while executing query for test");
q.authenticate(CoreAuthenticationTestUtils.getCredentialsWithSameUsernameAndPassword());
}
@Test
public void verifyAuthenticationMultipleAccounts() throws Exception {
final QueryAndEncodeDatabaseAuthenticationHandler q = new QueryAndEncodeDatabaseAuthenticationHandler("", null, null, null, dataSource, ALG_NAME,
buildSql(), PASSWORD_FIELD_NAME, "salt", null, null, "ops", 0, "");
this.thrown.expect(FailedLoginException.class);
this.thrown.expectMessage("Multiple records found for user0");
q.authenticate(CoreAuthenticationTestUtils.getCredentialsWithDifferentUsernameAndPassword("user0", "password0"));
}
@Test
public void verifyAuthenticationSuccessful() throws Exception {
final QueryAndEncodeDatabaseAuthenticationHandler q = new QueryAndEncodeDatabaseAuthenticationHandler("", null, null, null, dataSource, ALG_NAME,
buildSql(), PASSWORD_FIELD_NAME, "salt", null, null, NUM_ITERATIONS_FIELD_NAME, 0, STATIC_SALT);
final UsernamePasswordCredential c = CoreAuthenticationTestUtils.getCredentialsWithSameUsernameAndPassword("user1");
final HandlerResult r = q.authenticate(c);
assertNotNull(r);
assertEquals(r.getPrincipal().getId(), "user1");
}
@Test
public void verifyAuthenticationWithExpiredField() throws Exception {
final QueryAndEncodeDatabaseAuthenticationHandler q = new QueryAndEncodeDatabaseAuthenticationHandler("", null, null, null, dataSource, ALG_NAME,
buildSql(), PASSWORD_FIELD_NAME, "salt", EXPIRED_FIELD_NAME, null, NUM_ITERATIONS_FIELD_NAME, 0, STATIC_SALT);
this.thrown.expect(AccountPasswordMustChangeException.class);
this.thrown.expectMessage("Password has expired");
q.authenticate(CoreAuthenticationTestUtils.getCredentialsWithSameUsernameAndPassword("user20"));
fail("Shouldn't get here");
}
@Test
public void verifyAuthenticationWithDisabledField() throws Exception {
final QueryAndEncodeDatabaseAuthenticationHandler q = new QueryAndEncodeDatabaseAuthenticationHandler("", null, null, null, dataSource, ALG_NAME,
buildSql(), PASSWORD_FIELD_NAME, "salt", null, DISABLED_FIELD_NAME, NUM_ITERATIONS_FIELD_NAME, 0, STATIC_SALT);
this.thrown.expect(AccountDisabledException.class);
this.thrown.expectMessage("Account has been disabled");
q.authenticate(CoreAuthenticationTestUtils.getCredentialsWithSameUsernameAndPassword("user21"));
fail("Shouldn't get here");
}
@Test
public void verifyAuthenticationSuccessfulWithAPasswordEncoder() throws Exception {
final QueryAndEncodeDatabaseAuthenticationHandler q = new QueryAndEncodeDatabaseAuthenticationHandler("", null, null, null, dataSource, ALG_NAME,
buildSql(), PASSWORD_FIELD_NAME, "salt", null, null, NUM_ITERATIONS_FIELD_NAME, 0, STATIC_SALT);
q.setPasswordEncoder(new PasswordEncoder() {
@Override
public String encode(final CharSequence password) {
return password.toString().concat("1");
}
@Override
public boolean matches(final CharSequence rawPassword, final String encodedPassword) {
return true;
}
});
q.setPrincipalNameTransformer(new PrefixSuffixPrincipalNameTransformer("user", null));
final HandlerResult r = q.authenticate(
CoreAuthenticationTestUtils.getCredentialsWithDifferentUsernameAndPassword("1", "user"));
assertNotNull(r);
assertEquals(r.getPrincipal().getId(), "user1");
}
private static String buildSql(final String where) {
return String.format(SQL, where);
}
private static String buildSql() {
return String.format(SQL, "username=?;");
}
private static String genPassword(final String psw, final String salt, final int iter) {
try {
final DefaultHashService hash = new DefaultHashService();
hash.setPrivateSalt(ByteSource.Util.bytes(STATIC_SALT));
hash.setHashIterations(iter);
hash.setGeneratePublicSalt(false);
hash.setHashAlgorithmName(ALG_NAME);
return hash.computeHash(new HashRequest.Builder().setSource(psw).setSalt(salt).setIterations(iter).build()).toHex();
} catch (final Exception e) {
throw Throwables.propagate(e);
}
}
@Entity(name = "users")
public static class UsersTable {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
private String username;
private String password;
private String salt;
private String expired;
private String disabled;
private long numIterations;
}
}