前言
我们为啥要需要自定义注入方法?
mp中内置了不少默认的方法,但是有些时候这些方法还不够方便,例如说批量添加,mp原生的批量在添加上万条数据的时候是相当的慢,还有就是根据实体查询这些方法,因为mp的wrapper是不能进行序列化传输的,所以就需要根据实体查询这种方法,这时候首先先看mp的文档的里面这个SQL注入器这个地方,文档说可以继承继承抽象类 AbstractSqlInjector,然后实现
行废话不多说,直接上代码
1.首先定义常量工具方法
package icu.mhb.base.service.core.mp.constant;
/**
* @author mahuibo
* @Title: SqlConditionConstant
* @time 3/15/21 3:36 PM
*/
public class SqlConditionConstant {
private static final String STR_CONDITION = " <if test=\"%s != null and %s != ''\"> and %s=#{%s}</if> ";
private static final String NOT_NULL_CONDITION = " <if test=\"%s != null \"> and %s=#{%s}</if> ";
private static final String SET_STR_CONDITION = " <if test=\"%s != null and %s != ''\"> %s=#{%s},</if> ";
private static final String SET_NOT_NULL_CONDITION = " <if test=\"%s != null \"> %s=#{%s},</if> ";
/**
* 获取string类型条件
*
* @param property 属性名
* @param field 字段名
* @return str
*/
public static String getStrCondition(String property, String field) {
return String.format(STR_CONDITION, property, property, field, property);
}
/**
* 获取不为null类型条件
*
* @param property 属性名
* @param field 字段名
* @return str
*/
public static String getNotNullCondition(String property, String field) {
return String.format(NOT_NULL_CONDITION, property, field, property);
}
/**
* 获取set string类型条件
*
* @param property 属性名
* @param field 字段名
* @return str
*/
public static String getSetStrCondition(String property, String field) {
return String.format(SET_STR_CONDITION, property, property, field, property);
}
/**
* 获取set不为null类型条件
*
* @param property 属性名
* @param field 字段名
* @return str
*/
public static String getSetNotNullCondition(String property, String field) {
return String.format(SET_NOT_NULL_CONDITION, property, field, property);
}
}
这些是构建动态SQL脚本需要用的
2.定义mapper
package icu.mhb.base.service.core.mp.mapper;
import icu.mhb.mybatisplus.plugln.base.mapper.JoinBaseMapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
/**
* @author mahuibo
* @Title: CustomizeMapper
* @time 3/15/21 3:08 PM
*/
public interface CustomizeMapper<T> extends JoinBaseMapper<T> {
/**
* 根据实体不为空的查询列表
*
* @param entity 实体
* @return List<T>
*/
List<T> selectByEntity(T entity);
/**
* 根据实体不为空的查询对象
*
* @param entity 实体
* @return T
*/
T getByEntity(T entity);
/**
* 批量添加
*
* @param list list对象
* @return 影响条数
*/
int saveLotSize(@Param("list") List<T> list);
/**
* 根据实体不为空的查询总数
*
* @param entity 实体
* @return 总数
*/
int getCountByEntity(T entity);
/**
* 删除根据实体对象
*
* @param entity 实体
* @return 影响条数
*/
int removeByEntity(T entity);
/**
* 修改不为空的数据
*
* @param entity 实体对象
* @return 影响条数
*/
int updateIsNotEmpty(T entity);
}
这里你们使用的时候需要把继承的JoinBaseMapper换成BaseMapper,因为我这里是继承的mybatis-plus-join一款很不错的mp多表插件
文档地址
或者小程序(马汇博的博客)中搜索mybatis-plus-join mp的优雅多表插件也可以找到
3.service
package icu.mhb.base.service.core.mp.service;
import icu.mhb.mybatisplus.plugln.base.service.JoinIService;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
/**
* @author mahuibo
* @Title: CustomizeService
* @time 3/15/21 3:08 PM
*/
public interface CustomizeService<T> extends JoinIService<T> {
/**
* 根据实体不为空的数据查询列表
*
* @param entity 实体
* @return List<T>
*/
List<T> listByEntity(T entity);
/**
* 批量添加
*
* @param list 实体列表
* @return 是否成功
*/
@Transactional(propagation = Propagation.REQUIRED, rollbackFor = Exception.class)
boolean saveLotSize(List<T> list);
/**
* 根据实体不为空的查询对象
*
* @param entity 实体
* @return T
*/
T getByEntity(T entity);
/**
* 根据实体不为空的查询总数
*
* @param entity 实体
* @return 总数
*/
int getCountByEntity(T entity);
/**
* 删除根据实体对象
*
* @param entity 实体
* @return 是否成功
*/
boolean removeByEntity(T entity);
/**
* 修改不为空的数据
*
* @param entity 实体对象
* @return 是否成功
*/
boolean updateIsNotEmpty(T entity);
/**
* 批量修改不为空的数据
*
* @param entityList 实体对象
* @param batchSize 修改批次
* @return 是否成功
*/
boolean updateIsNotEmptyBach(List<T> entityList, int batchSize);
default boolean updateIsNotEmptyBach(List<T> entityList) {
return updateIsNotEmptyBach(entityList, DEFAULT_BATCH_SIZE);
}
}
serviceImpl
package icu.mhb.base.service.core.mp.service.impl;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.ObjectUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import com.google.common.collect.Lists;
import icu.mhb.base.service.core.mp.mapper.CustomizeMapper;
import icu.mhb.base.service.core.mp.service.CustomizeService;
import icu.mhb.mybatisplus.plugln.base.service.impl.JoinServiceImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author mahuibo
* @Title: CustomizeService
* @time 3/15/21 3:08 PM
*/
public class CustomizeServiceImpl<M extends CustomizeMapper<T>, T> extends JoinServiceImpl<M, T> implements CustomizeService<T> {
@Autowired
protected M customizeBaseMapper;
private static final int BATCH_NUMBER = 200;
@Override
public M getBaseMapper() {
return this.customizeBaseMapper;
}
@Override
public List<T> listByEntity(T entity) {
return customizeBaseMapper.selectByEntity(entity);
}
@Override
public boolean saveLotSize(List<T> list) {
AtomicInteger count = new AtomicInteger();
if (list.size() > 2000) {
Lists.partition(list, BATCH_NUMBER).parallelStream()
.forEach(ts -> count.set(customizeBaseMapper.saveLotSize(ts)));
} else {
for (List<T> ts : Lists.partition(list, BATCH_NUMBER)) {
count.set(customizeBaseMapper.saveLotSize(ts));
}
}
return SqlHelper.retBool(count.get());
}
@Override
public T getByEntity(T entity) {
return customizeBaseMapper.getByEntity(entity);
}
@Override
public int getCountByEntity(T entity) {
return customizeBaseMapper.getCountByEntity(entity);
}
@Override
public boolean removeByEntity(T entity) {
if (ObjectUtils.isNull(entity)) {
throw ExceptionUtils.mpe("实体不能为空null");
}
customizeBaseMapper.removeByEntity(entity);
return true;
}
@Override
public boolean updateIsNotEmpty(T entity) {
return SqlHelper.retBool(customizeBaseMapper.updateIsNotEmpty(entity));
}
@Override
@Transactional(rollbackFor = Exception.class)
public boolean updateIsNotEmptyBach(List<T> entityList, int batchSize) {
String sqlStatement = mapperClass.getName() + StringPool.DOT + "updateIsNotEmpty";
return executeBatch(entityList, batchSize, (sqlSession, entity) -> sqlSession.update(sqlStatement, entity));
}
}
最重要的SQL注入部分
package icu.mhb.base.service.core.mp.injector;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import icu.mhb.base.datatype.StringUtil;
import icu.mhb.base.service.core.mp.constant.SqlConditionConstant;
import org.apache.ibatis.executor.keygen.Jdbc3KeyGenerator;
import org.apache.ibatis.executor.keygen.KeyGenerator;
import org.apache.ibatis.executor.keygen.NoKeyGenerator;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* @author mahuibo
* @Title: CustomizeSqlInjector
* @time 3/15/21 3:40 PM
*/
public class CustomizeSqlInjector {
public static List<AbstractMethod> getCustomizeMapper() {
// 根据不为空的实体查询
AbstractMethod selectByEntity = new AbstractMethod() {
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> SELECT %s FROM %s <where>%s</where> </script>", tableInfo.getAllSqlSelect(), tableInfo.getTableName(), getConditionSql(tableInfo)), modelClass);
return this.addSelectMappedStatementForOther(mapperClass, "selectByEntity", sqlSource, modelClass);
}
};
// 根据不为空的实体查询单个
AbstractMethod getByEntity = new AbstractMethod() {
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> SELECT %s FROM %s <where>%s</where> limit 1 </script>", tableInfo.getAllSqlSelect(), tableInfo.getTableName(), getConditionSql(tableInfo)), modelClass);
return this.addSelectMappedStatementForOther(mapperClass, "getByEntity", sqlSource, modelClass);
}
};
// 根据不为空的实体查询count
AbstractMethod getCountByEntity = new AbstractMethod() {
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> SELECT count(1) FROM %s <where>%s</where> </script>", tableInfo.getTableName(), getConditionSql(tableInfo)), modelClass);
return this.addSelectMappedStatementForOther(mapperClass, "getCountByEntity", sqlSource, int.class);
}
};
// 根据不为空的实体删除
AbstractMethod removeByEntity = new AbstractMethod() {
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> DELETE FROM %s <where>%s</where> </script>", tableInfo.getTableName(), getConditionSql(tableInfo)
), modelClass);
return this.addDeleteMappedStatement(mapperClass, "removeByEntity", sqlSource);
}
};
// 根据不为空的实体修改
AbstractMethod updateIsNotEmpty = new AbstractMethod() {
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> UPDATE %s <set> %s </set> WHERE %s=#{%s} </script>", tableInfo.getTableName(), getSetSql(tableInfo), tableInfo.getKeyColumn(), tableInfo.getKeyProperty()
), modelClass);
return addUpdateMappedStatement(mapperClass, modelClass, "updateIsNotEmpty", sqlSource);
}
};
// 批量添加
AbstractMethod saveLotSize = new AbstractMethod() {
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
// 获取字段列表
List<String> columnList = tableInfo.getFieldList().stream().map(TableFieldInfo::getColumn).collect(Collectors.toList());
// 值列表
String values = tableInfo.getFieldList().stream().map(TableFieldInfo::getProperty)
.map(s -> StringUtil.format("#{item.%s}", s))
.collect(Collectors.joining(StringPool.COMMA, StringPool.LEFT_BRACKET, StringPool.RIGHT_BRACKET));
SqlSource sqlSource = languageDriver.createSqlSource(configuration, StringUtil.format("<script> INSERT INTO `%s` %s VALUES <foreach collection=\"list\" item=\"item\" separator=\",\"> %s </foreach> </script>", tableInfo.getTableName(), StringUtil.backticksList(columnList), values), modelClass);
KeyGenerator keyGenerator = new NoKeyGenerator();
String keyProperty = null;
String keyColumn = null;
// 表包含主键处理逻辑,如果不包含主键当普通字段处理
if (StringUtils.isNotBlank(tableInfo.getKeyProperty())) {
if (tableInfo.getIdType() == IdType.AUTO) {
/* 自增主键 */
keyGenerator = new Jdbc3KeyGenerator();
}
keyProperty = tableInfo.getKeyProperty();
keyColumn = tableInfo.getKeyColumn();
}
return this.addInsertMappedStatement(mapperClass, modelClass, "saveLotSize", sqlSource, keyGenerator, keyProperty, keyColumn);
}
};
return new ArrayList<>(
Arrays.asList(selectByEntity, getCountByEntity, getByEntity, saveLotSize, removeByEntity, updateIsNotEmpty)
);
}
/**
* 获取查询条件SQL脚本
*/
private static String getConditionSql(TableInfo tableInfo) {
StringBuilder sb = new StringBuilder();
if (StringUtil.isNotNull(tableInfo.getKeyProperty())) {
// 是否是string类型
if (tableInfo.getKeyType().equals(String.class)) {
sb.append("\n").append(SqlConditionConstant.getStrCondition(tableInfo.getKeyProperty(), tableInfo.getKeyColumn()));
} else {
sb.append("\n").append(SqlConditionConstant.getNotNullCondition(tableInfo.getKeyProperty(), tableInfo.getKeyColumn()));
}
}
for (TableFieldInfo fieldInfo : tableInfo.getFieldList()) {
// 如果是对象直接跳出
if (fieldInfo.getPropertyType().equals(Object.class)) {
continue;
}
// 是否是string类型
if (fieldInfo.getPropertyType().equals(String.class)) {
sb.append("\n").append(SqlConditionConstant.getStrCondition(fieldInfo.getProperty(), fieldInfo.getColumn()));
continue;
}
sb.append("\n").append(SqlConditionConstant.getNotNullCondition(fieldInfo.getProperty(), fieldInfo.getColumn()));
}
return sb.toString();
}
/**
* 获取修改SQL的 set部分的SQL脚本
*/
private static String getSetSql(TableInfo tableInfo) {
StringBuilder sb = new StringBuilder();
for (TableFieldInfo fieldInfo : tableInfo.getFieldList()) {
// 如果是对象直接跳出
if (fieldInfo.getPropertyType().equals(Object.class)) {
continue;
}
sb.append("\n").append(SqlConditionConstant.getSetNotNullCondition(fieldInfo.getProperty(), fieldInfo.getColumn()));
}
return sb.toString();
}
}
最后一步往mp的注入器中注入咱们自定义的方法
package com.zzfy.main.config;
import com.baomidou.mybatisplus.core.injector.AbstractMethod;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import icu.mhb.base.service.core.mp.injector.CustomizeSqlInjector;
import icu.mhb.mybatisplus.plugln.injector.JoinDefaultSqlInjector;
import org.springframework.context.annotation.Configuration;
import java.util.List;
@Configuration
public class MyBatisPlusConfig extends JoinDefaultSqlInjector {
@Override
public List<AbstractMethod> getMethodList(Class<?> mapperClass, TableInfo tableInfo) {
List<AbstractMethod> methodList = super.getMethodList(mapperClass, tableInfo);
methodList.addAll(CustomizeSqlInjector.getCustomizeMapper());
return methodList;
}
}
如果没有使用mybatis plus join的情况下,继承DefaultSqlInjector
到这里,基于mp的自定义注入就完成了!转载请标注原作者文章地址
评论区