Springboot管理系统数据权限过滤(三)——0业务入侵实现部门数据权限过滤

发布时间:2023年12月19日

上一章,讲解了SQL拦截器的,做了一个简单的SQL改造。本章要实现:

  • 仅对指定的service方法内执行的SQL进行处理;
  • 完成对部门权限的过滤;

以下简单用一个图说明了整个处理过程,红色框住的部分,就是本章要实现的内容:

  1. Spring注解拦截器,该拦截器的目标是对添加了@DataScope注解的方法,作用是解析中DataScope对象,将放到线程变量的权限过滤对象栈中。以便在SQL拦截器中获取权限对象进行解析;
  2. DataScope栈,需要实现一个线程栈,没有业务逻辑,仅为传递作用;
  3. SQL拦截器内,解析DataScope对象获得有权限的部门ID,然后拼接过滤条件;
  4. 另外DataScope权限注解也需要添加;
    在这里插入图片描述
    直接上干货:

定义DataSopce对象

DataScope.java
DataScope对象里面设置了用于数据权限规则数组。在SQL拦截器中将从这些数据权限规则中获取条件表达式。

package com.luo.chengrui.labs.lab02.annotation;

import com.luo.chengrui.labs.lab02.datapermission.DataPermissionRule;

import java.lang.annotation.*;

/**
 * 数据权限过滤注解
 *
 * @author ruoyi
 */
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface DataScope
{


    /**
     * 当前类或方法是否开启数据权限
     * 即使不添加 @DataPermission 注解,默认是开启状态
     * 可通过设置 enable 为 false 禁用
     */
    boolean enable() default true;

    /**
     * 生效的数据权限规则数组,为了以后方便扩展,所以定义为权限解析对象数组
     */
    Class<? extends DataPermissionRule>[] includeRules() default {};
}

DataPermissionRule.java
这里定义了一个权限解析接口,如此即可以扩展很多不同类型的权限,如:对部门过滤,本章我们仅实现对部门权限过滤。(若要实现对如预算科目、项目等过滤,可实现该接口,另外还得单独实现对这些数据权限的分配功能。若依仅实现了对部门和人员控制)

package com.luo.chengrui.labs.lab02.datapermission;

import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;

import java.util.Set;

/**
 * 数据权限规则接口
 * 通过实现接口,自定义数据规则。例如说,
 *
 * @author yudao
 */
public interface DataPermissionRule {


    /**
     * 根据表名和别名,生成对应的 WHERE / OR 过滤条件
     *
     * @param tableName 表名
     * @param tableAlias 别名,可能为空
     * @return 过滤条件 Expression 表达式
     */
    Expression getExpression(String tableName, Alias tableAlias);

}

DeptDataPermissionRule.java
部门权限解析实现类,只要SQL语句中有sys_dept表或者dept_id字段,则均可以通过getExpression(String tableName, Alias tableAlias)获取部门数据权限,并生成Express表达式。

ruoyi框架中数据权限分配和权限获取实现,可参考:Springboot管理系统数据权限过滤——ruoyi实现方案

package com.luo.chengrui.labs.lab02.datapermission.dept;

import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.luo.chengrui.labs.lab02.datapermission.DataPermissionRule;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Column;

import java.util.*;
import java.util.stream.Collectors;

/**
 * 基于部门的 {@link DataPermissionRule} 数据权限规则实现
 * <p>
 * 注意,使用 DeptDataPermissionRule 时,需要保证表中有 dept_id 部门编号的字段,可自定义。
 * <p>
 */
@AllArgsConstructor
@Slf4j
public class DeptDataPermissionRule implements DataPermissionRule {


    static final Expression EXPRESSION_NULL = new NullValue();


    /**
     * 基于部门的表字段配置
     * 一般情况下,每个表的部门编号字段是 dept_id,通过该配置自定义。
     * <p>
     * key:表名
     * value:字段名
     */
    private final Map<String, String> deptColumns = new HashMap<>();
    /**
     * 所有表名,需要进行权限过滤的表名
     */
    private final Set<String> TABLE_NAMES = new HashSet<>();
    /**
     * 添加需要过滤部门权限的表名和部门ID字段名
     * @param tableName
     * @param columnName
     */
    public void addDeptColumn(String tableName, String columnName) {
        deptColumns.put(tableName, columnName);
        TABLE_NAMES.add(tableName);
    }


    /**
     * 获取所有需要按部门权限过滤的表
     * @return
     */
    @Override
    public Set<String> getTableNames() {
        return TABLE_NAMES;
    }

