mybatis插件-乐观锁

1、业务背景

我司使用mysql数据库的InnoDB引擎,在执行数据库更新操做时使用了select ...... for update语句,在必定状况下可能致使行级锁转表级锁,在高并发的场景下致使性能低下,故而打算使用乐观锁解决部分性能问题。java

系统已经上线,修改全部更新代码改动量大,故决定经过插件方式。mysql

2、乐观锁简介

乐观锁经过在数据库中增长锁字段,例如version,更新语句以下git

update from TABLE1 set version = version+ 1 where version = versiongithub

每次更新时版本号字段都会加1,此时,将提交数据的版本数据与数据库表对应记录的当前版本信息进行比对,若是提交的数据版本号大于数据库表当前版本号,则予以更新,不然认为是过时数据不予更新。sql

3、插件使用

  1.  使用说明。

  

<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE configuration
  PUBLIC "-//mybatis.org//DTD Config 3.0//EN"
  "http://mybatis.org/dtd/mybatis-3-config.dtd">
<configuration>

    <plugins>
        <plugin interceptor="com.vi.optimistic.lock.interceptor.OptimisticLocker">
            <!--<property name="versionField" value="myVersion"/>-->
            <!--<property name="versionColumn" value="my_version"/>-->
        </plugin>
    </plugins>
    
    <environments default="development">
        <environment id="development">
            <transactionManager type="JDBC" />
            <dataSource type="UNPOOLED">
                <property name="driver" value="com.mysql.jdbc.Driver" />
                <property name="url" value="jdbc:mysql://localhost:3306/test" />
                <property name="username" value="root" />
                <property name="password" value="123456" />
            </dataSource>
        </environment>
    </environments>
    <mappers>
        <mapper resource="mapper/UserDefaultMapper.xml" />
        <mapper resource="mapper/UserVersionMapper.xml" />
    </mappers>
</configuration>

 

加入plugins插件,能够经过指定versionField 指定实体类名称,versionColumn 指定表中字段。暂不支持批量更新,后续会完善。数据库

4、插件原理简析

一、本插件经过拦截StatementHandler,默认只支持PreparedStatement。express

@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})

  二、实现mybatis的Interceptor接口,主要拦截方法为public Object intercept(Invocation invocation) throws Exception {}方法。apache

  三、得到mybatis的四大对象中的StatementHandler对象,经过SystemMetaObject工具类得到MetaObject对象,加载出MappedStatement对象获取sql类型,本插件只拦截更新操做。mybatis

  

        MappedStatement ms = (MappedStatement) metaObject.getValue("delegate.mappedStatement");

 

  四、经过MetaObject对象得到mapper中的sql即BoundSql,这也是后续咱们须要修改的sql 主要为在where语句后添加version = version。并发

        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");

  五、经过MetaObject得到原version值。

        Object originalVersion = metaObject.getValue("delegate.boundSql.parameterObject." + VERSION_FIELD);

  六、经过jsqlparser工具类修改boundSql 

  七、插入新的boundSql 和 originalVersion (version=version+1)新的锁值,默认类型为long(后续会支持int等类型)。

  metaObject.setValue("delegate.boundSql.sql", originalSql);
  metaObject.setValue("delegate.boundSql.parameterObject." + VERSION_FIELD, (Long) originalVersion + 1);

  八、默认的一些方法如生成代理对象。

@Override
    public void setProperties(Properties properties) {
        if (null != properties && !properties.isEmpty()) {
            props = properties;
        }
        if (props != null) {
            VERSION_COLUMN = props.getProperty("versionColumn", "version");
            VERSION_FIELD = props.getProperty("versionField", "version");
        }
    }

 

  

@Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler || target instanceof ParameterHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

  九、初始化配置文件。

    @Override
    public void setProperties(Properties properties) {
        if (null != properties && !properties.isEmpty()) {
            props = properties;
        }
        if (props != null) {
            VERSION_COLUMN = props.getProperty("versionColumn", "version");
            VERSION_FIELD = props.getProperty("versionField", "version");
        }
    }

  十、主要功能代码

package com.vi.optimistic.lock.interceptor;

import com.vi.optimistic.lock.util.PluginUtil;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.arithmetic.Addition;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.binding.BindingException;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.sql.Connection;
import java.util.Collection;
import java.util.List;
import java.util.Properties;

