MyBatis 拦截器实现审计字段自动填充

in Code with 0 comment

0x00 要解决的问题

在日常代码编写的过程中,需要对一些关键数据记录创建人、创建时间、修改人、修改时间等一些审计字段的处理。每次手动赋值,都是重复劳动,而且不小心还忘记赋值,就比较尴尬了。

0x01 基本环境

项目是基于 Spring boot 3.0.2 搭建的,使用 mybatis-spring-boot-starter:3.0.0 类库。

0x02 拦截器实现

拦截器实现主要分为两种,原理是一样的,只是精细的程度不一样。

实现思路是编写一个基类(BaseEntity),在基类中添加相应的审计字段。具体实体类继承该基类。在拦截器中判断当前执行的操作,如果是插入或者更新,则获取执行的参数(存储着具体的实体),判断其是否为基类(继承了基类 BaseEntity),如果是则获取审计字段并赋值。

第一种,实现全面的拦截,在 mapper xml 配置文件中,执行 insert 或 update 语句操作时,不需要添加审计字段。

第二种,只针对实体对象的审计字段赋值。

下面为具体代码实现:

基类

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;

import java.io.Serializable;
import java.time.LocalDateTime;

@Data
@SuperBuilder
@AllArgsConstructor
@NoArgsConstructor
public class BaseEntity implements Serializable {
    private Long id;
    private LocalDateTime createdAt;
    private String createdBy;
    private LocalDateTime updatedAt;
    private String updatedBy;
    private Boolean isDeleted;
}

第一种全面接管不需要编写审计字段

import com.vkarz.provider.persistence.model.BaseEntity;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;

import java.lang.reflect.Field;
import java.time.LocalDateTime;
import java.util.*;

/**
 * 审计字段自动填充 SQL 部分字段也会追加
 */
// @Component("AuditFieldInterceptor")
@Intercepts({@Signature(method = "update", type = Executor.class, args = {MappedStatement.class, Object.class})})
@Slf4j
public class AuditFieldInterceptor implements Interceptor {

    private final Set<String> cache = new HashSet<>();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object params = args[1];
        SqlSource newSqlSource = null;

        String currentUsername = "ZR21090131";

        if (ms.getSqlCommandType() == SqlCommandType.INSERT) {
            //重写sql语句增加审计字段
            String sql = ms.getSqlSource().getBoundSql(null).getSql();
            String newSql = insertAudit(sql);

            //创建新的参数隐射
            List<ParameterMapping> newParameterMappings = new ArrayList<>();
            ParameterMapping createdAtMapping = new ParameterMapping.Builder(ms.getConfiguration(), "createdAt", Object.class).build();
            newParameterMappings.add(createdAtMapping);

            ParameterMapping createdByMapping = new ParameterMapping.Builder(ms.getConfiguration(), "createdBy", Object.class).build();
            newParameterMappings.add(createdByMapping);

            ParameterMapping updatedAtMapping = new ParameterMapping.Builder(ms.getConfiguration(), "updatedAt", Object.class).build();
            newParameterMappings.add(updatedAtMapping);

            ParameterMapping updateByMapping = new ParameterMapping.Builder(ms.getConfiguration(), "updatedBy", Object.class).build();
            newParameterMappings.add(updateByMapping);

            ParameterMapping isDeletedMapping = new ParameterMapping.Builder(ms.getConfiguration(), "isDeleted", Object.class).build();
            newParameterMappings.add(isDeletedMapping);

            newParameterMappings.addAll(ms.getSqlSource().getBoundSql(null).getParameterMappings());

            //创建新的SqlSource
            newSqlSource = new StaticSqlSource(ms.getConfiguration(), newSql, newParameterMappings);

            //填充审计字段
            if (params instanceof BaseEntity) {
                BaseEntity baseEntity = (BaseEntity) params;
                baseEntity.setCreatedAt(LocalDateTime.now());
                baseEntity.setCreatedBy(currentUsername);
                baseEntity.setUpdatedAt(LocalDateTime.now());
                baseEntity.setUpdatedBy(currentUsername);
                baseEntity.setIsDeleted(false);
            } else if (params instanceof MapperMethod.ParamMap) {
                @SuppressWarnings("unchecked")
                MapperMethod.ParamMap<Object> map = (MapperMethod.ParamMap) params;
                map.put("createdAt", LocalDateTime.now());
                map.put("createdBy", currentUsername);
                map.put("updatedAt", LocalDateTime.now());
                map.put("updatedBy", currentUsername);
                map.put("isDeleted", false);
            }

            //修改sqlSource
            if (!cache.contains(ms.getId())) {
                Field msField = invocation.getArgs()[0].getClass().getDeclaredField("sqlSource");
                msField.setAccessible(true);
                msField.set(invocation.getArgs()[0], newSqlSource);
                cache.add(ms.getId());
            }
        }
        if (ms.getSqlCommandType() == SqlCommandType.UPDATE) {
            //重写sql语句增加审计字段
            String source = ms.getSqlSource().getBoundSql(null).getSql();
            String newSql = updateAudit(source);

            //创建新的参数隐射
            List<ParameterMapping> newParameterMappings = new ArrayList<>();
            ParameterMapping updatedAtMapping = new ParameterMapping.Builder(ms.getConfiguration(), "updatedAt", Object.class).build();
            newParameterMappings.add(updatedAtMapping);

            ParameterMapping updateByMapping = new ParameterMapping.Builder(ms.getConfiguration(), "updatedBy", Object.class).build();
            newParameterMappings.add(updateByMapping);
            newParameterMappings.addAll(ms.getSqlSource().getBoundSql(null).getParameterMappings());

            //创建新的SqlSource
            newSqlSource = new StaticSqlSource(ms.getConfiguration(), newSql, newParameterMappings);

            //填充审计字段
            if (params instanceof BaseEntity) {
                BaseEntity baseEntity = (BaseEntity) params;
                baseEntity.setUpdatedAt(LocalDateTime.now());
                baseEntity.setUpdatedBy(currentUsername);
            } else {
                @SuppressWarnings("unchecked")
                MapperMethod.ParamMap<Object> map = (MapperMethod.ParamMap) params;
                map.put("updatedAt", LocalDateTime.now());
                map.put("updatedBy", currentUsername);
            }

            //修改sqlSource
            if (!cache.contains(ms.getId())) {
                Field msField = invocation.getArgs()[0].getClass().getDeclaredField("sqlSource");
                msField.setAccessible(true);
                msField.set(invocation.getArgs()[0], newSqlSource);
                cache.add(ms.getId());
            }
        }