    @Override
    public Expression getExpression(String tableName, Alias tableAlias) {

        // 情况三,拼接 Dept 和 User 的条件,最后组合
        Set<Long> deptIds = new HashSet<>();
        //模拟数据,实现业务中需要获取当前用户部门权限,获取到部门id。
        deptIds.add(1L);
        deptIds.add(2L);
        // 配置中包含表时,进行过滤
        if (Objects.nonNull(deptColumns.get(tableName))) {
            return new InExpression(buildColumn(tableName, tableAlias, deptColumns.get(tableName)),
                    new ExpressionList(deptIds.stream().map(LongValue::new).collect(Collectors.toList())));
        }
        return EXPRESSION_NULL;
        //


    }

    /**
     * 构建 Column 对象
     *
     * @param tableName  表名
     * @param tableAlias 别名
     * @param column     字段名
     * @return Column 对象
     */
    public static Column buildColumn(String tableName, Alias tableAlias, String column) {
        if (tableAlias != null) {
            tableName = tableAlias.getName();
        }
        return new Column(tableName + StringPool.DOT + column);
    }
}


数据权限线程变量

DataPermissionContextHolder.java

package com.luo.chengrui.labs.lab02.interceptor;

import com.luo.chengrui.labs.lab02.annotation.DataScope;

import java.util.LinkedList;
import java.util.List;

/**
 * {@link DataScope} 注解的 Context 上下文
 * 将方法上的注解对象设置到 线程变量里面,在SQL执行拦截器中获取注解对象,根据内容生成相应的权限。
 * 告诉SQL执行{DataPermissionDatabaseInterceptor}拦截器,如此方法需要添加权限。
 */
public class DataPermissionContextHolder {

    /**
     * 使用 List 的原因,可能存在方法的嵌套调用
     */
    private static final ThreadLocal<LinkedList<DataScope>> DATA_PERMISSIONS =
            ThreadLocal.withInitial(LinkedList::new);

    /**
     * 获得当前的 DataPermission 注解
     *
     * @return DataPermission 注解
     */
    public static DataScope get() {
        return DATA_PERMISSIONS.get().peekLast();
    }

    /**
     * 入栈 DataPermission 注解
     *
     * @param dataPermission DataPermission 注解
     */
    public static void add(DataScope dataPermission) {
        DATA_PERMISSIONS.get().addLast(dataPermission);
    }

    /**
     * 出栈 DataPermission 注解
     *
     * @return DataPermission 注解
     */
    public static DataScope remove() {
        DataScope dataPermission = DATA_PERMISSIONS.get().removeLast();
        // 无元素时,清空 ThreadLocal
        if (DATA_PERMISSIONS.get().isEmpty()) {
            DATA_PERMISSIONS.remove();
        }
        return dataPermission;
    }

    /**
     * 获得所有 DataPermission
     *
     * @return DataPermission 队列
     */
    public static List<DataScope> getAll() {
        return DATA_PERMISSIONS.get();
    }

    /**
     * 清空上下文
     * <p>
     * 目前仅仅用于单测
     */
    public static void clear() {
        DATA_PERMISSIONS.remove();
    }

}

定义DataScope注解拦截器

DataScopeAnnotationInterceptor.java
方法拦截器:作用是拦截添加了@DataScope注解的方法,获取DataScope对象,放入线程变量。

package com.luo.chengrui.labs.lab02.interceptor;

import com.luo.chengrui.labs.lab02.annotation.DataScope;
import lombok.Getter;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.MethodClassKey;
import org.springframework.core.annotation.AnnotationUtils;

import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * {@link DataScope} 注解的拦截器
 * 1. 在执行方法前,将 @DataPermission 注解入栈
 * 2. 在执行方法后,将 @DataPermission 注解出栈
 *
 * @author yudao
 */
public class DataScopeAnnotationInterceptor implements MethodInterceptor {

    /**
     * DataPermission 空对象,用于方法无 {@link import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;} 注解时,使用 DATA_PERMISSION_NULL 进行占位
     */
    static final DataScope DATA_PERMISSION_NULL = DataScopeAnnotationInterceptor.class.getAnnotation(DataScope.class);

    @Getter
    private final Map<MethodClassKey, DataScope> dataPermissionCache = new ConcurrentHashMap<>();

