/****************************************************************************** * Copyright © 2013-2016 The Nxt Core Developers. * * * * See the AUTHORS.txt, DEVELOPER-AGREEMENT.txt and LICENSE.txt files at * * the top-level directory of this distribution for the individual copyright * * holder information and the developer policies on copyright and licensing. * * * * Unless otherwise agreed in a custom licensing agreement, no part of the * * Nxt software, including this file, may be copied, modified, propagated, * * or distributed except according to the terms contained in the LICENSE.txt * * file. * * * * Removal or modification of this copyright notice is prohibited. * * * ******************************************************************************/ package nxt.db; import nxt.Nxt; import nxt.util.Logger; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; public class TransactionalDb extends BasicDb { private static final DbFactory factory = new DbFactory(); private static final long stmtThreshold; private static final long txThreshold; private static final long txInterval; static { long temp; stmtThreshold = (temp=Nxt.getIntProperty("nxt.statementLogThreshold")) != 0 ? temp : 1000; txThreshold = (temp=Nxt.getIntProperty("nxt.transactionLogThreshold")) != 0 ? temp : 5000; txInterval = (temp=Nxt.getIntProperty("nxt.transactionLogInterval")) != 0 ? temp*60*1000 : 15*60*1000; } private final ThreadLocal<DbConnection> localConnection = new ThreadLocal<>(); private final ThreadLocal<Map<String,Map<DbKey,Object>>> transactionCaches = new ThreadLocal<>(); private final ThreadLocal<Set<TransactionCallback>> transactionCallback = new ThreadLocal<>(); private volatile long txTimes = 0; private volatile long txCount = 0; private volatile long statsTime = 0; public TransactionalDb(DbProperties dbProperties) { super(dbProperties); } @Override public Connection getConnection() throws SQLException { Connection con = localConnection.get(); if (con != null) { return con; } return new DbConnection(super.getConnection()); } public boolean isInTransaction() { return localConnection.get() != null; } public Connection beginTransaction() { if (localConnection.get() != null) { throw new IllegalStateException("Transaction already in progress"); } try { Connection con = getPooledConnection(); con.setAutoCommit(false); con = new DbConnection(con); ((DbConnection)con).txStart = System.currentTimeMillis(); localConnection.set((DbConnection)con); transactionCaches.set(new HashMap<>()); return con; } catch (SQLException e) { throw new RuntimeException(e.toString(), e); } } public void commitTransaction() { DbConnection con = localConnection.get(); if (con == null) { throw new IllegalStateException("Not in transaction"); } try { con.doCommit(); Set<TransactionCallback> callbacks = transactionCallback.get(); if (callbacks != null) { callbacks.forEach(TransactionCallback::commit); transactionCallback.set(null); } } catch (SQLException e) { throw new RuntimeException(e.toString(), e); } } public void rollbackTransaction() { DbConnection con = localConnection.get(); if (con == null) { throw new IllegalStateException("Not in transaction"); } try { con.doRollback(); } catch (SQLException e) { throw new RuntimeException(e.toString(), e); } finally { transactionCaches.get().clear(); Set<TransactionCallback> callbacks = transactionCallback.get(); if (callbacks != null) { callbacks.forEach(TransactionCallback::rollback); transactionCallback.set(null); } } } public void endTransaction() { Connection con = localConnection.get(); if (con == null) { throw new IllegalStateException("Not in transaction"); } localConnection.set(null); transactionCaches.set(null); long now = System.currentTimeMillis(); long elapsed = now - ((DbConnection)con).txStart; if (elapsed >= txThreshold) { logThreshold(String.format("Database transaction required %.3f seconds at height %d", (double)elapsed/1000.0, Nxt.getBlockchain().getHeight())); } else { long count, times; boolean logStats = false; synchronized(this) { count = ++txCount; times = txTimes += elapsed; if (now - statsTime >= txInterval) { logStats = true; txCount = 0; txTimes = 0; statsTime = now; } } if (logStats) Logger.logDebugMessage(String.format("Average database transaction time is %.3f seconds", (double)times/1000.0/(double)count)); } DbUtils.close(con); } public void registerCallback(TransactionCallback callback) { Set<TransactionCallback> callbacks = transactionCallback.get(); if (callbacks == null) { callbacks = new HashSet<>(); transactionCallback.set(callbacks); } callbacks.add(callback); } Map<DbKey,Object> getCache(String tableName) { if (!isInTransaction()) { throw new IllegalStateException("Not in transaction"); } Map<DbKey,Object> cacheMap = transactionCaches.get().get(tableName); if (cacheMap == null) { cacheMap = new HashMap<>(); transactionCaches.get().put(tableName, cacheMap); } return cacheMap; } void clearCache(String tableName) { Map<DbKey,Object> cacheMap = transactionCaches.get().get(tableName); if (cacheMap != null) { cacheMap.clear(); } } public void clearCache() { transactionCaches.get().values().forEach(Map::clear); } private static void logThreshold(String msg) { StringBuilder sb = new StringBuilder(512); sb.append(msg).append('\n'); StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); boolean firstLine = true; for (int i=3; i<stackTrace.length; i++) { String line = stackTrace[i].toString(); if (!line.startsWith("nxt.")) break; if (firstLine) firstLine = false; else sb.append('\n'); sb.append(" ").append(line); } Logger.logDebugMessage(sb.toString()); } private final class DbConnection extends FilteredConnection { long txStart = 0; private DbConnection(Connection con) { super(con, factory); } @Override public void setAutoCommit(boolean autoCommit) throws SQLException { throw new UnsupportedOperationException("Use Db.beginTransaction() to start a new transaction"); } @Override public void commit() throws SQLException { if (localConnection.get() == null) { super.commit(); } else if (this != localConnection.get()) { throw new IllegalStateException("Previous connection not committed"); } else { commitTransaction(); } } private void doCommit() throws SQLException { super.commit(); } @Override public void rollback() throws SQLException { if (localConnection.get() == null) { super.rollback(); } else if (this != localConnection.get()) { throw new IllegalStateException("Previous connection not committed"); } else { rollbackTransaction(); } } private void doRollback() throws SQLException { super.rollback(); } @Override public void close() throws SQLException { if (localConnection.get() == null) { super.close(); } else if (this != localConnection.get()) { throw new IllegalStateException("Previous connection not committed"); } } } private static final class DbStatement extends FilteredStatement { private DbStatement(Statement stmt) { super(stmt); } @Override public boolean execute(String sql) throws SQLException { long start = System.currentTimeMillis(); boolean b = super.execute(sql); long elapsed = System.currentTimeMillis() - start; if (elapsed > stmtThreshold) logThreshold(String.format("SQL statement required %.3f seconds at height %d:\n%s", (double)elapsed/1000.0, Nxt.getBlockchain().getHeight(), sql)); return b; } @Override public ResultSet executeQuery(String sql) throws SQLException { long start = System.currentTimeMillis(); ResultSet r = super.executeQuery(sql); long elapsed = System.currentTimeMillis() - start; if (elapsed > stmtThreshold) logThreshold(String.format("SQL statement required %.3f seconds at height %d:\n%s", (double)elapsed/1000.0, Nxt.getBlockchain().getHeight(), sql)); return r; } @Override public int executeUpdate(String sql) throws SQLException { long start = System.currentTimeMillis(); int c = super.executeUpdate(sql); long elapsed = System.currentTimeMillis() - start; if (elapsed > stmtThreshold) logThreshold(String.format("SQL statement required %.3f seconds at height %d:\n%s", (double)elapsed/1000.0, Nxt.getBlockchain().getHeight(), sql)); return c; } } private static final class DbPreparedStatement extends FilteredPreparedStatement { private DbPreparedStatement(PreparedStatement stmt, String sql) { super(stmt, sql); } @Override public boolean execute() throws SQLException { long start = System.currentTimeMillis(); boolean b = super.execute(); long elapsed = System.currentTimeMillis() - start; if (elapsed > stmtThreshold) logThreshold(String.format("SQL statement required %.3f seconds at height %d:\n%s", (double)elapsed/1000.0, Nxt.getBlockchain().getHeight(), getSQL())); return b; } @Override public ResultSet executeQuery() throws SQLException { long start = System.currentTimeMillis(); ResultSet r = super.executeQuery(); long elapsed = System.currentTimeMillis() - start; if (elapsed > stmtThreshold) logThreshold(String.format("SQL statement required %.3f seconds at height %d:\n%s", (double)elapsed/1000.0, Nxt.getBlockchain().getHeight(), getSQL())); return r; } @Override public int executeUpdate() throws SQLException { long start = System.currentTimeMillis(); int c = super.executeUpdate(); long elapsed = System.currentTimeMillis() - start; if (elapsed > stmtThreshold) logThreshold(String.format("SQL statement required %.3f seconds at height %d:\n%s", (double)elapsed/1000.0, Nxt.getBlockchain().getHeight(), getSQL())); return c; } } private static final class DbFactory implements FilteredFactory { @Override public Statement createStatement(Statement stmt) { return new DbStatement(stmt); } @Override public PreparedStatement createPreparedStatement(PreparedStatement stmt, String sql) { return new DbPreparedStatement(stmt, sql); } } /** * Transaction callback interface */ public interface TransactionCallback { /** * Transaction has been committed */ void commit(); /** * Transaction has been rolled back */ void rollback(); } }