springboot2.7集成sharding-jdbc4.1.1实现业务分表

发布时间:2024年01月10日

1、引入maven

        <dependency>
            <groupId>org.apache.shardingsphere</groupId>
            <artifactId>sharding-jdbc-spring-boot-starter</artifactId>
            <version>4.1.1</version>
        </dependency>

2、基本代码示例

基本逻辑:利用数据库存在的租户uuid,做租户级别的数据分表,如 user_${uuid},order_${uuid}等,因为好像pgsql无法支持使用 “-”这个特殊字符做为表名,所以需要吧uuid中的 - 全部替换掉

DynamicTableConfig.java

用于初始化动态分表信息

MybatisPlusConfig.java

sql拦截器,针对所有插入,查询和更新,判断所用sql是否属于分表范围内

OrgAutoShardingSphereFixture.java

自定义分片算法,继承实现Hint分片,用于动态分表

ShardingAlgorithmTool

动态分表工具类

ShardingTablesLoadRunner

项目启动后 读取已有分表 进行缓存

SqlParserHandler

sql工具类

TablesNamesConfig

分表信息和sql集合类,所需要做分表的,均需要配置在这里

EdgeUserAndOrderServiceInterceptor做sql拦截器

TablesNamesConfig类:
public class TablesNamesConfig {
//这里是你要做分表的表名
    public final static String TABLES_NAMES = "edge_cs_user,edge_order_info";

    /**
     * 模板sql
     *
     * @param tableName
     * @return
     */
    public static List<String> selectTableCreateSql(String tableName) {
        List<String> res = new ArrayList<>();
        if (tableName.equals("edge_cs_user")) {
            res.add("替换成你要创建的sql语句");
        } else if (tableName.equals("edge_order_info")) {
            res.add("替换成你要创建的sql语句");
        }
        return res;
    }
}
SqlParserHandler类
import com.baomidou.mybatisplus.core.injector.methods.Insert;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.alter.Alter;
import net.sf.jsqlparser.statement.create.index.CreateIndex;
import net.sf.jsqlparser.statement.create.table.CreateTable;
import net.sf.jsqlparser.statement.create.view.CreateView;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.drop.Drop;
import net.sf.jsqlparser.statement.execute.Execute;
import net.sf.jsqlparser.statement.merge.Merge;
import net.sf.jsqlparser.statement.replace.Replace;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.truncate.Truncate;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.upsert.Upsert;
import net.sf.jsqlparser.util.TablesNamesFinder;

import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

/**
 * sql解析工具
 *
 * @Created by xieyaoyi
 * @Created_date 2023/12/12
 */
public class SqlParserHandler {

    /**
     * 由于jsqlparser没有获取SQL类型的原始工具,并且在下面操作时需要知道SQL类型,所以编写此工具方法
     *
     * @param sql sql语句
     * @return sql类型,
     * @throws JSQLParserException
     */
    public static String getSqlType(String sql) throws JSQLParserException {
        Statement sqlStmt = CCJSqlParserUtil.parse(new StringReader(sql));
        if (sqlStmt instanceof Alter) {
            return "ALTER";
        } else if (sqlStmt instanceof CreateIndex) {
            return "CREATEINDEX";
        } else if (sqlStmt instanceof CreateTable) {
            return "CREATETABLE";
        } else if (sqlStmt instanceof CreateView) {
            return "CREATEVIEW";
        } else if (sqlStmt instanceof Delete) {
            return "DELETE";
        } else if (sqlStmt instanceof Drop) {
            return "DROP";
        } else if (sqlStmt instanceof Execute) {
            return "EXECUTE";
        } else if (sqlStmt instanceof Insert) {
            return "INSERT";
        } else if (sqlStmt instanceof Merge) {
            return "MERGE";
        } else if (sqlStmt instanceof Replace) {
            return "REPLACE";
        } else if (sqlStmt instanceof Select) {
            return "SELECT";
        } else if (sqlStmt instanceof Truncate) {
            return "TRUNCATE";
        } else if (sqlStmt instanceof Update) {
            return "UPDATE";
        } else if (sqlStmt instanceof Upsert) {
            return "UPSERT";
        } else {
            return "NONE";
        }
    }

