package io.dropwizard.db; import com.codahale.metrics.MetricRegistry; import io.dropwizard.configuration.ResourceConfigurationSourceProvider; import io.dropwizard.configuration.YamlConfigurationFactory; import io.dropwizard.jackson.Jackson; import io.dropwizard.util.Duration; import io.dropwizard.validation.BaseValidator; import org.apache.tomcat.jdbc.pool.interceptor.ConnectionState; import org.apache.tomcat.jdbc.pool.interceptor.StatementFinalizer; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; public class DataSourceFactoryTest { private final MetricRegistry metricRegistry = new MetricRegistry(); private DataSourceFactory factory; private ManagedDataSource dataSource; @Before public void setUp() { factory = new DataSourceFactory(); factory.setUrl("jdbc:h2:mem:DbTest-" + System.currentTimeMillis() + ";user=sa"); factory.setDriverClass("org.h2.Driver"); factory.setValidationQuery("SELECT 1"); } @After public void tearDown() throws Exception { if (null != dataSource) { dataSource.stop(); } } private ManagedDataSource dataSource() throws Exception { dataSource = factory.build(metricRegistry, "test"); dataSource.start(); return dataSource; } @Test public void testInitialSizeIsZero() throws Exception { factory.setUrl("nonsense invalid url"); factory.setInitialSize(0); ManagedDataSource dataSource = factory.build(metricRegistry, "test"); dataSource.start(); } @Test public void buildsAConnectionPoolToTheDatabase() throws Exception { try (Connection connection = dataSource().getConnection()) { try (PreparedStatement statement = connection.prepareStatement("select 1")) { try (ResultSet set = statement.executeQuery()) { while (set.next()) { assertThat(set.getInt(1)).isEqualTo(1); } } } } } @Test public void testNoValidationQueryTimeout() throws Exception { try (Connection connection = dataSource().getConnection()) { try (PreparedStatement statement = connection.prepareStatement("select 1")) { assertThat(statement.getQueryTimeout()).isEqualTo(0); } } } @Test public void testValidationQueryTimeoutIsSet() throws Exception { factory.setValidationQueryTimeout(Duration.seconds(3)); try (Connection connection = dataSource().getConnection()) { try (PreparedStatement statement = connection.prepareStatement("select 1")) { assertThat(statement.getQueryTimeout()).isEqualTo(3); } } } @Test(expected = SQLException.class) public void invalidJDBCDriverClassThrowsSQLException() throws SQLException { final DataSourceFactory factory = new DataSourceFactory(); factory.setDriverClass("org.example.no.driver.here"); factory.build(metricRegistry, "test").getConnection(); } @Test public void testCustomValidator() throws Exception { factory.setValidatorClassName(Optional.of(CustomConnectionValidator.class.getName())); try (Connection connection = dataSource().getConnection()) { try (PreparedStatement statement = connection.prepareStatement("select 1")) { try (ResultSet rs = statement.executeQuery()) { assertThat(rs.next()).isTrue(); assertThat(rs.getInt(1)).isEqualTo(1); } } } assertThat(CustomConnectionValidator.loaded).isTrue(); } @Test public void testJdbcInterceptors() throws Exception { factory.setJdbcInterceptors(Optional.of("StatementFinalizer;ConnectionState")); final ManagedPooledDataSource source = (ManagedPooledDataSource) dataSource(); assertThat(source.getPoolProperties().getJdbcInterceptorsAsArray()) .extracting("interceptorClass") .contains(StatementFinalizer.class, ConnectionState.class); } @Test public void createDefaultFactory() throws Exception { final DataSourceFactory factory = new YamlConfigurationFactory<>(DataSourceFactory.class, BaseValidator.newValidator(), Jackson.newObjectMapper(), "dw") .build(new ResourceConfigurationSourceProvider(), "yaml/minimal_db_pool.yml"); assertThat(factory.getDriverClass()).isEqualTo("org.postgresql.Driver"); assertThat(factory.getUser()).isEqualTo("pg-user"); assertThat(factory.getPassword()).isEqualTo("iAMs00perSecrEET"); assertThat(factory.getUrl()).isEqualTo("jdbc:postgresql://db.example.com/db-prod"); assertThat(factory.getValidationQuery()).isEqualTo("/* Health Check */ SELECT 1"); assertThat(factory.getValidationQueryTimeout()).isEqualTo(Optional.empty()); } }