/* * Copyright 2010 Proofpoint, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.airlift.dbpool; import com.google.common.primitives.Ints; import io.airlift.units.Duration; import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; import javax.sql.ConnectionEvent; import javax.sql.ConnectionEventListener; import javax.sql.DataSource; import javax.sql.PooledConnection; import java.io.PrintWriter; import java.sql.Connection; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; import static io.airlift.units.Duration.nanosSince; import static java.lang.Math.ceil; public abstract class ManagedDataSource implements DataSource { private final ManagedSemaphore semaphore; private final AtomicInteger maxConnectionWaitMillis = new AtomicInteger(100); private final ManagedDataSourceStats stats = new ManagedDataSourceStats(); protected ManagedDataSource(int maxConnections, Duration maxConnectionWait) { if (maxConnections < 1) { throw new IllegalArgumentException("maxConnections must be at least 1: maxConnections=" + maxConnections); } if (maxConnectionWait == null) { throw new NullPointerException("maxConnectionWait is null"); } semaphore = new ManagedSemaphore(maxConnections); maxConnectionWaitMillis.set(Ints.checkedCast(maxConnectionWait.toMillis())); } @Override public Connection getConnection() throws SQLException { long start = System.nanoTime(); try { acquirePermit(); boolean checkedOut = false; try { Connection connection = createConnection(); checkedOut = true; return connection; } finally { if (!checkedOut) { semaphore.release(); } } } finally { stats.connectionCheckedOut(nanosSince(start)); } } protected Connection createConnection() throws SQLException { boolean success = false; try { // todo do not create on caller's thread long start = System.nanoTime(); PooledConnection pooledConnection = createConnectionInternal(); Connection connection = prepareConnection(pooledConnection); stats.connectionCreated(nanosSince(start)); success = true; return connection; } finally { if (!success) { stats.creationErrorOccurred(); } } } protected abstract PooledConnection createConnectionInternal() throws SQLException; protected Connection prepareConnection(PooledConnection pooledConnection) throws SQLException { Connection connection = pooledConnection.getConnection(); pooledConnection.addConnectionEventListener(new NoPoolConnectionEventListener()); return connection; } protected void connectionReturned(PooledConnection pooledConnection, long checkoutTime) { try { // todo do not close on caller's thread pooledConnection.close(); } catch (SQLException ignored) { // hey we tried } } protected void connectionDestroyed(PooledConnection pooledConnection, long checkoutTime) { } @Managed public int getMaxConnectionWaitMillis() { return maxConnectionWaitMillis.get(); } @Managed public void setMaxConnectionWaitMillis(Duration maxConnectionWait) throws IllegalArgumentException { if (maxConnectionWait == null) { throw new NullPointerException("maxConnectionWait is null"); } int millis = Ints.checkedCast(maxConnectionWait.toMillis()); if (millis < 1) { throw new IllegalArgumentException("maxConnectionWait must be greater than 1 millisecond"); } this.maxConnectionWaitMillis.set(millis); } @Managed public long getConnectionsActive() { return semaphore.getActivePermits(); } @Managed public int getMaxConnections() { return semaphore.getPermits(); } @Managed public void setMaxConnections(int maxConnections) { if (maxConnections < 1) { throw new IllegalArgumentException("maxConnections must be at least 1: maxConnections=" + maxConnections); } semaphore.setPermits(maxConnections); } @Managed @Flatten public ManagedDataSourceStats getStats() { return stats; } @Override public PrintWriter getLogWriter() throws SQLException { return null; } @Override public void setLogWriter(PrintWriter out) throws SQLException { } @Override public Logger getParentLogger() throws SQLFeatureNotSupportedException { throw new SQLFeatureNotSupportedException("java.util.logging not supported"); } @Override public int getLoginTimeout() throws SQLException { return (int) ceil(getMaxConnectionWaitMillis() / 1000.0); } @Override public void setLoginTimeout(int seconds) throws SQLException { } @Override public boolean isWrapperFor(Class<?> iface) throws SQLException { if (iface == null) { throw new SQLException("iface is null"); } return iface.isInstance(this); } @Override public <T> T unwrap(Class<T> iface) throws SQLException { if (iface == null) { throw new SQLException("iface is null"); } if (iface.isInstance(this)) { return iface.cast(this); } throw new SQLException(getClass().getName() + " does not implement " + iface.getName()); } /** * Not supported. * * @throws UnsupportedOperationException always */ @Override public final Connection getConnection(String username, String password) throws SQLException { throw new UnsupportedOperationException(); } private void acquirePermit() throws SQLException { int timeout = maxConnectionWaitMillis.get(); try { if (!semaphore.tryAcquire(timeout, TimeUnit.MILLISECONDS)) { throw new SqlTimeoutException("Could not acquire a connection within " + timeout + " msec"); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new SqlTimeoutException("Interrupted while waiting for connection", e); } } protected class NoPoolConnectionEventListener implements ConnectionEventListener { private final long checkoutTime = System.nanoTime(); private final AtomicBoolean returned = new AtomicBoolean(); @Override public void connectionClosed(ConnectionEvent event) { // was the connection already returned if (!returned.compareAndSet(false, true)) { return; } PooledConnection pooledConnection = null; try { pooledConnection = (PooledConnection) event.getSource(); pooledConnection.removeConnectionEventListener(this); stats.connectionReturned(nanosSince(checkoutTime)); } finally { semaphore.release(); if (pooledConnection != null) { connectionReturned(pooledConnection, checkoutTime); } } } @Override public void connectionErrorOccurred(ConnectionEvent event) { // was the connection already returned if (!returned.compareAndSet(false, true)) { return; } PooledConnection pooledConnection = null; try { pooledConnection = (PooledConnection) event.getSource(); pooledConnection.removeConnectionEventListener(this); stats.connectionErrorOccurred(); } finally { semaphore.release(); if (pooledConnection != null) { connectionDestroyed(pooledConnection, checkoutTime); } } } } }