    /**
     * 获取sql操作接口,与上面类型判断结合使用
     * example:
     * String sql = "create table a(a string)";
     * SqlType sqlType = SqlParserTool.getSqlType(sql);
     * if(sqlType.equals(SqlType.SELECT)){
     * Select statement = (Select) SqlParserTool.getStatement(sql);
     * }
     *
     * @param sql
     * @return
     * @throws JSQLParserException
     */
    public static Statement getStatement(String sql) throws JSQLParserException {
        Statement sqlStmt = CCJSqlParserUtil.parse(new StringReader(sql));
        return sqlStmt;
    }

    /**
     * 获取tables的表名
     *
     * @param statement
     * @return
     */
    public static <T> List<String> getTableList(T statement) {
        TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
        List<String> tableList = tablesNamesFinder.getTableList((Statement) statement);
        return tableList;
    }

    /**
     * 获取join层级
     *
     * @param selectBody
     * @return
     */
    public static List<Join> getJoins(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            List<Join> joins = ((PlainSelect) selectBody).getJoins();
            return joins;
        }
        return new ArrayList<Join>();
    }

    /**
     * @param selectBody
     * @return
     */
    public static List<Table> getIntoTables(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            List<Table> tables = ((PlainSelect) selectBody).getIntoTables();
            return tables;
        }
        return new ArrayList<Table>();
    }

    /**
     * @param selectBody
     * @return
     */
    public static void setIntoTables(SelectBody selectBody, List<Table> tables) {
        if (selectBody instanceof PlainSelect) {
            ((PlainSelect) selectBody).setIntoTables(tables);
        }
    }

    /**
     * 获取limit值
     *
     * @param selectBody
     * @return
     */
    public static Limit getLimit(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            Limit limit = ((PlainSelect) selectBody).getLimit();
            return limit;
        }
        return null;
    }

    /**
     * 为SQL增加limit值
     *
     * @param selectBody
     * @param l
     */
    public static void setLimit(SelectBody selectBody, long l) {
        if (selectBody instanceof PlainSelect) {
            Limit limit = new Limit();
            limit.setRowCount(new LongValue(String.valueOf(l)));
            ((PlainSelect) selectBody).setLimit(limit);
        }
    }

    /**
     * 获取FromItem不支持子查询操作
     *
     * @param selectBody
     * @return
     */
    public static FromItem getFromItem(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            FromItem fromItem = ((PlainSelect) selectBody).getFromItem();
            return fromItem;
        } else if (selectBody instanceof WithItem) {
            getFromItem(selectBody);
        }
        return null;
    }

    /**
     * 获取子查询
     *
     * @param selectBody
     * @return
     */
    public static SubSelect getSubSelect(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            FromItem fromItem = ((PlainSelect) selectBody).getFromItem();
            if (fromItem instanceof SubSelect) {
                return ((SubSelect) fromItem);
            }
        } else if (selectBody instanceof WithItem) {
            getSubSelect(selectBody);
        }
        return null;
    }

    /**
     * 判断是否为多级子查询
     *
     * @param selectBody
     * @return
     */
    public static boolean isMultiSubSelect(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            FromItem fromItem = ((PlainSelect) selectBody).getFromItem();
            if (fromItem instanceof SubSelect) {
                SelectBody subBody = ((SubSelect) fromItem).getSelectBody();
                if (subBody instanceof PlainSelect) {
                    FromItem subFromItem = ((PlainSelect) subBody).getFromItem();
                    if (subFromItem instanceof SubSelect) {
                        return true;
                    }
                }
            }
        }
        return false;
    }

    /**
     * 获取查询字段
     *
     * @param selectBody
     * @return
     */
    public static List<SelectItem> getSelectItems(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            List<SelectItem> selectItems = ((PlainSelect) selectBody).getSelectItems();
            return selectItems;
        }
        return null;
    }

    public static void main(String[] args) throws JSQLParserException {
        String sql = "SELECT table_name FROM information_schema.tables  WHERE table_name  like concat('edge_cs_user','%')";
        Statement statement = getStatement(sql);
        List<String> tableList = getTableList(statement);

        String sqlType = getSqlType(sql);
        System.out.println(sqlType);
        for (String s : tableList) {
            System.out.println(s);
        }
    }

}