    @Override
    public Object invoke(MethodInvocation methodInvocation) throws Throwable {
        Logger log = LoggerFactory.getLogger(DataScopeAnnotationInterceptor.class);
        log.debug("DataScopeAnnotationInterceptor 拦截器:" + methodInvocation.getMethod().getName());
        // 入栈
        DataScope dataPermission = this.findAnnotation(methodInvocation);
        if (dataPermission != null) {
            DataPermissionContextHolder.add(dataPermission);
        }
        try {
            // 执行逻辑
            return methodInvocation.proceed();
        } finally {
            // 出栈
            if (dataPermission != null) {
                DataPermissionContextHolder.remove();
            }
        }
    }

    private DataScope findAnnotation(MethodInvocation methodInvocation) {
        // 1. 从缓存中获取
        Method method = methodInvocation.getMethod();
        Object targetObject = methodInvocation.getThis();
        Class<?> clazz = targetObject != null ? targetObject.getClass() : method.getDeclaringClass();
        MethodClassKey methodClassKey = new MethodClassKey(method, clazz);
        DataScope dataPermission = dataPermissionCache.get(methodClassKey);
        if (dataPermission != null) {
            return dataPermission != DATA_PERMISSION_NULL ? dataPermission : null;
        }

        // 2.1 从方法中获取
        dataPermission = AnnotationUtils.findAnnotation(method, DataScope.class);
        // 2.2 从类上获取
        if (dataPermission == null) {
            dataPermission = AnnotationUtils.findAnnotation(clazz, DataScope.class);
        }
        // 2.3 添加到缓存中
        dataPermissionCache.put(methodClassKey, dataPermission != null ? dataPermission : DATA_PERMISSION_NULL);
        return dataPermission;
    }

}

DataPermissionAnnotationAdvisor.java
定义拦截点,对添加了DataScope注解的类和方法都添加了拦截点。

package com.luo.chengrui.labs.lab02.interceptor;

import com.luo.chengrui.labs.lab02.annotation.DataScope;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.aopalliance.aop.Advice;
import org.springframework.aop.Pointcut;
import org.springframework.aop.support.AbstractPointcutAdvisor;
import org.springframework.aop.support.ComposablePointcut;
import org.springframework.aop.support.annotation.AnnotationMatchingPointcut;
import org.springframework.context.annotation.Role;
import org.springframework.stereotype.Component;

/**
 * {@link DataScope} 注解的 Advisor 实现类
 *
 */
@Getter
@EqualsAndHashCode(callSuper = true)
public class DataPermissionAnnotationAdvisor extends AbstractPointcutAdvisor {

    private final Advice advice;

    private final Pointcut pointcut;

    public DataPermissionAnnotationAdvisor() {
        this.advice = new DataScopeAnnotationInterceptor();
        this.pointcut = this.buildPointcut();
    }

    protected Pointcut buildPointcut() {
        Pointcut classPointcut = new AnnotationMatchingPointcut(DataScope.class, true);
        Pointcut methodPointcut = new AnnotationMatchingPointcut(null, DataScope.class, true);
        return new ComposablePointcut(classPointcut).union(methodPointcut);
    }

}

SQL拦截器改造sql。

DataPermissionDatabaseInterceptor.java
代码比较多,主要是对sql各种情况的解析和处理,与实现业务关联性很弱。在实际业务扩展时,该类基本不需要进行修改。可能搜索对 getExpression 方法的调用,改造sql的位置。

package com.luo.chengrui.labs.lab02.datapermission;

import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.sql.Connection;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
 * 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, List)} 方法
 * 主要是在执行SQL前拦截器,在执行之前可重写SQL
 *
 * @author yudao
 */
@RequiredArgsConstructor
public class DataPermissionDatabaseInterceptor extends JsqlParserSupport implements InnerInterceptor {

    private static final String MYSQL_ESCAPE_CHARACTER = "`";
    private final List<DataPermissionRule> dataPermissionRule;

    @Getter
    private final MappedStatementCache mappedStatementCache = new MappedStatementCache();

