侧边栏壁纸
博主头像
晚风博主等级

马汇博嘛。

  • 累计撰写 17 篇文章
  • 累计创建 13 个标签
  • 累计收到 20 条评论

一文带你理解mybatis plus自定义baseMapper集成常用方法

晚风
2022-02-07 / 0 评论 / 25 点赞 / 2,325 阅读 / 13,377 字
温馨提示:
本文最后更新于 2022-02-08,若内容或图片失效,请留言反馈。部分素材来自网络,若不小心影响到您的利益,请联系我们删除。

前言

我们为啥要需要自定义注入方法?
mp中内置了不少默认的方法,但是有些时候这些方法还不够方便,例如说批量添加,mp原生的批量在添加上万条数据的时候是相当的慢,还有就是根据实体查询这些方法,因为mp的wrapper是不能进行序列化传输的,所以就需要根据实体查询这种方法,这时候首先先看mp的文档的里面这个SQL注入器这个地方,文档说可以继承继承抽象类 AbstractSqlInjector,然后实现

image.png

行废话不多说,直接上代码

1.首先定义常量工具方法

package icu.mhb.base.service.core.mp.constant;


/**
 * @author mahuibo
 * @Title: SqlConditionConstant
 * @time 3/15/21 3:36 PM
 */
public class SqlConditionConstant {

    private static final String STR_CONDITION = " <if test=\"%s != null and %s != ''\"> and %s=#{%s}</if> ";

    private static final String NOT_NULL_CONDITION = " <if test=\"%s != null \"> and %s=#{%s}</if> ";

    private static final String SET_STR_CONDITION = " <if test=\"%s != null and %s != ''\"> %s=#{%s},</if> ";

    private static final String SET_NOT_NULL_CONDITION = " <if test=\"%s != null \"> %s=#{%s},</if> ";

    /**
     * 获取string类型条件
     *
     * @param property 属性名
     * @param field    字段名
     * @return str
     */
    public static String getStrCondition(String property, String field) {
        return String.format(STR_CONDITION, property, property, field, property);
    }

    /**
     * 获取不为null类型条件
     *
     * @param property 属性名
     * @param field    字段名
     * @return str
     */
    public static String getNotNullCondition(String property, String field) {
        return String.format(NOT_NULL_CONDITION, property, field, property);
    }

    /**
     * 获取set string类型条件
     *
     * @param property 属性名
     * @param field    字段名
     * @return str
     */
    public static String getSetStrCondition(String property, String field) {
        return String.format(SET_STR_CONDITION, property, property, field, property);
    }

    /**
     * 获取set不为null类型条件
     *
     * @param property 属性名
     * @param field    字段名
     * @return str
     */
    public static String getSetNotNullCondition(String property, String field) {
        return String.format(SET_NOT_NULL_CONDITION, property, field, property);
    }


}

这些是构建动态SQL脚本需要用的

2.定义mapper

package icu.mhb.base.service.core.mp.mapper;
import icu.mhb.mybatisplus.plugln.base.mapper.JoinBaseMapper;
import org.apache.ibatis.annotations.Param;

import java.util.List;

/**
 * @author mahuibo
 * @Title: CustomizeMapper
 * @time 3/15/21 3:08 PM
 */
public interface CustomizeMapper<T> extends JoinBaseMapper<T> {

    /**
     * 根据实体不为空的查询列表
     *
     * @param entity 实体
     * @return List<T>
     */
    List<T> selectByEntity(T entity);

    /**
     * 根据实体不为空的查询对象
     *
     * @param entity 实体
     * @return T
     */
    T getByEntity(T entity);

    /**
     * 批量添加
     *
     * @param list list对象
     * @return 影响条数
     */
    int saveLotSize(@Param("list") List<T> list);


    /**
     * 根据实体不为空的查询总数
     *
     * @param entity 实体
     * @return 总数
     */
    int getCountByEntity(T entity);

    /**
     * 删除根据实体对象
     *
     * @param entity 实体
     * @return 影响条数
     */
    int removeByEntity(T entity);

