package org.test4j.module.tracer.jdbc; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.sql.CallableStatement; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.Statement; import java.util.HashSet; import java.util.Set; import org.test4j.module.tracer.TracerManager; @SuppressWarnings("rawtypes") public class ConnectionProxy implements InvocationHandler { private final Connection connection; private final ClassLoader cl; public ConnectionProxy(final Connection connection) { this.connection = connection; this.cl = connection.getClass().getClassLoader(); } public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { Object result = method.invoke(connection, args); String methodname = method.getName(); if (CREATE_STATEMENT_METHODS.contains(methodname)) { boolean hasProxied = result instanceof IProxyMarker; boolean isStatment = result instanceof Statement; if (hasProxied || isStatment == false) { return result; } Class[] types = getStatementTypes((Statement) result); Object o = Proxy.newProxyInstance(cl, types, new StatementProxy((Statement) result)); addJdbcTracerEvent(args, methodname); return o; } return result; } private final static Class[] Statement_Types = new Class[] { Statement.class, IProxyMarker.class }; private final static Class[] PreparedStatement_Types = new Class[] { PreparedStatement.class, IProxyMarker.class }; private final static Class[] CallableStatement_Types = new Class[] { CallableStatement.class, IProxyMarker.class }; private final static Class[] getStatementTypes(Statement statment) { if (statment instanceof CallableStatement) { return CallableStatement_Types; } else if (statment instanceof PreparedStatement) { return PreparedStatement_Types; } else { return Statement_Types; } } private final static Class[] CONNECTION_TYPES = new Class[] { Connection.class, IProxyMarker.class }; private final static Set<String> CREATE_STATEMENT_METHODS = new HashSet<String>() { private static final long serialVersionUID = 3331093299449024077L; { this.add("createStatement"); this.add("prepareStatement"); this.add("prepareCall"); } }; private static final void addJdbcTracerEvent(Object[] args, String methodname) { if (args == null || args.length < 1) { return; } Object sql = args[0]; if (sql instanceof String) { TracerManager.traceJdbcStatement((String) sql); } } public static final Connection getConnectionProxy(Connection conn) { boolean hasProxied = conn instanceof IProxyMarker; if (hasProxied) { return conn; } ClassLoader cl = conn.getClass().getClassLoader(); Object o = Proxy.newProxyInstance(cl, CONNECTION_TYPES, new ConnectionProxy(conn)); return (Connection) o; } }