    @Override // SELECT 场景
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        // 获得 Mapper 对应的数据权限的规则
        if (mappedStatementCache.noRewritable(ms, dataPermissionRule)) { // 如果无需重写,则跳过
            return;
        }

        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        try {
            // 初始化上下文
            ContextHolder.init(dataPermissionRule);
            // 处理 SQL
            mpBs.sql(parserSingle(mpBs.sql(), null));
        } finally {
            // 添加是否需要重写的缓存
            addMappedStatementCache(ms);
            // 清空上下文
            ContextHolder.clear();
        }
    }

    @Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限)
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpSh.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            // 获得 Mapper 对应的数据权限的规则
            if (mappedStatementCache.noRewritable(ms, dataPermissionRule)) { // 如果无需重写,则跳过
                return;
            }

            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
            try {
                // 初始化上下文
                ContextHolder.init(dataPermissionRule);
                // 处理 SQL
                mpBs.sql(parserMulti(mpBs.sql(), null));
            } finally {
                // 添加是否需要重写的缓存
                addMappedStatementCache(ms);
                // 清空上下文
                ContextHolder.clear();
            }
        }
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        processSelectBody(select.getSelectBody());
        List<WithItem> withItemsList = select.getWithItemsList();
        if (!CollectionUtils.isEmpty(withItemsList)) {
            withItemsList.forEach(this::processSelectBody);
        }
    }

    /**
     * update 语句处理
     */
    @Override
    protected void processUpdate(Update update, int index, String sql, Object obj) {
        final Table table = update.getTable();
        update.setWhere(this.builderExpression(update.getWhere(), table));
    }

    /**
     * delete 语句处理
     */
    @Override
    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
    }

    // ========== 和 TenantLineInnerInterceptor 一致的逻辑 ==========

    protected void processSelectBody(SelectBody selectBody) {
        if (selectBody == null) {
            return;
        }
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            processSelectBody(withItem.getSubSelect().getSelectBody());
        } else {
            SetOperationList operationList = (SetOperationList) selectBody;
            List<SelectBody> selectBodyList = operationList.getSelects();
            if (CollectionUtils.isNotEmpty(selectBodyList)) {
                selectBodyList.forEach(this::processSelectBody);
            }
        }
    }

    /**
     * 处理 PlainSelect
     */
    protected void processPlainSelect(PlainSelect plainSelect) {
        //#3087 github
        List<SelectItem> selectItems = plainSelect.getSelectItems();
        if (CollectionUtils.isNotEmpty(selectItems)) {
            selectItems.forEach(this::processSelectItem);
        }

        // 处理 where 中的子查询
        Expression where = plainSelect.getWhere();
        processWhereSubSelect(where);

        // 处理 fromItem
        FromItem fromItem = plainSelect.getFromItem();
        List<Table> list = processFromItem(fromItem);
        List<Table> mainTables = new ArrayList<>(list);

        // 处理 join
        List<Join> joins = plainSelect.getJoins();
        if (CollectionUtils.isNotEmpty(joins)) {
            mainTables = processJoins(mainTables, joins);
        }

        // 当有 mainTable 时,进行 where 条件追加
        if (CollectionUtils.isNotEmpty(mainTables)) {
            plainSelect.setWhere(builderExpression(where, mainTables));
        }
    }

    private List<Table> processFromItem(FromItem fromItem) {
        // 处理括号括起来的表达式
        while (fromItem instanceof ParenthesisFromItem) {
            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
        }

        List<Table> mainTables = new ArrayList<>();
        // 无 join 时的处理逻辑
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            mainTables.add(fromTable);
        } else if (fromItem instanceof SubJoin) {
            // SubJoin 类型则还需要添加上 where 条件
            List<Table> tables = processSubJoin((SubJoin) fromItem);
            mainTables.addAll(tables);
        } else {
            // 处理下 fromItem
            processOtherFromItem(fromItem);
        }
        return mainTables;
    }

    /**
     * 处理where条件内的子查询
     * <p>
     * 支持如下:
     * 1. in
     * 2. =
     * 3. >
     * 4. <
     * 5. >=
     * 6. <=
     * 7. <>
     * 8. EXISTS
     * 9. NOT EXISTS
     * <p>
     * 前提条件:
     * 1. 子查询必须放在小括号中
     * 2. 子查询一般放在比较操作符的右边
     *
     * @param where where 条件
     */
    protected void processWhereSubSelect(Expression where) {
        if (where == null) {
            return;
        }
        if (where instanceof FromItem) {
            processOtherFromItem((FromItem) where);
            return;
        }
        if (where.toString().indexOf("SELECT") > 0) {
            // 有子查询
            if (where instanceof BinaryExpression) {
                // 比较符号 , and , or , 等等
                BinaryExpression expression = (BinaryExpression) where;
                processWhereSubSelect(expression.getLeftExpression());
                processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof InExpression) {
                // in
                InExpression expression = (InExpression) where;
                Expression inExpression = expression.getRightExpression();
                if (inExpression instanceof SubSelect) {
                    processSelectBody(((SubSelect) inExpression).getSelectBody());
                }
            } else if (where instanceof ExistsExpression) {
                // exists
                ExistsExpression expression = (ExistsExpression) where;
                processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof NotExpression) {
                // not exists
                NotExpression expression = (NotExpression) where;
                processWhereSubSelect(expression.getExpression());
            } else if (where instanceof Parenthesis) {
                Parenthesis expression = (Parenthesis) where;
                processWhereSubSelect(expression.getExpression());
            }
        }
    }

    protected void processSelectItem(SelectItem selectItem) {
        if (selectItem instanceof SelectExpressionItem) {
            SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
            if (selectExpressionItem.getExpression() instanceof SubSelect) {
                processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
            } else if (selectExpressionItem.getExpression() instanceof Function) {
                processFunction((Function) selectExpressionItem.getExpression());
            }
        }
    }

    /**
     * 处理函数
     * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
     * <p> fixed gitee pulls/141</p>
     *
     * @param function
     */
    protected void processFunction(Function function) {
        ExpressionList parameters = function.getParameters();
        if (parameters != null) {
            parameters.getExpressions().forEach(expression -> {
                if (expression instanceof SubSelect) {
                    processSelectBody(((SubSelect) expression).getSelectBody());
                } else if (expression instanceof Function) {
                    processFunction((Function) expression);
                }
            });
        }
    }

    /**
     * 处理子查询等
     */
    protected void processOtherFromItem(FromItem fromItem) {
        // 去除括号
        while (fromItem instanceof ParenthesisFromItem) {
            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
        }

        if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelectBody(subSelect.getSelectBody());
            }
        } else if (fromItem instanceof ValuesList) {
            logger.debug("Perform a subQuery, if you do not give us feedback");
        } else if (fromItem instanceof LateralSubSelect) {
            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
            if (lateralSubSelect.getSubSelect() != null) {
                SubSelect subSelect = lateralSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelectBody(subSelect.getSelectBody());
                }
            }
        }
    }

    /**
     * 处理 sub join
     *
     * @param subJoin subJoin
     * @return Table subJoin 中的主表
     */
    private List<Table> processSubJoin(SubJoin subJoin) {
        List<Table> mainTables = new ArrayList<>();
        if (subJoin.getJoinList() != null) {
            List<Table> list = processFromItem(subJoin.getLeft());
            mainTables.addAll(list);
            mainTables = processJoins(mainTables, subJoin.getJoinList());
        }
        return mainTables;
    }

    /**
     * 处理 joins
     *
     * @param mainTables 可以为 null
     * @param joins      join 集合
     * @return List<Table> 右连接查询的 Table 列表
     */
    private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
        // join 表达式中最终的主表
        Table mainTable = null;
        // 当前 join 的左表
        Table leftTable = null;

        if (mainTables == null) {
            mainTables = new ArrayList<>();
        } else if (mainTables.size() == 1) {
            mainTable = mainTables.get(0);
            leftTable = mainTable;
        }

        //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
        Deque<List<Table>> onTableDeque = new LinkedList<>();
        for (Join join : joins) {
            // 处理 on 表达式
            FromItem joinItem = join.getRightItem();

            // 获取当前 join 的表,subJoint 可以看作是一张表
            List<Table> joinTables = null;
            if (joinItem instanceof Table) {
                joinTables = new ArrayList<>();
                joinTables.add((Table) joinItem);
            } else if (joinItem instanceof SubJoin) {
                joinTables = processSubJoin((SubJoin) joinItem);
            }

            if (joinTables != null) {

                // 如果是隐式内连接
                if (join.isSimple()) {
                    mainTables.addAll(joinTables);
                    continue;
                }

                // 当前表是否忽略
                Table joinTable = joinTables.get(0);

                List<Table> onTables = null;
                // 如果不要忽略,且是右连接,则记录下当前表
                if (join.isRight()) {
                    mainTable = joinTable;
                    if (leftTable != null) {
                        onTables = Collections.singletonList(leftTable);
                    }
                } else if (join.isLeft()) {
                    onTables = Collections.singletonList(joinTable);
                } else if (join.isInner()) {
                    if (mainTable == null) {
                        onTables = Collections.singletonList(joinTable);
                    } else {
                        onTables = Arrays.asList(mainTable, joinTable);
                    }
                    mainTable = null;
                }

                mainTables = new ArrayList<>();
                if (mainTable != null) {
                    mainTables.add(mainTable);
                }

                // 获取 join 尾缀的 on 表达式列表
                Collection<Expression> originOnExpressions = join.getOnExpressions();
                // 正常 join on 表达式只有一个,立刻处理
                if (originOnExpressions.size() == 1 && onTables != null) {
                    List<Expression> onExpressions = new LinkedList<>();
                    onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
                    join.setOnExpressions(onExpressions);
                    leftTable = joinTable;
                    continue;
                }
                // 表名压栈,忽略的表压入 null,以便后续不处理
                onTableDeque.push(onTables);
                // 尾缀多个 on 表达式的时候统一处理
                if (originOnExpressions.size() > 1) {
                    Collection<Expression> onExpressions = new LinkedList<>();
                    for (Expression originOnExpression : originOnExpressions) {
                        List<Table> currentTableList = onTableDeque.poll();
                        if (CollectionUtils.isEmpty(currentTableList)) {
                            onExpressions.add(originOnExpression);
                        } else {
                            onExpressions.add(builderExpression(originOnExpression, currentTableList));
                        }
                    }
                    join.setOnExpressions(onExpressions);
                }
                leftTable = joinTable;
            } else {
                processOtherFromItem(joinItem);
                leftTable = null;
            }
        }

        return mainTables;
    }

    // ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========

    /**
     * 处理条件
     *
     * @param currentExpression 当前 where 条件
     * @param table             单个表
     */
    protected Expression builderExpression(Expression currentExpression, Table table) {
        return this.builderExpression(currentExpression, Collections.singletonList(table));
    }

    /**
     * 处理条件
     *
     * @param currentExpression 当前 where 条件
     * @param tables            多个表
     */
    protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
        // 没有表需要处理直接返回
        if (CollectionUtils.isEmpty(tables)) {
            return currentExpression;
        }

        // 第一步,获得 Table 对应的数据权限条件
        Expression dataPermissionExpression = null;
        for (Table table : tables) {
            // 构建每个表的权限 Expression 条件
            Expression expression = buildDataPermissionExpression(table);
            if (expression == null) {
                continue;
            }
            // 合并到 dataPermissionExpression 中
            dataPermissionExpression = dataPermissionExpression == null ? expression
                    : new AndExpression(dataPermissionExpression, expression);
        }

        // 第二步,合并多个 Expression 条件
        if (dataPermissionExpression == null) {
            return currentExpression;
        }
        if (currentExpression == null) {
            return dataPermissionExpression;
        }
        // ① 如果表达式为 Or,则需要 (currentExpression) AND dataPermissionExpression
        if (currentExpression instanceof OrExpression) {
            return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
        }
        // ② 如果表达式为 And,则直接返回 where AND dataPermissionExpression
        return new AndExpression(currentExpression, dataPermissionExpression);
    }

    /**
     * 构建指定表的数据权限的 Expression 过滤条件
     *
     * @param table 表
     * @return Expression 过滤条件
     */
    private Expression buildDataPermissionExpression(Table table) {
        // 生成条件
        Expression allExpression = null;
        for (DataPermissionRule rule : ContextHolder.getRules()) {

            // 如果有匹配的规则,说明可重写。
            // 为什么不是有 allExpression 非空才重写呢?在生成 column = value 过滤条件时,会因为 value 不存在,导致未重写。
            // 这样导致第一次无 value,被标记成无需重写;但是第二次有 value,此时会需要重写。
            ContextHolder.setRewrite(true);

            // 单条规则的条件
            String tableName = getTableName(table);
            Expression oneExpress = rule.getExpression(tableName, table.getAlias());
            if (oneExpress == null) {
                continue;
            }
            // 拼接到 allExpression 中
            allExpression = allExpression == null ? oneExpress
                    : new AndExpression(allExpression, oneExpress);
        }

        return allExpression;
    }

    /**
     * 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
     *
     * @param ms MappedStatement
     */
    private void addMappedStatementCache(MappedStatement ms) {
        if (ContextHolder.getRewrite()) {
            return;
        }
        // 无重写,进行添加
        mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
    }

    /**
     * SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
     *
     * @author yudao
     */
    static final class ContextHolder {

        /**
         * 该 {@link MappedStatement} 对应的规则
         */
        private static final ThreadLocal<List<DataPermissionRule>> RULES = ThreadLocal.withInitial(Collections::emptyList);
        /**
         * SQL 是否进行重写
         */
        private static final ThreadLocal<Boolean> REWRITE = ThreadLocal.withInitial(() -> Boolean.FALSE);

        public static void init(List<DataPermissionRule> rules) {
            RULES.set(rules);
            REWRITE.set(false);
        }

        public static void clear() {
            RULES.remove();
            REWRITE.remove();
        }

        public static boolean getRewrite() {
            return REWRITE.get();
        }

        public static void setRewrite(boolean rewrite) {
            REWRITE.set(rewrite);
        }

        public static List<DataPermissionRule> getRules() {
            return RULES.get();
        }

    }

    /**
     * {@link MappedStatement} 缓存
     * 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
     * 如果无效,则可以避免 SQL 的解析,加快速度
     *
     * @author yudao
     */
    static final class MappedStatementCache {

        /**
         * 指定数据权限规则,对指定 MappedStatement 无需重写(不生效)的缓存
         * <p>
         * value:{@link MappedStatement#getId()} 编号
         */
        @Getter
        private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();

        /**
         * 判断是否无需重写
         * ps:虽然有点中文式英语,但是容易读懂即可
         *
         * @param ms    MappedStatement
         * @param rules 数据权限规则数组
         * @return 是否无需重写
         */
        public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
            // 如果规则为空,说明无需重写
            if (org.springframework.util.CollectionUtils.isEmpty(rules)) {
                return true;
            }
            // 任一规则不在 noRewritableMap 中,则说明可能需要重写
            for (DataPermissionRule rule : rules) {
                Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
                if (mappedStatementIds != null && !mappedStatementIds.stream().anyMatch(item -> item.equals(ms.getId()))) {
                    return false;
                }
                return false;
            }
            return true;
        }

        /**
         * 添加无需重写的 MappedStatement
         *
         * @param ms    MappedStatement
         * @param rules 数据权限规则数组
         */
        public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
            for (DataPermissionRule rule : rules) {
                Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
                if (CollectionUtils.isEmpty(mappedStatementIds)) {
                    mappedStatementIds.add(ms.getId());
                } else {
                    noRewritableMappedStatements.put(rule.getClass(), Arrays.stream(new String[]{ms.getId()}).collect(Collectors.toSet()));
                }
            }
        }

        /**
         * 清空缓存
         * 目前主要提供给单元测试
         */
        public void clear() {
            noRewritableMappedStatements.clear();
        }

    }

    /**
     * 获得 Table 对应的表名
     * <p>
     * 兼容 MySQL 转义表名 `t_xxx`
     *
     * @param table 表
     * @return 去除转移字符后的表名
     */
    public static String getTableName(Table table) {
        String tableName = table.getName();
        if (tableName.startsWith(MYSQL_ESCAPE_CHARACTER) && tableName.endsWith(MYSQL_ESCAPE_CHARACTER)) {
            tableName = tableName.substring(1, tableName.length() - 1);
        }
        return tableName;
    }


}

