上一章,讲解了SQL拦截器的,做了一个简单的SQL改造。本章要实现:
以下简单用一个图说明了整个处理过程,红色框住的部分,就是本章要实现的内容:
DataScope.java
DataScope对象里面设置了用于数据权限规则数组。在SQL拦截器中将从这些数据权限规则中获取条件表达式。
package com.luo.chengrui.labs.lab02.annotation;
import com.luo.chengrui.labs.lab02.datapermission.DataPermissionRule;
import java.lang.annotation.*;
/**
* 数据权限过滤注解
*
* @author ruoyi
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface DataScope
{
/**
* 当前类或方法是否开启数据权限
* 即使不添加 @DataPermission 注解,默认是开启状态
* 可通过设置 enable 为 false 禁用
*/
boolean enable() default true;
/**
* 生效的数据权限规则数组,为了以后方便扩展,所以定义为权限解析对象数组
*/
Class<? extends DataPermissionRule>[] includeRules() default {};
}
DataPermissionRule.java
这里定义了一个权限解析接口,如此即可以扩展很多不同类型的权限,如:对部门过滤,本章我们仅实现对部门权限过滤。(若要实现对如预算科目、项目等过滤,可实现该接口,另外还得单独实现对这些数据权限的分配功能。若依仅实现了对部门和人员控制)
package com.luo.chengrui.labs.lab02.datapermission;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import java.util.Set;
/**
* 数据权限规则接口
* 通过实现接口,自定义数据规则。例如说,
*
* @author yudao
*/
public interface DataPermissionRule {
/**
* 根据表名和别名,生成对应的 WHERE / OR 过滤条件
*
* @param tableName 表名
* @param tableAlias 别名,可能为空
* @return 过滤条件 Expression 表达式
*/
Expression getExpression(String tableName, Alias tableAlias);
}
DeptDataPermissionRule.java
部门权限解析实现类,只要SQL语句中有sys_dept表或者dept_id字段,则均可以通过getExpression(String tableName, Alias tableAlias)获取部门数据权限,并生成Express表达式。
ruoyi框架中数据权限分配和权限获取实现,可参考:Springboot管理系统数据权限过滤——ruoyi实现方案
package com.luo.chengrui.labs.lab02.datapermission.dept;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.luo.chengrui.labs.lab02.datapermission.DataPermissionRule;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Column;
import java.util.*;
import java.util.stream.Collectors;
/**
* 基于部门的 {@link DataPermissionRule} 数据权限规则实现
* <p>
* 注意,使用 DeptDataPermissionRule 时,需要保证表中有 dept_id 部门编号的字段,可自定义。
* <p>
*/
@AllArgsConstructor
@Slf4j
public class DeptDataPermissionRule implements DataPermissionRule {
static final Expression EXPRESSION_NULL = new NullValue();
/**
* 基于部门的表字段配置
* 一般情况下,每个表的部门编号字段是 dept_id,通过该配置自定义。
* <p>
* key:表名
* value:字段名
*/
private final Map<String, String> deptColumns = new HashMap<>();
/**
* 所有表名,需要进行权限过滤的表名
*/
private final Set<String> TABLE_NAMES = new HashSet<>();
/**
* 添加需要过滤部门权限的表名和部门ID字段名
* @param tableName
* @param columnName
*/
public void addDeptColumn(String tableName, String columnName) {
deptColumns.put(tableName, columnName);
TABLE_NAMES.add(tableName);
}
/**
* 获取所有需要按部门权限过滤的表
* @return
*/
@Override
public Set<String> getTableNames() {
return TABLE_NAMES;
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
// 情况三,拼接 Dept 和 User 的条件,最后组合
Set<Long> deptIds = new HashSet<>();
//模拟数据,实现业务中需要获取当前用户部门权限,获取到部门id。
deptIds.add(1L);
deptIds.add(2L);
// 配置中包含表时,进行过滤
if (Objects.nonNull(deptColumns.get(tableName))) {
return new InExpression(buildColumn(tableName, tableAlias, deptColumns.get(tableName)),
new ExpressionList(deptIds.stream().map(LongValue::new).collect(Collectors.toList())));
}
return EXPRESSION_NULL;
//
}
/**
* 构建 Column 对象
*
* @param tableName 表名
* @param tableAlias 别名
* @param column 字段名
* @return Column 对象
*/
public static Column buildColumn(String tableName, Alias tableAlias, String column) {
if (tableAlias != null) {
tableName = tableAlias.getName();
}
return new Column(tableName + StringPool.DOT + column);
}
}
DataPermissionContextHolder.java
package com.luo.chengrui.labs.lab02.interceptor;
import com.luo.chengrui.labs.lab02.annotation.DataScope;
import java.util.LinkedList;
import java.util.List;
/**
* {@link DataScope} 注解的 Context 上下文
* 将方法上的注解对象设置到 线程变量里面,在SQL执行拦截器中获取注解对象,根据内容生成相应的权限。
* 告诉SQL执行{DataPermissionDatabaseInterceptor}拦截器,如此方法需要添加权限。
*/
public class DataPermissionContextHolder {
/**
* 使用 List 的原因,可能存在方法的嵌套调用
*/
private static final ThreadLocal<LinkedList<DataScope>> DATA_PERMISSIONS =
ThreadLocal.withInitial(LinkedList::new);
/**
* 获得当前的 DataPermission 注解
*
* @return DataPermission 注解
*/
public static DataScope get() {
return DATA_PERMISSIONS.get().peekLast();
}
/**
* 入栈 DataPermission 注解
*
* @param dataPermission DataPermission 注解
*/
public static void add(DataScope dataPermission) {
DATA_PERMISSIONS.get().addLast(dataPermission);
}
/**
* 出栈 DataPermission 注解
*
* @return DataPermission 注解
*/
public static DataScope remove() {
DataScope dataPermission = DATA_PERMISSIONS.get().removeLast();
// 无元素时,清空 ThreadLocal
if (DATA_PERMISSIONS.get().isEmpty()) {
DATA_PERMISSIONS.remove();
}
return dataPermission;
}
/**
* 获得所有 DataPermission
*
* @return DataPermission 队列
*/
public static List<DataScope> getAll() {
return DATA_PERMISSIONS.get();
}
/**
* 清空上下文
* <p>
* 目前仅仅用于单测
*/
public static void clear() {
DATA_PERMISSIONS.remove();
}
}
DataScopeAnnotationInterceptor.java
方法拦截器:作用是拦截添加了@DataScope注解的方法,获取DataScope对象,放入线程变量。
package com.luo.chengrui.labs.lab02.interceptor;
import com.luo.chengrui.labs.lab02.annotation.DataScope;
import lombok.Getter;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.MethodClassKey;
import org.springframework.core.annotation.AnnotationUtils;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* {@link DataScope} 注解的拦截器
* 1. 在执行方法前,将 @DataPermission 注解入栈
* 2. 在执行方法后,将 @DataPermission 注解出栈
*
* @author yudao
*/
public class DataScopeAnnotationInterceptor implements MethodInterceptor {
/**
* DataPermission 空对象,用于方法无 {@link import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;} 注解时,使用 DATA_PERMISSION_NULL 进行占位
*/
static final DataScope DATA_PERMISSION_NULL = DataScopeAnnotationInterceptor.class.getAnnotation(DataScope.class);
@Getter
private final Map<MethodClassKey, DataScope> dataPermissionCache = new ConcurrentHashMap<>();
@Override
public Object invoke(MethodInvocation methodInvocation) throws Throwable {
Logger log = LoggerFactory.getLogger(DataScopeAnnotationInterceptor.class);
log.debug("DataScopeAnnotationInterceptor 拦截器:" + methodInvocation.getMethod().getName());
// 入栈
DataScope dataPermission = this.findAnnotation(methodInvocation);
if (dataPermission != null) {
DataPermissionContextHolder.add(dataPermission);
}
try {
// 执行逻辑
return methodInvocation.proceed();
} finally {
// 出栈
if (dataPermission != null) {
DataPermissionContextHolder.remove();
}
}
}
private DataScope findAnnotation(MethodInvocation methodInvocation) {
// 1. 从缓存中获取
Method method = methodInvocation.getMethod();
Object targetObject = methodInvocation.getThis();
Class<?> clazz = targetObject != null ? targetObject.getClass() : method.getDeclaringClass();
MethodClassKey methodClassKey = new MethodClassKey(method, clazz);
DataScope dataPermission = dataPermissionCache.get(methodClassKey);
if (dataPermission != null) {
return dataPermission != DATA_PERMISSION_NULL ? dataPermission : null;
}
// 2.1 从方法中获取
dataPermission = AnnotationUtils.findAnnotation(method, DataScope.class);
// 2.2 从类上获取
if (dataPermission == null) {
dataPermission = AnnotationUtils.findAnnotation(clazz, DataScope.class);
}
// 2.3 添加到缓存中
dataPermissionCache.put(methodClassKey, dataPermission != null ? dataPermission : DATA_PERMISSION_NULL);
return dataPermission;
}
}
DataPermissionAnnotationAdvisor.java
定义拦截点,对添加了DataScope注解的类和方法都添加了拦截点。
package com.luo.chengrui.labs.lab02.interceptor;
import com.luo.chengrui.labs.lab02.annotation.DataScope;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.aopalliance.aop.Advice;
import org.springframework.aop.Pointcut;
import org.springframework.aop.support.AbstractPointcutAdvisor;
import org.springframework.aop.support.ComposablePointcut;
import org.springframework.aop.support.annotation.AnnotationMatchingPointcut;
import org.springframework.context.annotation.Role;
import org.springframework.stereotype.Component;
/**
* {@link DataScope} 注解的 Advisor 实现类
*
*/
@Getter
@EqualsAndHashCode(callSuper = true)
public class DataPermissionAnnotationAdvisor extends AbstractPointcutAdvisor {
private final Advice advice;
private final Pointcut pointcut;
public DataPermissionAnnotationAdvisor() {
this.advice = new DataScopeAnnotationInterceptor();
this.pointcut = this.buildPointcut();
}
protected Pointcut buildPointcut() {
Pointcut classPointcut = new AnnotationMatchingPointcut(DataScope.class, true);
Pointcut methodPointcut = new AnnotationMatchingPointcut(null, DataScope.class, true);
return new ComposablePointcut(classPointcut).union(methodPointcut);
}
}
DataPermissionDatabaseInterceptor.java
代码比较多,主要是对sql各种情况的解析和处理,与实现业务关联性很弱。在实际业务扩展时,该类基本不需要进行修改。可能搜索对 getExpression 方法的调用,改造sql的位置。
package com.luo.chengrui.labs.lab02.datapermission;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.Connection;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
* 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, List)} 方法
* 主要是在执行SQL前拦截器,在执行之前可重写SQL
*
* @author yudao
*/
@RequiredArgsConstructor
public class DataPermissionDatabaseInterceptor extends JsqlParserSupport implements InnerInterceptor {
private static final String MYSQL_ESCAPE_CHARACTER = "`";
private final List<DataPermissionRule> dataPermissionRule;
@Getter
private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
@Override // SELECT 场景
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// 获得 Mapper 对应的数据权限的规则
if (mappedStatementCache.noRewritable(ms, dataPermissionRule)) { // 如果无需重写,则跳过
return;
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
try {
// 初始化上下文
ContextHolder.init(dataPermissionRule);
// 处理 SQL
mpBs.sql(parserSingle(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
@Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限)
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
// 获得 Mapper 对应的数据权限的规则
if (mappedStatementCache.noRewritable(ms, dataPermissionRule)) { // 如果无需重写,则跳过
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
try {
// 初始化上下文
ContextHolder.init(dataPermissionRule);
// 处理 SQL
mpBs.sql(parserMulti(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
processSelectBody(select.getSelectBody());
List<WithItem> withItemsList = select.getWithItemsList();
if (!CollectionUtils.isEmpty(withItemsList)) {
withItemsList.forEach(this::processSelectBody);
}
}
/**
* update 语句处理
*/
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
final Table table = update.getTable();
update.setWhere(this.builderExpression(update.getWhere(), table));
}
/**
* delete 语句处理
*/
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
}
// ========== 和 TenantLineInnerInterceptor 一致的逻辑 ==========
protected void processSelectBody(SelectBody selectBody) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
processSelectBody(withItem.getSubSelect().getSelectBody());
} else {
SetOperationList operationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodyList)) {
selectBodyList.forEach(this::processSelectBody);
}
}
}
/**
* 处理 PlainSelect
*/
protected void processPlainSelect(PlainSelect plainSelect) {
//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}
// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);
// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
mainTables = processJoins(mainTables, joins);
}
// 当有 mainTable 时,进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}
private List<Table> processFromItem(FromItem fromItem) {
// 处理括号括起来的表达式
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
List<Table> mainTables = new ArrayList<>();
// 无 join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
mainTables.add(fromTable);
} else if (fromItem instanceof SubJoin) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((SubJoin) fromItem);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
}
/**
* 处理where条件内的子查询
* <p>
* 支持如下:
* 1. in
* 2. =
* 3. >
* 4. <
* 5. >=
* 6. <=
* 7. <>
* 8. EXISTS
* 9. NOT EXISTS
* <p>
* 前提条件:
* 1. 子查询必须放在小括号中
* 2. 子查询一般放在比较操作符的右边
*
* @param where where 条件
*/
protected void processWhereSubSelect(Expression where) {
if (where == null) {
return;
}
if (where instanceof FromItem) {
processOtherFromItem((FromItem) where);
return;
}
if (where.toString().indexOf("SELECT") > 0) {
// 有子查询
if (where instanceof BinaryExpression) {
// 比较符号 , and , or , 等等
BinaryExpression expression = (BinaryExpression) where;
processWhereSubSelect(expression.getLeftExpression());
processWhereSubSelect(expression.getRightExpression());
} else if (where instanceof InExpression) {
// in
InExpression expression = (InExpression) where;
Expression inExpression = expression.getRightExpression();
if (inExpression instanceof SubSelect) {
processSelectBody(((SubSelect) inExpression).getSelectBody());
}
} else if (where instanceof ExistsExpression) {
// exists
ExistsExpression expression = (ExistsExpression) where;
processWhereSubSelect(expression.getRightExpression());
} else if (where instanceof NotExpression) {
// not exists
NotExpression expression = (NotExpression) where;
processWhereSubSelect(expression.getExpression());
} else if (where instanceof Parenthesis) {
Parenthesis expression = (Parenthesis) where;
processWhereSubSelect(expression.getExpression());
}
}
}
protected void processSelectItem(SelectItem selectItem) {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
if (selectExpressionItem.getExpression() instanceof SubSelect) {
processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
} else if (selectExpressionItem.getExpression() instanceof Function) {
processFunction((Function) selectExpressionItem.getExpression());
}
}
}
/**
* 处理函数
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
* <p> fixed gitee pulls/141</p>
*
* @param function
*/
protected void processFunction(Function function) {
ExpressionList parameters = function.getParameters();
if (parameters != null) {
parameters.getExpressions().forEach(expression -> {
if (expression instanceof SubSelect) {
processSelectBody(((SubSelect) expression).getSelectBody());
} else if (expression instanceof Function) {
processFunction((Function) expression);
}
});
}
}
/**
* 处理子查询等
*/
protected void processOtherFromItem(FromItem fromItem) {
// 去除括号
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
} else if (fromItem instanceof ValuesList) {
logger.debug("Perform a subQuery, if you do not give us feedback");
} else if (fromItem instanceof LateralSubSelect) {
LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
if (lateralSubSelect.getSubSelect() != null) {
SubSelect subSelect = lateralSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
}
}
}
/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(SubJoin subJoin) {
List<Table> mainTables = new ArrayList<>();
if (subJoin.getJoinList() != null) {
List<Table> list = processFromItem(subJoin.getLeft());
mainTables.addAll(list);
mainTables = processJoins(mainTables, subJoin.getJoinList());
}
return mainTables;
}
/**
* 处理 joins
*
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables == null) {
mainTables = new ArrayList<>();
} else if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}
//对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem joinItem = join.getRightItem();
// 获取当前 join 的表,subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof SubJoin) {
joinTables = processSubJoin((SubJoin) joinItem);
}
if (joinTables != null) {
// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}
// 当前表是否忽略
Table joinTable = joinTables.get(0);
List<Table> onTables = null;
// 如果不要忽略,且是右连接,则记录下当前表
if (join.isRight()) {
mainTable = joinTable;
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isLeft()) {
onTables = Collections.singletonList(joinTable);
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
}
mainTables = new ArrayList<>();
if (mainTable != null) {
mainTables.add(mainTable);
}
// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
join.setOnExpressions(onExpressions);
leftTable = joinTable;
continue;
}
// 表名压栈,忽略的表压入 null,以便后续不处理
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTableList));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
} else {
processOtherFromItem(joinItem);
leftTable = null;
}
}
return mainTables;
}
// ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========
/**
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param table 单个表
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
return this.builderExpression(currentExpression, Collections.singletonList(table));
}
/**
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param tables 多个表
*/
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 第一步,获得 Table 对应的数据权限条件
Expression dataPermissionExpression = null;
for (Table table : tables) {
// 构建每个表的权限 Expression 条件
Expression expression = buildDataPermissionExpression(table);
if (expression == null) {
continue;
}
// 合并到 dataPermissionExpression 中
dataPermissionExpression = dataPermissionExpression == null ? expression
: new AndExpression(dataPermissionExpression, expression);
}
// 第二步,合并多个 Expression 条件
if (dataPermissionExpression == null) {
return currentExpression;
}
if (currentExpression == null) {
return dataPermissionExpression;
}
// ① 如果表达式为 Or,则需要 (currentExpression) AND dataPermissionExpression
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
}
// ② 如果表达式为 And,则直接返回 where AND dataPermissionExpression
return new AndExpression(currentExpression, dataPermissionExpression);
}
/**
* 构建指定表的数据权限的 Expression 过滤条件
*
* @param table 表
* @return Expression 过滤条件
*/
private Expression buildDataPermissionExpression(Table table) {
// 生成条件
Expression allExpression = null;
for (DataPermissionRule rule : ContextHolder.getRules()) {
// 如果有匹配的规则,说明可重写。
// 为什么不是有 allExpression 非空才重写呢?在生成 column = value 过滤条件时,会因为 value 不存在,导致未重写。
// 这样导致第一次无 value,被标记成无需重写;但是第二次有 value,此时会需要重写。
ContextHolder.setRewrite(true);
// 单条规则的条件
String tableName = getTableName(table);
Expression oneExpress = rule.getExpression(tableName, table.getAlias());
if (oneExpress == null) {
continue;
}
// 拼接到 allExpression 中
allExpression = allExpression == null ? oneExpress
: new AndExpression(allExpression, oneExpress);
}
return allExpression;
}
/**
* 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
*
* @param ms MappedStatement
*/
private void addMappedStatementCache(MappedStatement ms) {
if (ContextHolder.getRewrite()) {
return;
}
// 无重写,进行添加
mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
}
/**
* SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
*
* @author yudao
*/
static final class ContextHolder {
/**
* 该 {@link MappedStatement} 对应的规则
*/
private static final ThreadLocal<List<DataPermissionRule>> RULES = ThreadLocal.withInitial(Collections::emptyList);
/**
* SQL 是否进行重写
*/
private static final ThreadLocal<Boolean> REWRITE = ThreadLocal.withInitial(() -> Boolean.FALSE);
public static void init(List<DataPermissionRule> rules) {
RULES.set(rules);
REWRITE.set(false);
}
public static void clear() {
RULES.remove();
REWRITE.remove();
}
public static boolean getRewrite() {
return REWRITE.get();
}
public static void setRewrite(boolean rewrite) {
REWRITE.set(rewrite);
}
public static List<DataPermissionRule> getRules() {
return RULES.get();
}
}
/**
* {@link MappedStatement} 缓存
* 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
* 如果无效,则可以避免 SQL 的解析,加快速度
*
* @author yudao
*/
static final class MappedStatementCache {
/**
* 指定数据权限规则,对指定 MappedStatement 无需重写(不生效)的缓存
* <p>
* value:{@link MappedStatement#getId()} 编号
*/
@Getter
private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();
/**
* 判断是否无需重写
* ps:虽然有点中文式英语,但是容易读懂即可
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
* @return 是否无需重写
*/
public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
// 如果规则为空,说明无需重写
if (org.springframework.util.CollectionUtils.isEmpty(rules)) {
return true;
}
// 任一规则不在 noRewritableMap 中,则说明可能需要重写
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (mappedStatementIds != null && !mappedStatementIds.stream().anyMatch(item -> item.equals(ms.getId()))) {
return false;
}
return false;
}
return true;
}
/**
* 添加无需重写的 MappedStatement
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
*/
public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (CollectionUtils.isEmpty(mappedStatementIds)) {
mappedStatementIds.add(ms.getId());
} else {
noRewritableMappedStatements.put(rule.getClass(), Arrays.stream(new String[]{ms.getId()}).collect(Collectors.toSet()));
}
}
}
/**
* 清空缓存
* 目前主要提供给单元测试
*/
public void clear() {
noRewritableMappedStatements.clear();
}
}
/**
* 获得 Table 对应的表名
* <p>
* 兼容 MySQL 转义表名 `t_xxx`
*
* @param table 表
* @return 去除转移字符后的表名
*/
public static String getTableName(Table table) {
String tableName = table.getName();
if (tableName.startsWith(MYSQL_ESCAPE_CHARACTER) && tableName.endsWith(MYSQL_ESCAPE_CHARACTER)) {
tableName = tableName.substring(1, tableName.length() - 1);
}
return tableName;
}
}
将三个Bean进行添加:
关于springboot 拦截器之Advisor不生效问题 可参考该文章。
package com.luo.chengrui.labs.lab02.config;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.luo.chengrui.labs.lab02.datapermission.*;
import com.luo.chengrui.labs.lab02.datapermission.dept.DeptDataPermissionRule;
import com.luo.chengrui.labs.lab02.interceptor.DataPermissionAnnotationAdvisor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Role;
import java.util.ArrayList;
import java.util.List;
/**
* 数据权限的自动配置类
*
* @author yudao
*/
@Configuration
public class DataPermissionConfiguration {
@Bean
public MybatisPlusInterceptor mybatisPlusInterceptor(List<DataPermissionRule> dataPermissionRule) {
MybatisPlusInterceptor mybatisPlusInterceptor = new MybatisPlusInterceptor();
// 分页插件
// mybatisPlusInterceptor.addInnerInterceptor(new PaginationInnerInterceptor());
//添加权限拦截器。
DataPermissionDatabaseInterceptor inner = new DataPermissionDatabaseInterceptor(dataPermissionRule);
List<InnerInterceptor> inners = new ArrayList<>(mybatisPlusInterceptor.getInterceptors());
inners.add(0, inner);
mybatisPlusInterceptor.setInterceptors(inners);
// MybatisDatabaseInterceptor mybatisDatabaseInterceptor = new MybatisDatabaseInterceptor();
// List<InnerInterceptor> inners = new ArrayList<>(mybatisPlusInterceptor.getInterceptors());
// inners.add(0, mybatisDatabaseInterceptor);
// mybatisPlusInterceptor.setInterceptors(inners);
return mybatisPlusInterceptor;
}
/**
* 初始化部门权限 bean 。
*
* @return
*/
@Bean
public DeptDataPermissionRule deptDataPermissionRule() {
// 创建 DeptDataPermissionRule 对象
DeptDataPermissionRule rule = new DeptDataPermissionRule();
// 用户表 需要作权限过滤,用户表中部门字段为dept_id。
rule.addDeptColumn("users","dept_id");
// 请假流程表,也需要按部门权限过滤;但wf_leave表中部门字段为:deptid,则应按如下配置
rule.addDeptColumn("wf_leave","deptid");
return rule;
}
/**
* 权限注解拦截器。
*
* @return
*/
@Bean
@Role(2)
public DataPermissionAnnotationAdvisor dataPermissionAnnotationAdvisor() {
return new DataPermissionAnnotationAdvisor();
}
}
UserService.java
定义了用户列表查询接口,添加了@DataScope注解
package com.luo.chengrui.labs.lab02.service;
import com.luo.chengrui.labs.lab02.annotation.DataScope;
import com.luo.chengrui.labs.lab02.dataobject.UserDO;
import com.luo.chengrui.labs.lab02.mapper.UserMapper;
import org.springframework.aop.framework.AopContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
/**
* @author
* @version 1.0.0
* @description
* @createTime 2023/07/21
*/
@Service
public class UserService {
@Autowired
private UserMapper userMapper;
private UserService self() {
return (UserService) AopContext.currentProxy();
}
@DataScope
public List<UserDO> selectList() {
return userMapper.selectList();
}
}
UserMapper.java
package com.luo.chengrui.labs.lab02.mapper;
import com.luo.chengrui.labs.lab02.dataobject.UserDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface UserMapper {
UserDO selectById(@Param("id") Integer id);
List<UserDO> selectList();
}
UserMapper.xml
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.luo.chengrui.labs.lab02.mapper.UserMapper">
<sql id="FIELDS">
id
, username
</sql>
<select id="selectById" parameterType="Integer" resultType="UserDO">
SELECT
<include refid="FIELDS"/>
FROM users
WHERE id = #{id}
</select>
<select id="selectList" resultType="UserDo">
SELECT
<include refid="FIELDS"/>
FROM users
</select>
</mapper>
DataPermissionConfiguration.java 类中 DeptDataPermissionRule 对象初始化设置用户表权限,配置如下:
@Bean
public DeptDataPermissionRule deptDataPermissionRule() {
// 创建 DeptDataPermissionRule 对象
DeptDataPermissionRule rule = new DeptDataPermissionRule();
// 用户表 需要作权限过滤,用户表中部门字段为dept_id。
rule.addDeptColumn("users","dept_id");
// 请假流程表,也需要按部门权限过滤;但wf_leave表中部门字段为:deptid,则应按如下配置
return rule;
}
运行结果如下:
不添加表字段,或者添加的表不是users表时,再执行用户查询。
@Bean
public DeptDataPermissionRule deptDataPermissionRule() {
// 创建 DeptDataPermissionRule 对象
DeptDataPermissionRule rule = new DeptDataPermissionRule();
// 用户表 需要作权限过滤,用户表中部门字段为dept_id。
//rule.addDeptColumn("users","dept_id");
// 请假流程表,也需要按部门权限过滤;但wf_leave表中部门字段为:deptid,则应按如下配置
rule.addDeptColumn("wf_leavel","deptid");
return rule;
}
运行结果如下:
以上几个步骤完成后,就可以对业务无入侵完成数据权限控制。
当有空时,再完善DeptDataPermissionRule中对ruoyi部门权限实现的代码。