/**
 * 拦截默认PreparedStatement
 * <p>MyBatis乐观锁插件<br>
 *
 * @author vi
 * @version 0.0.1
 * @date 2018-04-01
 * @since JDK1.8
 */
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class OptimisticLocker implements Interceptor {
    private static final Log log = LogFactory.getLog(OptimisticLocker.class);
    //数据库列名
    private static String VERSION_COLUMN = "version";
    //实体类字段名
    private static String VERSION_FIELD = "version";
    //拦截类型
    private static final String METHOD_TYPE = "prepare";

    private static Properties props = null;

    @Override
    public Object intercept(Invocation invocation) throws Exception {
        String interceptMethod = invocation.getMethod().getName();
        if (!METHOD_TYPE.equals(interceptMethod)) {
            return invocation.proceed();
        }
        StatementHandler handler = (StatementHandler) PluginUtil.processTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(handler);
        MappedStatement ms = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        SqlCommandType sqlCmdType = ms.getSqlCommandType();
        if (sqlCmdType != SqlCommandType.UPDATE) {
            return invocation.proceed();
        }
        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        //TODO 批量更新时须要取list中的参数,后续完善。
        //原乐观锁值
        Object originalVersion = metaObject.getValue("delegate.boundSql.parameterObject." + VERSION_FIELD);
        if (originalVersion == null || Long.parseLong(originalVersion.toString()) <= 0) {
            throw new BindingException("value of version field[" + VERSION_FIELD + "]can not be empty");
        }
        String originalSql = boundSql.getSql();
        if (log.isDebugEnabled()) {
            log.debug("originalSql: " + originalSql);
        }
        originalSql = addVersionToSql(originalSql, VERSION_COLUMN, originalVersion);
        metaObject.setValue("delegate.boundSql.sql", originalSql);
        metaObject.setValue("delegate.boundSql.parameterObject." + VERSION_FIELD, (Long) originalVersion + 1);
        if (log.isDebugEnabled()) {
            log.debug("originalSql after add version: " + originalSql);
            log.debug("delegate.boundSql.parameterObject." + VERSION_FIELD + originalSql);
        }
        return invocation.proceed();
    }

    private String addVersionToSql(String originalSql, String versionColumnName, Object originalVersion) {
        try {
            Statement stmt = CCJSqlParserUtil.parse(originalSql);
            if (!(stmt instanceof Update)) {
                return originalSql;
            }
            Update update = (Update) stmt;
            if (!contains(update, versionColumnName)) {
                buildVersionExpression(update, versionColumnName);
            }
            Expression where = update.getWhere();
            if (where != null) {
                AndExpression and = new AndExpression(where, buildVersionEquals(versionColumnName, originalVersion));
                update.setWhere(and);
            } else {
                update.setWhere(buildVersionEquals(versionColumnName, originalVersion));
            }
            return stmt.toString();
        } catch (Exception e) {
            e.printStackTrace();
            return originalSql;
        }
    }

    private boolean contains(Update update, String versionColumnName) {
        List<Column> columns = update.getColumns();
        for (Column column : columns) {
            if (column.getColumnName().equalsIgnoreCase(versionColumnName)) {
                return true;
            }
        }
        return false;
    }

    private void buildVersionExpression(Update update, String versionColumnName) {

        List<Column> columns = update.getColumns();
        Column versionColumn = new Column();
        versionColumn.setColumnName(versionColumnName);
        columns.add(versionColumn);

        List<Expression> expressions = update.getExpressions();
        Addition add = new Addition();
        add.setLeftExpression(versionColumn);
        add.setRightExpression(new LongValue(1));
        expressions.add(add);
    }

    private Expression buildVersionEquals(String versionColumnName, Object originalVersion) {
        EqualsTo equal = new EqualsTo();
        Column column = new Column();
        column.setColumnName(versionColumnName);
        equal.setLeftExpression(column);
        LongValue val = new LongValue(originalVersion.toString());
        equal.setRightExpression(val);
        return equal;
    }


    private Class<?> getMapper(MappedStatement ms) {
        String namespace = getMapperNamespace(ms);
        Collection<Class<?>> mappers = ms.getConfiguration().getMapperRegistry().getMappers();
        for (Class<?> clazz : mappers) {
            if (clazz.getName().equals(namespace)) {
                return clazz;
            }
        }
        return null;
    }

    private String getMapperNamespace(MappedStatement ms) {
        String id = ms.getId();
        int pos = id.lastIndexOf(".");
        return id.substring(0, pos);
    }

    private String getMapperShortId(MappedStatement ms) {
        String id = ms.getId();
        int pos = id.lastIndexOf(".");
        return id.substring(pos + 1);
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler || target instanceof ParameterHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {
        if (null != properties && !properties.isEmpty()) {
            props = properties;
        }
        if (props != null) {
            VERSION_COLUMN = props.getProperty("versionColumn", "version");
            VERSION_FIELD = props.getProperty("versionField", "version");
        }
    }
}

   5、源码说明

    一、源码中附有测试案例和使用教程,还有一些功能须要后期完善,如使用h2内存数据库方便测试。

    二、最近在看一些mybatis的源码。想要理解插件的工做原理,须要对mybatis的运行流程熟悉,不然插件可能会破坏mybatis的功能,以后会带来一些mybatis的源码分析。

    源码地址:https://github.com/binary-vi/binary.github.io/tree/master/locker