    /**
     * 修改不为空的数据
     *
     * @param entity 实体对象
     * @return 影响条数
     */
    int updateIsNotEmpty(T entity);

}

这里你们使用的时候需要把继承的JoinBaseMapper换成BaseMapper,因为我这里是继承的mybatis-plus-join一款很不错的mp多表插件
文档地址
或者小程序(马汇博的博客)中搜索mybatis-plus-join mp的优雅多表插件也可以找到

3.service

package icu.mhb.base.service.core.mp.service;

import icu.mhb.mybatisplus.plugln.base.service.JoinIService;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;

import java.util.List;

/**
 * @author mahuibo
 * @Title: CustomizeService
 * @time 3/15/21 3:08 PM
 */
public interface CustomizeService<T> extends JoinIService<T> {

    /**
     * 根据实体不为空的数据查询列表
     *
     * @param entity 实体
     * @return List<T>
     */
    List<T> listByEntity(T entity);

    /**
     * 批量添加
     *
     * @param list 实体列表
     * @return 是否成功
     */
    @Transactional(propagation = Propagation.REQUIRED, rollbackFor = Exception.class)
    boolean saveLotSize(List<T> list);


    /**
     * 根据实体不为空的查询对象
     *
     * @param entity 实体
     * @return T
     */
    T getByEntity(T entity);

    /**
     * 根据实体不为空的查询总数
     *
     * @param entity 实体
     * @return 总数
     */
    int getCountByEntity(T entity);

    /**
     * 删除根据实体对象
     *
     * @param entity 实体
     * @return 是否成功
     */
    boolean removeByEntity(T entity);

    /**
     * 修改不为空的数据
     *
     * @param entity 实体对象
     * @return 是否成功
     */
    boolean updateIsNotEmpty(T entity);

    /**
     * 批量修改不为空的数据
     *
     * @param entityList 实体对象
     * @param batchSize  修改批次
     * @return 是否成功
     */
    boolean updateIsNotEmptyBach(List<T> entityList, int batchSize);

    default boolean updateIsNotEmptyBach(List<T> entityList) {
        return updateIsNotEmptyBach(entityList, DEFAULT_BATCH_SIZE);
    }

}

serviceImpl

package icu.mhb.base.service.core.mp.service.impl;

import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.ObjectUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import com.google.common.collect.Lists;
import icu.mhb.base.service.core.mp.mapper.CustomizeMapper;
import icu.mhb.base.service.core.mp.service.CustomizeService;
import icu.mhb.mybatisplus.plugln.base.service.impl.JoinServiceImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Transactional;

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author mahuibo
 * @Title: CustomizeService
 * @time 3/15/21 3:08 PM
 */
public class CustomizeServiceImpl<M extends CustomizeMapper<T>, T> extends JoinServiceImpl<M, T> implements CustomizeService<T> {

    @Autowired
    protected M customizeBaseMapper;

    private static final int BATCH_NUMBER = 200;

    @Override
    public M getBaseMapper() {
        return this.customizeBaseMapper;
    }

    @Override
    public List<T> listByEntity(T entity) {
        return customizeBaseMapper.selectByEntity(entity);
    }

    @Override
    public boolean saveLotSize(List<T> list) {
        AtomicInteger count = new AtomicInteger();

        if (list.size() > 2000) {
            Lists.partition(list, BATCH_NUMBER).parallelStream()
                    .forEach(ts -> count.set(customizeBaseMapper.saveLotSize(ts)));
        } else {
            for (List<T> ts : Lists.partition(list, BATCH_NUMBER)) {
                count.set(customizeBaseMapper.saveLotSize(ts));
            }
        }
        return SqlHelper.retBool(count.get());
    }

    @Override
    public T getByEntity(T entity) {
        return customizeBaseMapper.getByEntity(entity);
    }

    @Override
    public int getCountByEntity(T entity) {
        return customizeBaseMapper.getCountByEntity(entity);
    }

    @Override
    public boolean removeByEntity(T entity) {

        if (ObjectUtils.isNull(entity)) {
            throw ExceptionUtils.mpe("实体不能为空null");
        }

        customizeBaseMapper.removeByEntity(entity);
        return true;
    }

