/***************************************************************** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. ****************************************************************/ package org.apache.cayenne.datasource; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.Collection; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import javax.sql.DataSource; import org.apache.cayenne.unit.di.server.CayenneProjects; import org.apache.cayenne.unit.di.server.UseServerRuntime; import org.slf4j.Logger; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.stubbing.OngoingStubbing; import org.slf4j.LoggerFactory; @UseServerRuntime(CayenneProjects.TESTMAP_PROJECT) public class ManagedPoolingDataSourceIT { private static final Logger LOGGER = LoggerFactory.getLogger(ManagedPoolingDataSourceIT.class); private int poolSize; private OnOffDataSourceManager dataSourceManager; private UnmanagedPoolingDataSource unmanagedPool; private ManagedPoolingDataSource managedPool; @Before public void before() throws SQLException { this.poolSize = 4; this.dataSourceManager = new OnOffDataSourceManager(); PoolingDataSourceParameters parameters = new PoolingDataSourceParameters(); parameters.setMaxConnections(poolSize); parameters.setMinConnections(poolSize / 2); parameters.setMaxQueueWaitTime(1000); parameters.setValidationQuery("SELECT 1"); this.unmanagedPool = new UnmanagedPoolingDataSource(dataSourceManager.mockDataSource, parameters); this.managedPool = new ManagedPoolingDataSource(unmanagedPool, 10000); } @After public void after() { if (managedPool != null) { managedPool.close(); } } private Collection<PoolTask> createTasks(int size) { Collection<PoolTask> tasks = new ArrayList<>(); for (int i = 0; i < size; i++) { tasks.add(new PoolTask()); } return tasks; } @Test public void testGetConnection_OnBackendShutdown() throws SQLException, InterruptedException { // note that this assertion can only work reliably when the pool is inactive... assertEquals(poolSize, managedPool.poolSize() + managedPool.canExpandSize()); Collection<PoolTask> tasks = createTasks(4); ExecutorService executor = Executors.newFixedThreadPool(4); for (int j = 0; j < 10; j++) { for (PoolTask task : tasks) { executor.submit(task); } } dataSourceManager.off(); Thread.sleep(500); for (int j = 0; j < 10; j++) { for (PoolTask task : tasks) { executor.submit(task); } } Thread.sleep(100); dataSourceManager.on(); for (int j = 0; j < 10; j++) { for (PoolTask task : tasks) { executor.submit(task); } } executor.shutdown(); executor.awaitTermination(2, TimeUnit.SECONDS); // note that this assertion can only work reliably when the pool is inactive... assertEquals(poolSize, managedPool.poolSize() + managedPool.canExpandSize()); } class PoolTask implements Runnable { @Override public void run() { try (Connection c = managedPool.getConnection();) { try (Statement s = c.createStatement()) { try { Thread.sleep(40); } catch (InterruptedException e) { // ignore } } } catch (SQLException e) { if (OnOffDataSourceManager.NO_CONNECTIONS_MESSAGE.equals(e.getMessage())) { LOGGER.info("db down..."); } else { LOGGER.warn("error getting connection", e); } } } } static class OnOffDataSourceManager { static final String NO_CONNECTIONS_MESSAGE = "no connections at the moment"; private DataSource mockDataSource; private OngoingStubbing<Connection> createConnectionMock; OnOffDataSourceManager() throws SQLException { this.mockDataSource = mock(DataSource.class); this.createConnectionMock = when(mockDataSource.getConnection()); on(); } void off() throws SQLException { createConnectionMock.thenAnswer(new Answer<Connection>() { @Override public Connection answer(InvocationOnMock invocation) throws Throwable { throw new SQLException(NO_CONNECTIONS_MESSAGE); } }); } void on() throws SQLException { createConnectionMock.thenAnswer(new Answer<Connection>() { @Override public Connection answer(InvocationOnMock invocation) throws Throwable { Connection c = mock(Connection.class); when(c.createStatement()).thenAnswer(new Answer<Statement>() { @Override public Statement answer(InvocationOnMock invocation) throws Throwable { ResultSet mockRs = mock(ResultSet.class); when(mockRs.next()).thenReturn(true, false, false, false); Statement mockStatement = mock(Statement.class); when(mockStatement.executeQuery(anyString())).thenReturn(mockRs); return mockStatement; } }); return c; } }); } } }