ShardingAlgorithmTool工具类,为了可以指定数据库创建指定分表操作,并将创建好的数据表缓存起来,下次就不用再重复创建


import cn.hutool.core.io.resource.ClassPathResource;
import cn.hutool.core.util.StrUtil;
import com.youxin.commons.commonsdata.service.EdgeOrgData;
import com.youxin.edge_service.ifoodapi.utils.IfoodApiUtils;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.*;

@Slf4j
public class ShardingAlgorithmTool {
    private static final HashSet<String> tableNameCache = new HashSet<>();

    /**
     * 判断 分表获取的表名是否存在 不存在则自动建表
     *
     * @param logicTableName  逻辑表名(表头)
     * @param resultTableName 真实表名
     * @return 确认存在于数据库中的真实表名
     */
    public static String shardingTablesCheckAndCreatAndReturn(String logicTableName, String resultTableName) {
        log.error(String.valueOf(EdgeOrgData.orgListMap));
        synchronized (logicTableName.intern()) {
            // 缓存中有此表 返回
            if (tableNameCache.contains(resultTableName)) {
                return resultTableName;
            }
            // 缓存中无此表 建表 并添加缓存
            List<String> sqlList = TablesNamesConfig.selectTableCreateSql(logicTableName);
            for (int i = 0; i < sqlList.size(); i++) {
                sqlList.set(i, sqlList.get(i).replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS").replace(logicTableName, resultTableName));
            }
            if (executeSql(sqlList)){
                tableNameCache.add(resultTableName);
            }
        }
        return resultTableName;
    }

    /**
     * 缓存重载方法
     */
    public static void tableNameCacheReload(String active) {
        // 读取数据库中所有表名
        List<String> tableNameList = getAllTableNameBySchema(active);
        // 删除旧的缓存(如果存在)
        ShardingAlgorithmTool.tableNameCache.clear();
        // 写入新的缓存
        ShardingAlgorithmTool.tableNameCache.addAll(tableNameList);
    }


    private static boolean executeSql(List<String> sqlList) {
        final ClassPathResource resource = new ClassPathResource("application.yml");
        Properties properties = new Properties();
        try {
            properties.load(resource.getStream());
            String active = properties.getProperty("active");
            String propertiesname = "application.properties";
            switch (active) {
                case "dev":
                    propertiesname = "application-dev.properties";
                    break;
                case "test":
                    propertiesname = "application-test.properties";
                    break;
                case "prod":
                    propertiesname = "application-prod.properties";
                    break;
                default:
                    break;
            }
            final ClassPathResource resource1 = new ClassPathResource(propertiesname);
            properties.load(resource1.getStream());
        } catch (IOException e) {
            log.error("读取sharding.yaml文件失败{}",e);
            return false;
        }
        try (Connection conn1 = DriverManager.getConnection(properties.getProperty("spring.shardingsphere.datasource.ds1.jdbc-url"),
                properties.getProperty("spring.shardingsphere.datasource.ds1.username"),
                properties.getProperty("spring.shardingsphere.datasource.ds1.password"))) {
            try (Statement st = conn1.createStatement()) {
                conn1.setAutoCommit(false);
                for (String sql : sqlList) {
                    st.execute(sql);
                }
                conn1.commit();
            } catch (Exception ex) {
                log.error("执行sql失败,原因:{}", ex);
                conn1.rollback();
                return false;
            }
        } catch (Exception ex) {
            log.error("手动链接失败失败,原因:{}", ex);
            return false;
        }
        return true;
    }


    public static List<String> getAllTableNameBySchema(String active) {
        String propertiesname = "application.properties";
        switch (active) {
            case "dev":
                propertiesname = "application-dev.properties";
                break;
            case "test":
                propertiesname = "application-test.properties";
                break;
            case "prod":
                propertiesname = "application-prod.properties";
                break;
            default:
                break;
        }
        List<String> res = new ArrayList<>();
        final ClassPathResource resource = new ClassPathResource(propertiesname);
        Properties properties = new Properties();
        try {
            properties.load(resource.getStream());
        } catch (IOException e) {
            log.error("读取sharding.yaml文件失败");
            throw new RuntimeException(e);
        }
        String[] tablesNames = TablesNamesConfig.TABLES_NAMES.split(StrUtil.COMMA);
        for (String table_name : tablesNames) {
            String sql = "SELECT table_name FROM information_schema.tables  WHERE table_name  like concat(" + "'" + table_name + "'" + ",'%')";
            try (Connection connection = DriverManager.getConnection(properties.getProperty("spring.shardingsphere.datasource.ds1.jdbc-url"),
                    properties.getProperty("spring.shardingsphere.datasource.ds1.username"),
                    properties.getProperty("spring.shardingsphere.datasource.ds1.password"));
                 Statement st = connection.createStatement()) {
                try (ResultSet rs = st.executeQuery(sql)) {
                    while (rs.next()) {
                        res.add(rs.getString(1));
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return res;
    }

    public static HashSet<String> cacheTableNames() {
        return tableNameCache;
    }
}

DynamicTableConfig初始化

import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.api.config.sharding.ShardingRuleConfiguration;
import org.apache.shardingsphere.api.config.sharding.TableRuleConfiguration;
import org.apache.shardingsphere.api.config.sharding.strategy.HintShardingStrategyConfiguration;
import org.apache.shardingsphere.core.rule.ShardingDataSourceNames;
import org.apache.shardingsphere.core.rule.ShardingRule;
import org.apache.shardingsphere.core.rule.TableRule;
import org.apache.shardingsphere.shardingjdbc.jdbc.core.datasource.ShardingDataSource;
import org.springframework.context.EnvironmentAware;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.sql.DataSource;
import java.sql.*;
import java.util.*;

/**
 * 初始化动态分表信息
 *
 * @author menshaojing
 */
@Component
@Slf4j
public class DynamicTableConfig implements EnvironmentAware {
    /**
     * 数据库集合key
     **/
    private final static String DBS_KEY = "spring.shardingsphere.datasource.names";
    /**
     * 自定义水平分表集合key
     **/
    private final static String TABLES_KEY = "shardingsphere.sharding.tables";

    private final static String DATASOURCE_KEY = "spring.shardingsphere.datasource";

    /**
     * 数据库集合
     **/
    private String dbs;
    /**
     * 动态分表集合
     **/
    private String tables;


    private String driverClassName;

    private String jdbcUrl;

    private String userName;

    private String password;

    private Connection connection;

    @Resource(name = "shardingDataSource")
    private DataSource dataSource;

    /**
     * 初始化动态分表信息
     * ds0.edge_cs_user_org_id
     */
    @PostConstruct
    public void initDynamicTable() throws SQLException, ClassNotFoundException {
        log.info("动态初始化添加分库分表策略.....");
        ShardingDataSource dataSource = (ShardingDataSource) this.dataSource;
        ShardingRule tableRule = dataSource.getRuntimeContext().getRule();
        final Collection<TableRule> tableRules = tableRule.getTableRules();
        final List<String> tableList = Arrays.asList(this.tables.split(StrUtil.COMMA));
        //进行添加实际数据节点,hit分片
        for (String table : tableList) {
            log.info("开始动态初始化添加[{}]分表策略", table);
            addDefaultHitAlgorithm(table, tableRules);
            log.info("结束动态初始化添加[{}]分表策略", table);

        }
    }

    public void loadDatabase() throws ClassNotFoundException, SQLException {
        Class.forName(this.driverClassName);
        connection = DriverManager.getConnection(this.jdbcUrl, this.userName, this.password);

    }


    /**
     * 添加实际数据节点:按照组织id进行水平分表
     *
     * @param table
     */
    public String addDefaultActualDataNodes(String table) throws SQLException, ClassNotFoundException {
        loadDatabase();
        List<String> allOrg = getAllOrg();
        connection.close();
        StringBuilder stringBuilder = new StringBuilder();
        for (String db : this.dbs.split(StrUtil.COMMA)) {
            for (String org_id : allOrg) {
                stringBuilder
                        .append(db)
                        .append(StrUtil.DOT)
                        .append(table)
                        .append(StrUtil.UNDERLINE)
                        .append(org_id.replaceAll("-", ""))
                        .append(StrUtil.COMMA);
            }
        }
        log.info("添加实际数据节点[{}] :{}", table, stringBuilder.substring(0, stringBuilder.length() - 1));
        return stringBuilder.substring(0, stringBuilder.length() - 1);
    }

    /**
     * 添加默认hit算法 algorithm
     *
     * @param table
     * @param tableRules
     */
    public void addDefaultHitAlgorithm(String table, Collection<TableRule> tableRules) throws SQLException, ClassNotFoundException {

        //表规则配置
        TableRuleConfiguration tableRuleConfiguration = new TableRuleConfiguration(table, addDefaultActualDataNodes(table));

        //添加默认数据库hit算法
        HintShardingStrategyConfiguration hintShardingStrategyConfiguration = new HintShardingStrategyConfiguration(new OrgAutoShardingSphereFixture());

        tableRuleConfiguration.setDatabaseShardingStrategyConfig(hintShardingStrategyConfiguration);
        log.info("添加默认数据库hit算法[{}]策略:{}", table, hintShardingStrategyConfiguration);
        //添加默认分表hit算法
        hintShardingStrategyConfiguration = new HintShardingStrategyConfiguration(new OrgAutoShardingSphereFixture());
        tableRuleConfiguration.setTableShardingStrategyConfig(hintShardingStrategyConfiguration);
        log.info("添加默认分表hit算法[{}]策略:{}", table, hintShardingStrategyConfiguration);
        //原始数据源集合
        Collection<String> rawDataSourceNames = new ArrayList<>();

        for (String db : this.dbs.split(StrUtil.COMMA)) {
            rawDataSourceNames.add(db);
        }

        ShardingRuleConfiguration shardingRuleConfiguration = new ShardingRuleConfiguration();

        ShardingDataSourceNames shardingDataSourceNames = new ShardingDataSourceNames(shardingRuleConfiguration, rawDataSourceNames);

        TableRule tableRule = new TableRule(tableRuleConfiguration, shardingDataSourceNames, null);

        tableRules.add(tableRule);
    }


    /**
     * 获取所有组织信息
     *
     * @return
     */
    private List<String> getAllOrg() throws SQLException {
        List<String> list = new ArrayList<>();
        final PreparedStatement preparedStatement = connection.prepareStatement("SELECT org_id from edge_org where activity != -1");
        final ResultSet resultSet = preparedStatement.executeQuery();
        while (resultSet.next()) {
            String org_id = resultSet.getObject(1, String.class);
            list.add(org_id);
        }
        preparedStatement.close();
        resultSet.close();
        return list;
    }


    @Override
    public void setEnvironment(Environment environment) {
        this.dbs = environment.getProperty(DBS_KEY);
        this.tables = TablesNamesConfig.TABLES_NAMES;
        String db = this.dbs.split(StrUtil.COMMA)[0];
        final String s = DATASOURCE_KEY + StrUtil.DOT + db + StrUtil.DOT + "driver-class-name";
        this.driverClassName = environment.getProperty(DATASOURCE_KEY + StrUtil.DOT + db + StrUtil.DOT + "driver-class-name");
        this.jdbcUrl = environment.getProperty(DATASOURCE_KEY + StrUtil.DOT + db + StrUtil.DOT + "jdbc-url");
        this.userName = environment.getProperty(DATASOURCE_KEY + StrUtil.DOT + db + StrUtil.DOT + "username");
        this.password = environment.getProperty(DATASOURCE_KEY + StrUtil.DOT + db + StrUtil.DOT + "password");
    }

}

OrgAutoShardingSphereFixture类,实现Hint分片,这里是最主要的业务核心


import com.google.common.collect.Range;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.api.sharding.hint.HintShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.hint.HintShardingValue;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.stream.Collectors;

/**
 * 自定义分片算法,继承实现Hint分片
 *
 * @Created by xieyaoyi
 * @Created_date 2023/12/1
 */
@Slf4j
@Component
public class OrgAutoShardingSphereFixture implements HintShardingAlgorithm<String> {

    /**
     * @param collection        数据源集合
     *                          在分库时值为所有分片库的集合 databaseNames
     *                          分表时为对应分片库中所有分片表的集合 tablesNames
     * @param hintShardingValue 分片属性,包括
     *                          logicTableName 为逻辑表,
     *                          columnName 分片健(字段),hit策略此处为空 ""
     *                          <p>
     *                          value 【之前】都是 从 SQL 中解析出的分片健的值,用于取模判断
     *                          HintShardingAlgorithm不再从SQL 解析中获取值,而是直接通过
     *                          hintManager.addTableShardingValue("edge_cs_user", “003538e36799cec47ebbe1d56fa1671bde9”)参数进行指定
     * @return
     */
    @Override
    public Collection<String> doSharding(Collection<String> collection, HintShardingValue<String> hintShardingValue) {
        //collection.forEach(i -> System.out.println("节点配置表名为: " + i));
        Collection<String> result = new ArrayList<>();
        Set<String> tableSet = collection.stream().collect(Collectors.toSet());
        for (String shardingValue : hintShardingValue.getValues()) {
            String value = hintShardingValue.getLogicTableName() + "_" + shardingValue;
            if (!tableSet.isEmpty() && tableSet.contains(value)) {
                ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(hintShardingValue.getLogicTableName(), value);
                result.add(value);
            }else {
                ShardingAlgorithmTool.shardingTablesCheckAndCreatAndReturn(hintShardingValue.getLogicTableName(), value);
                result.add(value);
            }
        }
        return result;

    }


}
ShardingTablesLoadRunner类


import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;

/**
 * 项目启动后 读取已有分表 进行缓存
 */
@Slf4j
@Order
@Component
public class ShardingTablesLoadRunner implements CommandLineRunner {

    @Value("${spring.profiles.active:prod}")
    private String active;

    @Override
    public void run(String... args) {
        ShardingAlgorithmTool.tableNameCacheReload(active);
    }
}

EdgeUserAndOrderServiceInterceptor类,做sql拦截,用于sql进行分表创建和查询


import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.youxin.shardingsphere.SqlParserHandler;
import com.youxin.shardingsphere.TablesNamesConfig;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.statement.Statement;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.shardingsphere.api.hint.HintManager;
import org.springframework.stereotype.Component;


import java.sql.SQLException;
import java.util.List;

/**
 * 用户表、订单表分表
 *
 * @Created by xieyaoyi
 * @Created_date 2023/12/4
 */
@Component
@Slf4j
public class EdgeUserAndOrderServiceInterceptor implements InnerInterceptor {


    /**
     * 判断是否符合规则
     *
     * @param sql
     * @return
     */
    public boolean judgmentSql(String sql) {
        String[] tableames = TablesNamesConfig.TABLES_NAMES.split(StrUtil.COMMA);
        for (String table_name : tableames) {
            if (sql.toLowerCase().indexOf("from " + table_name) > 0
                    || sql.toLowerCase().indexOf("update " + table_name) >= 0
                    || sql.toLowerCase().indexOf("into " + table_name) > 0) {
                return true;
            }
        }
        return false;
    }

    @Override
    public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        BoundSql boundSql = ms.getBoundSql(parameter);
        String sql = boundSql.getSql();
        if (!judgmentSql(sql)) {
            return;
        }
        String state = JSONObject.parseObject(JSON.toJSON(parameter).toString()).getString("state");
        if (parameter != null && StringUtils.isNotBlank(state)) {
            try {
                Statement statement = SqlParserHandler.getStatement(sql);
                final List<String> tableList = SqlParserHandler.getTableList(statement);
                //清除历史规则
                HintManager.clear();
                //获取对应的实例
                HintManager hintManager = HintManager.getInstance();
                for (String table : tableList) {
                    //设置表的分片键值,value是用于表分片
                    hintManager.addTableShardingValue(table, state);
                }
                log.info("解析SQL表名:{}", tableList);
            } catch (JSQLParserException e) {
                log.error("解析SQL表名失败:{}", e);
                throw new RuntimeException(e);
            }
        }
    }

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        boundSql = ms.getBoundSql(parameter);
        String sql = boundSql.getSql();
        if (!judgmentSql(sql)) {
            return;
        }
        String state = JSONObject.parseObject(JSON.toJSON(parameter).toString()).getString("state");
        if (parameter != null && StringUtils.isNotBlank(state)) {
            try {
                Statement statement = SqlParserHandler.getStatement(sql);
                final List<String> tableList = SqlParserHandler.getTableList(statement);
                //清除历史规则
                HintManager.clear();
                //获取对应的实例
                HintManager hintManager = HintManager.getInstance();
                for (String table : tableList) {
                    //设置表的分片键值,value是用于表分片
                    hintManager.addTableShardingValue(table, state);
                }
                log.info("解析SQL表名:{}", tableList);
            } catch (JSQLParserException e) {
                log.error("解析SQL表名失败:{}", e);
                throw new RuntimeException(e);
            }
        }
    }
}
MybatisPlusConfig做拦截器扫描配置进来


import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.youxin.commons.interceptor.service.EdgeUserAndOrderServiceInterceptor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

/**
 * @Created by xieyaoyi
 * @Created_date 2023/12/4
 */
@Slf4j
@Configuration
public class MybatisPlusConfig {



    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        addPaginationInnerInterceptor(interceptor);
        return interceptor;
    }

    private void addPaginationInnerInterceptor(MybatisPlusInterceptor interceptor) {
        //分页
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.POSTGRE_SQL));
        //动态表拦截器
        interceptor.addInnerInterceptor(new EdgeUserAndOrderServiceInterceptor());

    }

}

