package org.fanhongtao.mybatis.frame; import java.sql.Connection; import java.util.ArrayList; import java.util.List; import java.util.Properties; import org.apache.ibatis.builder.xml.dynamic.DynamicSqlSource; import org.apache.ibatis.builder.xml.dynamic.IfSqlNode; import org.apache.ibatis.builder.xml.dynamic.MixedSqlNode; import org.apache.ibatis.builder.xml.dynamic.SqlNode; import org.apache.ibatis.builder.xml.dynamic.TextSqlNode; import org.apache.ibatis.exceptions.TooManyResultsException; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.executor.SimpleExecutor; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.MappedStatement.Builder; import org.apache.ibatis.mapping.ResultMap; import org.apache.ibatis.mapping.ResultMapping; import org.apache.ibatis.mapping.SqlCommandType; import org.apache.ibatis.mapping.SqlSource; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.session.Configuration; import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import org.fanhongtao.lang.ReflectUtil; /** * @author Fan Hongtao * @created 2010-8-27 */ @Intercepts(@Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class })) public class PageQueryInterceptor implements Interceptor { private static int MAPPED_STATEMENT_INDEX = 0; private static int PARAMETER_INDEX = 1; private static int ROWBOUNDS_INDEX = 2; @Override public Object intercept(Invocation invocation) throws Throwable { Executor executor = (Executor) invocation.getTarget(); Connection connection = executor.getTransaction().getConnection(); if (MyBatisConfig.hasConnection(connection)) { queryCount(invocation, connection); } return invocation.proceed(); } @Override public Object plugin(Object target) { return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) { } @SuppressWarnings("unchecked") private void queryCount(Invocation invocation, Connection connection) throws Throwable { Object[] queryArgs = invocation.getArgs(); MappedStatement ms = (MappedStatement) queryArgs[MAPPED_STATEMENT_INDEX]; Object parameter = queryArgs[PARAMETER_INDEX]; RowBounds rowBounds = (RowBounds) queryArgs[ROWBOUNDS_INDEX]; // 将需要执行的查询语句修改成 select count(*) 的形式 MappedStatement queryCountMs = MyBatisConfig.getQueryCountStatement(ms); if (null == queryCountMs) { queryCountMs = createMappedStatement(ms); MyBatisConfig.registerQueryCountStatement(ms, queryCountMs); } Executor executor = (Executor) invocation.getTarget(); SimpleExecutor s = new SimpleExecutor(ms.getConfiguration(), executor.getTransaction()); List<Integer> list = (List<Integer>) s.doQuery(queryCountMs, parameter, RowBounds.DEFAULT, null); if (list.size() == 1) { int count = (Integer) list.get(0); MyBatisConfig.setRecordNum(connection, count); int offset = rowBounds.getOffset(); if ((offset > count) && (offset != RowBounds.NO_ROW_OFFSET)) { PageHelper helper = new PageHelper(count, rowBounds.getLimit()); helper.setCurrPage(helper.getMaxPage()); queryArgs[ROWBOUNDS_INDEX] = helper.getRowBounds(); } } else { throw new TooManyResultsException("Expected one result to be returned by PageQueryInterceptor.intercept()"); } } private MappedStatement createMappedStatement(MappedStatement ms) { MappedStatement queryMs = null; synchronized (ms) { queryMs = MyBatisConfig.getQueryCountStatement(ms); if (null != queryMs) { return queryMs; } } DynamicSqlSource sqlSource = (DynamicSqlSource) ms.getSqlSource(); Configuration configuration = (Configuration) ReflectUtil.getField(sqlSource, "configuration"); SqlNode rootSqlNode = (SqlNode) ReflectUtil.getField(sqlSource, "rootSqlNode"); MixedSqlNode newRootSqlNode = (MixedSqlNode) SqlNodeUtils.clone(rootSqlNode); @SuppressWarnings("unchecked") List<SqlNode> contents = (List<SqlNode>) ReflectUtil.getField(newRootSqlNode, "contents"); TextSqlNode firstNode = (TextSqlNode) contents.get(0); String firstSql = (String) ReflectUtil.getField(firstNode, "text"); String tmpSql = firstSql.toUpperCase(); int fromIndex = tmpSql.indexOf("FROM"); int orderByIndex = tmpSql.indexOf("ORDER BY"); if (orderByIndex > 0) { tmpSql = "select count(*) " + firstSql.substring(fromIndex, orderByIndex); ReflectUtil.setField(firstNode, "text", tmpSql); } else { tmpSql = "select count(*) " + firstSql.substring(fromIndex); ReflectUtil.setField(firstNode, "text", tmpSql); for (int i = 1; i < contents.size(); i++) { SqlNode node = contents.get(i); replaceSql(node); } } DynamicSqlSource dynamicSqlSource = new DynamicSqlSource(configuration, newRootSqlNode); queryMs = copyFromMappedStatement(ms, dynamicSqlSource); return queryMs; } private void replaceSql(SqlNode sqlNode) { if (sqlNode instanceof TextSqlNode) { TextSqlNode textSqlNode = (TextSqlNode) sqlNode; String sql = (String) ReflectUtil.getField(textSqlNode, "text"); int index = sql.toUpperCase().indexOf("ORDER BY"); if (index >= 0) { String tmpSql = sql.substring(0, index); ReflectUtil.setField(textSqlNode, "text", tmpSql); } } else if (sqlNode instanceof IfSqlNode) { IfSqlNode ifSqlNode = (IfSqlNode) sqlNode; SqlNode contents = (SqlNode) ReflectUtil.getField(ifSqlNode, "contents"); replaceSql(contents); } else if (sqlNode instanceof MixedSqlNode) { MixedSqlNode mixedSqlNode = (MixedSqlNode) sqlNode; @SuppressWarnings("unchecked") List<SqlNode> contents = (List<SqlNode>) ReflectUtil.getField(mixedSqlNode, "contents"); for (SqlNode node : contents) { replaceSql(node); } } } private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource sqlSource) { Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), sqlSource, SqlCommandType.SELECT); builder.resource(ms.getResource()); builder.fetchSize(ms.getFetchSize()); builder.statementType(ms.getStatementType()); builder.keyGenerator(ms.getKeyGenerator()); builder.keyProperty(ms.getKeyProperty()); builder.timeout(ms.getTimeout()); builder.parameterMap(ms.getParameterMap()); // 将返回值类型修改成 Integer ResultMap.Builder resultMapBuilder = new ResultMap.Builder(ms.getConfiguration(), "", Integer.class, new ArrayList<ResultMapping>()); List<ResultMap> resultMapList = new ArrayList<ResultMap>(); resultMapList.add(resultMapBuilder.build()); builder.resultMaps(resultMapList); builder.cache(ms.getCache()); MappedStatement newMs = builder.build(); return newMs; } }