package com.easyooo.framework.support.mybatis; import java.util.ArrayList; import java.util.List; import java.util.Properties; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ParameterMapping; import org.apache.ibatis.mapping.SqlSource; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.session.Configuration; import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import org.apache.ibatis.type.TypeHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.InitializingBean; import com.easyooo.framework.support.mybatis.util.DialectUtil; import com.google.common.base.Joiner; /** * <p>分页插件实现,扩展<code>MyBatis Interceptor</code> * 分页插件只在当<code>pagination</code>作为查询参数时有效. * </p> * <p> * example:public void selectAll(Pagination){}</p> * <p> * 当有参数时,将参数设置到Paginatoin#criteria * <p> * * * @see Pagination * * @author Killer */ @Intercepts({ @Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}) }) public class PaginationPlugin implements Interceptor, InitializingBean { static final Logger logger = LoggerFactory.getLogger(PaginationPlugin.class); private static final char DEFAULT_SEPARATOR = ','; private static final Integer MAPPED_STATEMENT_INDEX = 0; private static final Integer PARAMETER_INDEX = 1; private static final Integer ROW_BOUNDS_INDEX = 2; private static final String OFFSET_PARAMETER = "_pagination_offset"; private static final String LIMIT_PARAMETER = "_pagination_limit"; // private static final String DEFAULT_DBMS = "ORACLE"; private Dialect dialect; // properties private String dbms; @Override public Object intercept(Invocation invocation) throws Throwable { Object[] args = invocation.getArgs(); // no paging if(!isPaging(args)){ return invocation.proceed(); } // process for paging InvocationContext context = getInvocationContext(invocation); InvocationContext newContext = processIntercept(context); swapParameter(newContext, args); Object result = invocation.proceed(); if(result != null && result instanceof List){ newContext.getPagination().setRecords((List<?>)result); } return result; } @Override public Object plugin(Object target) { if (target instanceof Executor) { return Plugin.wrap(target, this); } return target; } @Override public void setProperties(Properties properties) { // do nothing } /** * check method * @param ms * @param args * @return args was Pagination return true, else false */ private boolean isPaging(Object[] args){ return (args[PARAMETER_INDEX] instanceof Pagination); } /** * 封装代理上下文参数 * @param invocation * @return */ private InvocationContext getInvocationContext(Invocation invocation){ Object[] args = invocation.getArgs(); MappedStatement ms = (MappedStatement)args[MAPPED_STATEMENT_INDEX]; Pagination ps = (Pagination)args[PARAMETER_INDEX]; return new InvocationContext(ms, ps); } /** * 交换<code>MyBatis</code> 参数列表 * * @param newContext * @param args */ private void swapParameter(InvocationContext newContext, Object[] args){ Pagination pg = newContext.getPagination(); // setter new MapperdStatement args[MAPPED_STATEMENT_INDEX] = newContext.getMappedStatement(); // Criteria property swap pagination object args[PARAMETER_INDEX] = pg.getCriteria(); //RowBounds rowBounds = new RowBounds(pg.getOffset(), pg.getLimit()); args[ROW_BOUNDS_INDEX] = new RowBounds(); } private InvocationContext processIntercept(InvocationContext context)throws Throwable { MappedStatement ms = context.getMappedStatement(); Pagination pagination = context.getPagination(); BoundSql boundSql = ms.getBoundSql(pagination.getCriteria()); // counting if(pagination.isNeedTotalCount()){ Integer counting = new CountingExecutor(ms, dialect, boundSql).execute(); pagination.setTotalCount(counting); } // paging String pagingSql = dialect.getPagingSQL(boundSql.getSql().trim()); // cpy mappings List<ParameterMapping> mappings = new ArrayList<ParameterMapping>(); if(boundSql.getParameterMappings() != null){ List<ParameterMapping> tmpMappings = boundSql.getParameterMappings(); for (int i = 0; i < tmpMappings.size(); i++) { mappings.add(tmpMappings.get(i)); } } BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), pagingSql, mappings, boundSql.getParameterObject()); cpyAndAppendParameters(ms, pagination, boundSql, newBoundSql); InvocationContext newContext = new InvocationContext(); MappedStatement newms = cloneMappedStatement(ms, newBoundSql); newContext.setMappedStatement(newms); newContext.setPagination(pagination); return newContext; } private void cpyAndAppendParameters(MappedStatement ms, Pagination pg, BoundSql boundSql, BoundSql newBoundSql){ // cpy old parameters for (ParameterMapping mapping : newBoundSql.getParameterMappings()) { String prop = mapping.getProperty(); if (boundSql.hasAdditionalParameter(prop)) { newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop)); } } // append pagination parameters Configuration cf = ms.getConfiguration(); TypeHandler<?> type = cf.getTypeHandlerRegistry().getTypeHandler(Integer.class); ParameterMapping offsetMapping = new ParameterMapping.Builder(cf, OFFSET_PARAMETER, type).build(); ParameterMapping limitMapping = new ParameterMapping.Builder(cf, LIMIT_PARAMETER, type).build(); ParameterMapping[] mappings = new ParameterMapping[]{offsetMapping, limitMapping}; for (Order order : dialect.order()) { newBoundSql.getParameterMappings().add(mappings[order.ordinal()]); } newBoundSql.setAdditionalParameter(OFFSET_PARAMETER, pg.getOffset()); // 如果是Oracle,第二个参数需要设置起始位置加Limit得到结束位置 // 与MySql是不一样 if(DBMS.ORACLE.name().equals(dbms)){ newBoundSql.setAdditionalParameter(LIMIT_PARAMETER, pg.getOffset() + pg.getLimit()); }else{ newBoundSql.setAdditionalParameter(LIMIT_PARAMETER, pg.getLimit()); } } private MappedStatement cloneMappedStatement(MappedStatement old, BoundSql boundSql){ MappedStatement.Builder builder = new MappedStatement.Builder( old.getConfiguration(), old.getId(), new AlwaySqlSource( boundSql), old.getSqlCommandType()); builder.cache(old.getCache()); builder.databaseId(old.getDatabaseId()); builder.fetchSize(old.getFetchSize()); builder.flushCacheRequired(old.isFlushCacheRequired()); builder.keyGenerator(old.getKeyGenerator()); builder.keyProperty(join(old.getKeyProperties())); builder.keyColumn(join(old.getKeyColumns())); builder.lang(old.getLang()); builder.resource(old.getResource()); builder.resultMaps(old.getResultMaps()); builder.resultSetType(old.getResultSetType()); builder.parameterMap(old.getParameterMap()); builder.statementType(old.getStatementType()); builder.timeout(old.getTimeout()); builder.useCache(old.isUseCache()); builder.resultOrdered(old.isResultOrdered()); builder.resulSets(join(old.getResulSets())); return builder.build(); } private String join(String[] strs){ if(strs == null || strs.length == 0 ){ return null; } return Joiner.on(DEFAULT_SEPARATOR).join(strs); } class AlwaySqlSource implements SqlSource{ private BoundSql bs; public AlwaySqlSource(BoundSql bs){ this.bs = bs; } @Override public BoundSql getBoundSql(Object po) { return bs; } }; public String getDbms() { return dbms; } public void setDbms(String dbms) { this.dbms = dbms; } @Override public void afterPropertiesSet() throws Exception { if(dbms == null || dbms.equals("")){ dbms = DEFAULT_DBMS; } this.dialect = new DialectUtil().switchDialect(dbms.toUpperCase()); } }