以上就是主要的分表业务代码了,但是注意的是以上只是显示了基本的框架,你在插入或者查询的时候,是需要带上分表关键字段的数据,才能做到真正的切换指定的表数据

如:我这里是使用state这个字段作为分表的关键字段,所以在插入、查询、更新。或者删除的时候都必须传入这个参数,及时我的mapper里面不需要这个参数的使用

示例:

    public EdgeOrderInfo getEdgeOrderInfo(String order_no, String state) {
        return edgeOrderMapper.getOrderInfoByOrderNo(order_no, state);
    }


    public EdgeOrderInfo addOrUpdateEdgeOrderInfo(EdgeOrderInfo orderInfo) {
        if (StringUtils.isBlank(orderInfo.getOrder_no())) {
            orderInfo.setOrder_no(CreateNoUtils.getCreateOrderNo(6));
        }
        if (StringUtils.isBlank(orderInfo.getState())) {
            orderInfo.setState(orderInfo.getOrg_id().replaceAll("-", ""));
        }
        if (getEdgeOrderInfo(orderInfo.getOrder_no(), orderInfo.getState()) != null) {
            orderInfo.setUpdated_date(new Date());
            edgeOrderMapper.updateByOrderNo(orderInfo, orderInfo.getState());
        } else {
            edgeOrderMapper.insert(orderInfo);
        }
        return orderInfo;
    }

这样在sql拦截后,就会经过OrgAutoShardingSphereFixture类将edge_order_info替换edge_order_info_${uuid}了

数据库效果展示:

我也是参考别人的博客写的,推荐一下写得可以的博客一起学习一下

参考博客:

https://blog.csdn.net/weixin_39403349/article/details/130264892

文章来源:https://blog.csdn.net/qq_32003379/article/details/135498902
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。