配置SpringBean

将三个Bean进行添加:

  • MybatisPlusInterceptor ,Sql拦截器;
  • DataPermissionAnnotationAdvisor ,DataScope拦截器;
  • DeptDataPermissionRule ,部门权限解析对象,同时注册需要添加部门权限的表名和部门ID字段名,有新增表时需要在这里设置。

关于springboot 拦截器之Advisor不生效问题 可参考该文章。

package com.luo.chengrui.labs.lab02.config;

import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.luo.chengrui.labs.lab02.datapermission.*;
import com.luo.chengrui.labs.lab02.datapermission.dept.DeptDataPermissionRule;
import com.luo.chengrui.labs.lab02.interceptor.DataPermissionAnnotationAdvisor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Role;

import java.util.ArrayList;
import java.util.List;

/**
 * 数据权限的自动配置类
 *
 * @author yudao
 */
@Configuration
public class DataPermissionConfiguration {

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor(List<DataPermissionRule> dataPermissionRule) {
        MybatisPlusInterceptor mybatisPlusInterceptor = new MybatisPlusInterceptor();

        // 分页插件
//        mybatisPlusInterceptor.addInnerInterceptor(new PaginationInnerInterceptor());

        //添加权限拦截器。
        DataPermissionDatabaseInterceptor inner = new DataPermissionDatabaseInterceptor(dataPermissionRule);
        List<InnerInterceptor> inners = new ArrayList<>(mybatisPlusInterceptor.getInterceptors());
        inners.add(0, inner);
        mybatisPlusInterceptor.setInterceptors(inners);

//        MybatisDatabaseInterceptor mybatisDatabaseInterceptor = new MybatisDatabaseInterceptor();
//        List<InnerInterceptor> inners = new ArrayList<>(mybatisPlusInterceptor.getInterceptors());
//        inners.add(0, mybatisDatabaseInterceptor);
//        mybatisPlusInterceptor.setInterceptors(inners);

        return mybatisPlusInterceptor;
    }