        //修改参数值
        invocation.getArgs()[1] = params;
        return invocation.proceed();
    }

    public String insertAudit(String source) {
        char[] chars = source.toCharArray();
        StringBuilder newSql = new StringBuilder();
        boolean afterFormalParam = false;
        for (char c : chars) {
            String cStr = String.valueOf(c);
            newSql.append(cStr);
            if (cStr.equals("(")) {
                if (afterFormalParam) {
                    newSql.append("?,?,?,?,?,");
                } else {
                    newSql.append("created_at, created_by, updated_at, updated_by, is_deleted, ");
                }
            }
            if (cStr.equals(")")) {
                afterFormalParam = true;
            }
        }
        return newSql.toString();
    }

    public String updateAudit(String source) {
        String[] strArray = source.split("\\s+");
        StringBuilder newSql = new StringBuilder();
        boolean afterSet = false;
        for (String s : strArray) {
            if (afterSet) {
                s = "updated_at=?, updated_by=?," + s;
                afterSet = false;
            }
            if (s.equals("set") || s.equals("SET")) {
                afterSet = true;
            }
            newSql.append(s).append(" ");
        }
        return newSql.toString();
    }
}

第二种需要编写审计字段只处理实体

import com.vkarz.provider.persistence.model.BaseEntity;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;

/**
 * 只处理实体审计字段的值 需要在 SQL 中指定审计字段
 */
@Component("AuditFieldLiteInterceptor")
@Intercepts({@Signature(
        type= Executor.class,
        method = "update",
        args = {MappedStatement.class ,Object.class})})
@Slf4j
public class AuditFieldLiteInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        // 获取 SQL 类型
        SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
        // 获取参数
        Object parameter = invocation.getArgs()[1];
        String currentUsername = "ZR21090131";

        if (sqlCommandType == SqlCommandType.INSERT) {
            log.debug("execute insert operation parameter: {}", parameter);
            //填充审计字段
            if (parameter instanceof BaseEntity baseEntity) {
                baseEntity.setCreatedAt(LocalDateTime.now());
                baseEntity.setCreatedBy(currentUsername);
                baseEntity.setUpdatedAt(LocalDateTime.now());
                baseEntity.setUpdatedBy(currentUsername);
                baseEntity.setIsDeleted(false);
            } else if (parameter instanceof MapperMethod.ParamMap) {
                @SuppressWarnings("unchecked")
                MapperMethod.ParamMap<Object> map = (MapperMethod.ParamMap) parameter;
                map.put("createdAt", LocalDateTime.now());
                map.put("createdBy", currentUsername);
                map.put("updatedAt", LocalDateTime.now());
                map.put("updatedBy", currentUsername);
                map.put("isDeleted", false);
            }
        }
        if (sqlCommandType == SqlCommandType.UPDATE) {
            log.debug("execute update operate parameter: {}", parameter);
            //填充审计字段
            if (parameter instanceof BaseEntity baseEntity) {
                baseEntity.setUpdatedAt(LocalDateTime.now());
                baseEntity.setUpdatedBy(currentUsername);
            } else if (parameter instanceof MapperMethod.ParamMap) {
                @SuppressWarnings("unchecked")
                MapperMethod.ParamMap<Object> map = (MapperMethod.ParamMap) parameter;
                map.put("updatedAt", LocalDateTime.now());
                map.put("updatedBy", currentUsername);
            }
        }

        return invocation.proceed();
    }
}

0x03 配置拦截器

编写好拦截器,要使其生效。可能通过两种方式。

一种是通过配置 mybatis-config.xml 文件,在 plugins 节点处添加插件(通用性比较强)。

<configuration>
  <plugins>
    <plugin interceptor="com.vkarz.provider.persistence.interceptor.AuditFieldInterceptor" />
  </plugins>
</configuration>

另一种是,项目中使用了 Spring Boot ,则可以通过在 Interceptor 类上添加 @Component 注解来实现的。

0x04 参考资料

Comments are closed.