/** * Copyright (c) 2011-2014, hubin (jobob@qq.com). * <p> * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.baomidou.mybatisplus.plugins; import java.lang.reflect.Field; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.sql.Connection; import java.sql.Timestamp; import java.util.Date; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import org.apache.ibatis.binding.MapperMethod.ParamMap; import org.apache.ibatis.exceptions.ExceptionFactory; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ParameterMapping; import org.apache.ibatis.mapping.SqlCommandType; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.SystemMetaObject; import org.apache.ibatis.session.Configuration; import org.apache.ibatis.type.TypeException; import org.apache.ibatis.type.UnknownTypeHandler; import com.baomidou.mybatisplus.annotations.TableField; import com.baomidou.mybatisplus.annotations.Version; import com.baomidou.mybatisplus.mapper.EntityWrapper; import com.baomidou.mybatisplus.toolkit.PluginUtils; import com.baomidou.mybatisplus.toolkit.StringUtils; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.statement.update.Update; /** * <p> * MyBatis乐观锁插件 * </p> * * <pre> * 之前:update user set name = ?, password = ? where id = ? * 之后:update user set name = ?, password = ?, version = version+1 where id = ? and version = ? * 对象上的version字段上添加{@link Version}注解 * sql可以不需要写version字段,只要对象version有值就会更新 * 支持,int Integer, long Long, Date,Timestamp * 其他类型可以自定义实现,注入versionHandlers,多个以逗号分隔 * </pre> * * @author TaoYu 小锅盖 * @since 2017-04-08 */ @Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class, Integer.class }) }) public final class OptimisticLockerInterceptor implements Interceptor { /** * 根据对象类型缓存version基本信息 */ private static final Map<Class<?>, LockerCache> versionCache = new ConcurrentHashMap<>(); /** * 根据version字段类型缓存的处理器 */ private static final Map<Type, VersionHandler<?>> typeHandlers = new HashMap<>(); private static final Expression RIGHT_EXPRESSION = new Column("?"); static { IntegerTypeHandler integerTypeHandler = new IntegerTypeHandler(); typeHandlers.put(int.class, integerTypeHandler); typeHandlers.put(Integer.class, integerTypeHandler); LongTypeHandler longTypeHandler = new LongTypeHandler(); typeHandlers.put(long.class, longTypeHandler); typeHandlers.put(Long.class, longTypeHandler); typeHandlers.put(Date.class, new DateTypeHandler()); typeHandlers.put(Timestamp.class, new TimestampTypeHandler()); } public Object intercept(Invocation invocation) throws Exception { StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget()); MetaObject metaObject = SystemMetaObject.forObject(statementHandler); // 先判断是不是真正的UPDATE操作 MappedStatement ms = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); if (!ms.getSqlCommandType().equals(SqlCommandType.UPDATE)) { return invocation.proceed(); } BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql"); // 获得参数类型,去缓存中快速判断是否有version注解才继续执行 Class<?> parameterClass = ms.getParameterMap().getType(); LockerCache lockerCache = versionCache.get(parameterClass); if (lockerCache != null) { if (lockerCache.lock) { processChangeSql(ms, boundSql, lockerCache); } } else { Field versionField = getVersionField(parameterClass); if (versionField != null) { Class<?> fieldType = versionField.getType(); if (!typeHandlers.containsKey(fieldType)) { throw new TypeException("乐观锁不支持" + fieldType.getName() + "类型,请自定义实现"); } final TableField tableField = versionField.getAnnotation(TableField.class); String versionColumn = versionField.getName(); if (tableField != null) { versionColumn = tableField.value(); } LockerCache lc = new LockerCache(true, versionColumn, versionField, typeHandlers.get(fieldType)); versionCache.put(parameterClass, lc); processChangeSql(ms, boundSql, lc); } else { versionCache.put(parameterClass, LockerCache.INSTANCE); } } return invocation.proceed(); } private Field getVersionField(Class<?> parameterClass) { if (parameterClass != Object.class) { for (Field field : parameterClass.getDeclaredFields()) { if (field.isAnnotationPresent(Version.class)) { field.setAccessible(true); return field; } } return getVersionField(parameterClass.getSuperclass()); } return null; } private void processChangeSql(MappedStatement ms, BoundSql boundSql, LockerCache lockerCache) throws Exception { Object parameterObject = boundSql.getParameterObject(); if (parameterObject instanceof ParamMap) { ParamMap<?> paramMap = (ParamMap<?>) parameterObject; parameterObject = paramMap.get("et"); EntityWrapper<?> entityWrapper = (EntityWrapper<?>) paramMap.get("ew"); if (entityWrapper != null) { Object entity = entityWrapper.getEntity(); if (entity != null && lockerCache.field.get(entity) == null) { changSql(ms, boundSql, parameterObject, lockerCache); } } } else { changSql(ms, boundSql, parameterObject, lockerCache); } } @SuppressWarnings("unchecked") private void changSql(MappedStatement ms, BoundSql boundSql, Object parameterObject, LockerCache lockerCache) throws Exception { Field versionField = lockerCache.field; String versionColumn = lockerCache.column; final Object versionValue = versionField.get(parameterObject); if (versionValue != null) {// 先判断传参是否携带version,没带跳过插件 Configuration configuration = ms.getConfiguration(); // 给字段赋新值 lockerCache.versionHandler.plusVersion(parameterObject, versionField, versionValue); // 处理where条件,添加? Update jsqlSql = (Update) CCJSqlParserUtil.parse(boundSql.getSql()); BinaryExpression expression = (BinaryExpression) jsqlSql.getWhere(); if (expression != null && !expression.toString().contains(versionColumn)) { EqualsTo equalsTo = new EqualsTo(); equalsTo.setLeftExpression(new Column(versionColumn)); equalsTo.setRightExpression(RIGHT_EXPRESSION); jsqlSql.setWhere(new AndExpression(equalsTo, expression)); List<ParameterMapping> parameterMappings = new LinkedList<>(boundSql.getParameterMappings()); parameterMappings.add(jsqlSql.getExpressions().size(), getVersionMappingInstance(configuration)); MetaObject boundSqlMeta = configuration.newMetaObject(boundSql); boundSqlMeta.setValue("sql", jsqlSql.toString()); boundSqlMeta.setValue("parameterMappings", parameterMappings); } // 设置参数 boundSql.setAdditionalParameter("originVersionValue", versionValue); } } private volatile ParameterMapping parameterMapping; private ParameterMapping getVersionMappingInstance(Configuration configuration) { if (parameterMapping == null) { synchronized (OptimisticLockerInterceptor.class) { if (parameterMapping == null) { parameterMapping = new ParameterMapping.Builder(configuration, "originVersionValue", new UnknownTypeHandler(configuration.getTypeHandlerRegistry())).build(); } } } return parameterMapping; } @Override public Object plugin(Object target) { if (target instanceof StatementHandler) { return Plugin.wrap(target, this); } return target; } @Override public void setProperties(Properties properties) { String versionHandlers = properties.getProperty("versionHandlers"); if (StringUtils.isNotEmpty(versionHandlers)) { for (String handlerClazz : versionHandlers.split(",")) { try { registerHandler(Class.forName(handlerClazz)); } catch (Exception e) { throw ExceptionFactory.wrapException("乐观锁插件自定义处理器注册失败", e); } } } } /** * 注册处理器 */ private static void registerHandler(Class<?> versionHandlerClazz) throws Exception { ParameterizedType parameterizedType = (ParameterizedType) versionHandlerClazz.getGenericInterfaces()[0]; Object versionInstance = versionHandlerClazz.newInstance(); if (!(versionInstance instanceof VersionHandler)) { throw new TypeException("参数未实现VersionHandler,不能注入"); } else { Type[] actualTypeArguments = parameterizedType.getActualTypeArguments(); if (actualTypeArguments.length == 0) { throw new IllegalArgumentException("处理器泛型未定义"); } else if (Object.class.equals(actualTypeArguments[0])) { throw new IllegalArgumentException("处理器泛型不能为Object"); } else { typeHandlers.put(actualTypeArguments[0], (VersionHandler<?>) versionInstance); } } } // *****************************基本类型处理器***************************** private static class IntegerTypeHandler implements VersionHandler<Integer> { public void plusVersion(Object paramObj, Field field, Integer versionValue) throws Exception { field.set(paramObj, versionValue + 1); } } private static class LongTypeHandler implements VersionHandler<Long> { public void plusVersion(Object paramObj, Field field, Long versionValue) throws Exception { field.set(paramObj, versionValue + 1); } } // ***************************** 时间类型处理器***************************** private static class DateTypeHandler implements VersionHandler<Date> { public void plusVersion(Object paramObj, Field field, Date versionValue) throws Exception { field.set(paramObj, new Date()); } } private static class TimestampTypeHandler implements VersionHandler<Timestamp> { public void plusVersion(Object paramObj, Field field, Timestamp versionValue) throws Exception { field.set(paramObj, new Timestamp(new Date().getTime())); } } /** * 缓存对象 */ @SuppressWarnings("rawtypes") private static class LockerCache { public static final LockerCache INSTANCE = new LockerCache(); private boolean lock; private String column; private Field field; private VersionHandler versionHandler; public LockerCache() { } LockerCache(Boolean lock, String column, Field field, VersionHandler versionHandler) { this.lock = lock; this.column = column; this.field = field; this.versionHandler = versionHandler; } } }