    /**
     * 初始化部门权限 bean 。
     *
     * @return
     */
    @Bean
    public DeptDataPermissionRule deptDataPermissionRule() {
        // 创建 DeptDataPermissionRule 对象
        DeptDataPermissionRule rule = new DeptDataPermissionRule();
        // 用户表 需要作权限过滤,用户表中部门字段为dept_id。
        rule.addDeptColumn("users","dept_id");
        // 请假流程表,也需要按部门权限过滤;但wf_leave表中部门字段为:deptid,则应按如下配置
        rule.addDeptColumn("wf_leave","deptid");
        return rule;
    }

    /**
     * 权限注解拦截器。
     *
     * @return
     */
    @Bean
    @Role(2)
    public DataPermissionAnnotationAdvisor dataPermissionAnnotationAdvisor() {
        return new DataPermissionAnnotationAdvisor();
    }

}

数据权限应用

UserService.java
定义了用户列表查询接口,添加了@DataScope注解

package com.luo.chengrui.labs.lab02.service;

import com.luo.chengrui.labs.lab02.annotation.DataScope;
import com.luo.chengrui.labs.lab02.dataobject.UserDO;
import com.luo.chengrui.labs.lab02.mapper.UserMapper;
import org.springframework.aop.framework.AopContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.List;