    @Override
    public boolean updateIsNotEmpty(T entity) {
        return SqlHelper.retBool(customizeBaseMapper.updateIsNotEmpty(entity));
    }

    @Override
    @Transactional(rollbackFor = Exception.class)
    public boolean updateIsNotEmptyBach(List<T> entityList, int batchSize) {
        String sqlStatement = mapperClass.getName() + StringPool.DOT + "updateIsNotEmpty";
        return executeBatch(entityList, batchSize, (sqlSession, entity) -> sqlSession.update(sqlStatement, entity));
    }

}

最重要的SQL注入部分

package icu.mhb.base.service.core.mp.injector;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import icu.mhb.base.datatype.StringUtil;
import icu.mhb.base.service.core.mp.constant.SqlConditionConstant;
import org.apache.ibatis.executor.keygen.Jdbc3KeyGenerator;
import org.apache.ibatis.executor.keygen.KeyGenerator;
import org.apache.ibatis.executor.keygen.NoKeyGenerator;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @author mahuibo
 * @Title: CustomizeSqlInjector
 * @time 3/15/21 3:40 PM
 */
public class CustomizeSqlInjector {

    public static List<AbstractMethod> getCustomizeMapper() {

        // 根据不为空的实体查询
        AbstractMethod selectByEntity = new AbstractMethod() {

            @Override
            public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {

                SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> SELECT %s FROM %s <where>%s</where> </script>", tableInfo.getAllSqlSelect(), tableInfo.getTableName(), getConditionSql(tableInfo)), modelClass);
                return this.addSelectMappedStatementForOther(mapperClass, "selectByEntity", sqlSource, modelClass);
            }
        };

        // 根据不为空的实体查询单个
        AbstractMethod getByEntity = new AbstractMethod() {

            @Override
            public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {

                SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> SELECT %s FROM %s <where>%s</where> limit 1 </script>", tableInfo.getAllSqlSelect(), tableInfo.getTableName(), getConditionSql(tableInfo)), modelClass);
                return this.addSelectMappedStatementForOther(mapperClass, "getByEntity", sqlSource, modelClass);
            }
        };

        // 根据不为空的实体查询count
        AbstractMethod getCountByEntity = new AbstractMethod() {

            @Override
            public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {

                SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> SELECT count(1) FROM %s <where>%s</where> </script>", tableInfo.getTableName(), getConditionSql(tableInfo)), modelClass);
                return this.addSelectMappedStatementForOther(mapperClass, "getCountByEntity", sqlSource, int.class);
            }
        };

        // 根据不为空的实体删除
        AbstractMethod removeByEntity = new AbstractMethod() {

            @Override
            public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {

                SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> DELETE FROM %s <where>%s</where> </script>", tableInfo.getTableName(), getConditionSql(tableInfo)
                ), modelClass);
                return this.addDeleteMappedStatement(mapperClass, "removeByEntity", sqlSource);
            }
        };

        // 根据不为空的实体修改
        AbstractMethod updateIsNotEmpty = new AbstractMethod() {

            @Override
            public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {

                SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> UPDATE %s <set> %s </set> WHERE %s=#{%s}  </script>", tableInfo.getTableName(), getSetSql(tableInfo), tableInfo.getKeyColumn(), tableInfo.getKeyProperty()
                ), modelClass);


                return addUpdateMappedStatement(mapperClass, modelClass, "updateIsNotEmpty", sqlSource);
            }
        };


        // 批量添加
        AbstractMethod saveLotSize = new AbstractMethod() {

            @Override
            public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {

                // 获取字段列表
                List<String> columnList = tableInfo.getFieldList().stream().map(TableFieldInfo::getColumn).collect(Collectors.toList());
                // 值列表
                String values = tableInfo.getFieldList().stream().map(TableFieldInfo::getProperty)
                        .map(s -> StringUtil.format("#{item.%s}", s))
                        .collect(Collectors.joining(StringPool.COMMA, StringPool.LEFT_BRACKET, StringPool.RIGHT_BRACKET));

                SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> INSERT INTO `%s` %s VALUES <foreach collection=\"list\" item=\"item\" separator=\",\"> %s </foreach> </script>", tableInfo.getTableName(), StringUtil.backticksList(columnList), values), modelClass);
                KeyGenerator keyGenerator = new NoKeyGenerator();
                String keyProperty = null;
                String keyColumn = null;
                // 表包含主键处理逻辑,如果不包含主键当普通字段处理
                if (StringUtils.isNotBlank(tableInfo.getKeyProperty())) {
                    if (tableInfo.getIdType() == IdType.AUTO) {
                        /* 自增主键 */
                        keyGenerator = new Jdbc3KeyGenerator();
                    }
                    keyProperty = tableInfo.getKeyProperty();
                    keyColumn = tableInfo.getKeyColumn();
                }

                return this.addInsertMappedStatement(mapperClass, modelClass, "saveLotSize", sqlSource, keyGenerator, keyProperty, keyColumn);
            }
        };

        return new ArrayList<>(
                Arrays.asList(selectByEntity, getCountByEntity, getByEntity, saveLotSize, removeByEntity, updateIsNotEmpty)
        );
    }


    /**
     * 获取查询条件SQL脚本
     */
    private static String getConditionSql(TableInfo tableInfo) {
        StringBuilder sb = new StringBuilder();

        if (StringUtil.isNotNull(tableInfo.getKeyProperty())) {
            // 是否是string类型
            if (tableInfo.getKeyType().equals(String.class)) {
                sb.append("\n").append(SqlConditionConstant.getStrCondition(tableInfo.getKeyProperty(), tableInfo.getKeyColumn()));
            } else {
                sb.append("\n").append(SqlConditionConstant.getNotNullCondition(tableInfo.getKeyProperty(), tableInfo.getKeyColumn()));
            }
        }

        for (TableFieldInfo fieldInfo : tableInfo.getFieldList()) {

            // 如果是对象直接跳出
            if (fieldInfo.getPropertyType().equals(Object.class)) {
                continue;
            }
            // 是否是string类型
            if (fieldInfo.getPropertyType().equals(String.class)) {
                sb.append("\n").append(SqlConditionConstant.getStrCondition(fieldInfo.getProperty(), fieldInfo.getColumn()));
                continue;
            }

            sb.append("\n").append(SqlConditionConstant.getNotNullCondition(fieldInfo.getProperty(), fieldInfo.getColumn()));
        }

        return sb.toString();
    }

    /**
     * 获取修改SQL的 set部分的SQL脚本
     */
    private static String getSetSql(TableInfo tableInfo) {
        StringBuilder sb = new StringBuilder();

        for (TableFieldInfo fieldInfo : tableInfo.getFieldList()) {

            // 如果是对象直接跳出
            if (fieldInfo.getPropertyType().equals(Object.class)) {
                continue;
            }
            sb.append("\n").append(SqlConditionConstant.getSetNotNullCondition(fieldInfo.getProperty(), fieldInfo.getColumn()));
        }

        return sb.toString();
    }

}

最后一步往mp的注入器中注入咱们自定义的方法

package com.zzfy.main.config;

import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import icu.mhb.base.service.core.mp.injector.CustomizeSqlInjector;
import icu.mhb.mybatisplus.plugln.injector.JoinDefaultSqlInjector;
import org.springframework.context.annotation.Configuration;

import java.util.List;

@Configuration
public class MyBatisPlusConfig extends JoinDefaultSqlInjector {

    @Override
    public List<AbstractMethod> getMethodList(Class<?> mapperClass, TableInfo tableInfo) {
        List<AbstractMethod> methodList = super.getMethodList(mapperClass, tableInfo);
        methodList.addAll(CustomizeSqlInjector.getCustomizeMapper());
        return methodList;
    }

}

如果没有使用mybatis plus join的情况下,继承DefaultSqlInjector

到这里,基于mp的自定义注入就完成了!转载请标注原作者文章地址

25

评论区