/**
 * @author
 * @version 1.0.0
 * @description
 * @createTime 2023/07/21
 */
@Service
public class UserService {

    @Autowired
    private UserMapper userMapper;

    private UserService self() {
        return (UserService) AopContext.currentProxy();
    }

    @DataScope
    public List<UserDO> selectList() {
        return userMapper.selectList();
    }
}

UserMapper.java

package com.luo.chengrui.labs.lab02.mapper;

import com.luo.chengrui.labs.lab02.dataobject.UserDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;

import java.util.List;


@Mapper
public interface UserMapper {

    UserDO selectById(@Param("id") Integer id);

    List<UserDO> selectList();
}

UserMapper.xml

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.luo.chengrui.labs.lab02.mapper.UserMapper">

    <sql id="FIELDS">
        id
        , username
    </sql>

    <select id="selectById" parameterType="Integer" resultType="UserDO">
        SELECT
        <include refid="FIELDS"/>
        FROM users
        WHERE id = #{id}
    </select>

    <select id="selectList" resultType="UserDo">
        SELECT
        <include refid="FIELDS"/>
        FROM users
    </select>

</mapper>

功能验证1:

DataPermissionConfiguration.java 类中 DeptDataPermissionRule 对象初始化设置用户表权限,配置如下:

 @Bean
    public DeptDataPermissionRule deptDataPermissionRule() {
        // 创建 DeptDataPermissionRule 对象
        DeptDataPermissionRule rule = new DeptDataPermissionRule();
        // 用户表 需要作权限过滤,用户表中部门字段为dept_id。
        rule.addDeptColumn("users","dept_id");
        // 请假流程表,也需要按部门权限过滤;但wf_leave表中部门字段为:deptid,则应按如下配置
        return rule;
    }

运行结果如下:
在这里插入图片描述

功能测试2:不添加表和字段

不添加表字段,或者添加的表不是users表时,再执行用户查询。

@Bean
    public DeptDataPermissionRule deptDataPermissionRule() {
        // 创建 DeptDataPermissionRule 对象
        DeptDataPermissionRule rule = new DeptDataPermissionRule();
        // 用户表 需要作权限过滤,用户表中部门字段为dept_id。
        //rule.addDeptColumn("users","dept_id");
        // 请假流程表,也需要按部门权限过滤;但wf_leave表中部门字段为:deptid,则应按如下配置
        
        rule.addDeptColumn("wf_leavel","deptid");
        return rule;
    }

运行结果如下:
在这里插入图片描述

以上几个步骤完成后,就可以对业务无入侵完成数据权限控制。

当有空时,再完善DeptDataPermissionRule中对ruoyi部门权限实现的代码。

文章来源:https://blog.csdn.net/ldz_wolf/article/details/135087144
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。