本文主要研究一下PowerJob的AbstractSqlProcessor
tech/powerjob/official/processors/impl/sql/AbstractSqlProcessor.java
@Slf4j
public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
/**
* 默认超时时间
*/
protected static final int DEFAULT_TIMEOUT = 60;
/**
* name => SQL validator
* 注意 :
* - 返回 true 表示验证通过
* - 返回 false 表示 SQL 非法,将被拒绝执行
*/
protected final Map<String, Predicate<String>> sqlValidatorMap = Maps.newConcurrentMap();
/**
* 自定义 SQL 解析器
*/
protected SqlParser sqlParser;
private static final Joiner JOINER = Joiner.on("|").useForNull("-");
@Override
public ProcessResult process0(TaskContext taskContext) {
OmsLogger omsLogger = taskContext.getOmsLogger();
// 解析参数
SqlParams sqlParams = extractParams(taskContext);
omsLogger.info("origin sql params: {}", JSON.toJSON(sqlParams));
// 校验参数
validateParams(sqlParams);
StopWatch stopWatch = new StopWatch(this.getClass().getSimpleName());
// 解析
stopWatch.start("Parse SQL");
if (sqlParser != null) {
omsLogger.info("before parse sql: {}", sqlParams.getSql());
String newSQL = sqlParser.parse(sqlParams.getSql(), taskContext);
sqlParams.setSql(newSQL);
omsLogger.info("after parse sql: {}", newSQL);
}
stopWatch.stop();
// 校验 SQL
stopWatch.start("Validate SQL");
validateSql(sqlParams.getSql(), omsLogger);
stopWatch.stop();
// 执行
stopWatch.start("Execute SQL");
omsLogger.info("final sql params: {}", JSON.toJSON(sqlParams));
executeSql(sqlParams, taskContext);
stopWatch.stop();
omsLogger.info(stopWatch.prettyPrint());
String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
return new ProcessResult(true, message);
}
abstract Connection getConnection(SqlParams sqlParams, TaskContext taskContext) throws SQLException;
public void setSqlParser(SqlParser sqlParser) {
this.sqlParser = sqlParser;
}
public void registerSqlValidator(String validatorName, Predicate<String> sqlValidator) {
sqlValidatorMap.put(validatorName, sqlValidator);
log.info("register sql validator({})' successfully.", validatorName);
}
//......
}
AbstractSqlProcessor继承了CommonBasicProcessor,其process0先将入参解析为SqlParams,然后调用validateParams进行参数校验,针对sqlParser不为null的会通过sqlParser进行解析,接着通过validateSql校验sql,最后通过executeSql执行sql;它定义了getConnection抽象方法,提供了setSqlParser、registerSqlValidator方法
@Data
public static class SqlParams {
/**
* 数据源名称
*/
private String dataSourceName;
/**
* 需要执行的 SQL
*/
private String sql;
/**
* 超时时间
*/
private Integer timeout;
/**
* jdbc url
* 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format
*/
private String jdbcUrl;
/**
* 是否展示 SQL 执行结果
*/
private boolean showResult;
}
SqlParams定义了dataSourceName、sql、timeout、jdbcUrl、showResult属性
private void validateSql(String sql, OmsLogger omsLogger) {
if (sqlValidatorMap.isEmpty()) {
return;
}
for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) {
Predicate<String> validator = entry.getValue();
if (!validator.test(sql)) {
omsLogger.error("validate sql by validator[{}] failed, skip to process!", entry.getKey());
throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey());
}
}
}
validateSql遍历sqlValidatorMap,挨个执行test方法,验证不通过抛出IllegalArgumentException
@SneakyThrows
private void executeSql(SqlParams sqlParams, TaskContext ctx) {
OmsLogger omsLogger = ctx.getOmsLogger();
boolean originAutoCommitFlag ;
try (Connection connection = getConnection(sqlParams, ctx)) {
originAutoCommitFlag = connection.getAutoCommit();
connection.setAutoCommit(false);
try (Statement statement = connection.createStatement()) {
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
statement.execute(sqlParams.getSql());
connection.commit();
if (sqlParams.showResult) {
outputSqlResult(statement, omsLogger);
}
} catch (Throwable e) {
omsLogger.error("execute sql failed, try to rollback", e);
connection.rollback();
throw e;
} finally {
connection.setAutoCommit(originAutoCommitFlag);
}
}
}
executeSql通过getConnection获取连接,设置为手动提交,然后创建Statement,设置queryTimeout,执行,最后提交,针对showResult的执行outputSqlResult
private void outputSqlResult(Statement statement, OmsLogger omsLogger) throws SQLException {
omsLogger.info("====== SQL EXECUTE RESULT ======");
for (int index = 0; index < Integer.MAX_VALUE; index++) {
// 某一个结果集
ResultSet resultSet = statement.getResultSet();
if (resultSet != null) {
try (ResultSet rs = resultSet) {
int columnCount = rs.getMetaData().getColumnCount();
List<String> columnNames = Lists.newLinkedList();
//column – the first column is 1, the second is 2, ...
for (int i = 1; i <= columnCount; i++) {
columnNames.add(rs.getMetaData().getColumnName(i));
}
omsLogger.info("[Result-{}] [Columns] {}" + System.lineSeparator(), index, JOINER.join(columnNames));
int rowIndex = 0;
List<Object> row = Lists.newLinkedList();
while (rs.next()) {
for (int i = 1; i <= columnCount; i++) {
row.add(rs.getObject(i));
}
omsLogger.info("[Result-{}] [Row-{}] {}" + System.lineSeparator(), index, rowIndex++, JOINER.join(row));
}
}
} else {
int updateCount = statement.getUpdateCount();
if (updateCount != -1) {
omsLogger.info("[Result-{}] update count: {}", index, updateCount);
}
}
if (((!statement.getMoreResults()) && (statement.getUpdateCount() == -1))) {
break;
}
}
omsLogger.info("====== SQL EXECUTE RESULT ======");
}
outputSqlResult从statement获取resultSet,然后打印columnName,在打印每行数据,对于更新操作则打印updateCount
@FunctionalInterface
public interface SqlParser {
/**
* 自定义 SQL 解析逻辑
*
* @param sql 原始 SQL 语句
* @param taskContext 任务上下文
* @return 解析后的 SQL
*/
String parse(String sql, TaskContext taskContext);
}
SqlParser接口定义了parse方法
tech/powerjob/official/processors/impl/sql/DynamicDatasourceSqlProcessor.java
public class DynamicDatasourceSqlProcessor extends AbstractSqlProcessor {
@Override
protected void validateParams(SqlParams sqlParams) {
if (StringUtils.isEmpty(sqlParams.getJdbcUrl())) {
throw new IllegalArgumentException("jdbcUrl can't be empty in DynamicDatasourceSqlProcessor!");
}
}
@Override
Connection getConnection(SqlParams sqlParams, TaskContext taskContext) throws SQLException {
JSONObject params = JSONObject.parseObject(CommonUtils.parseParams(taskContext));
Properties properties = new Properties();
// normally at least a "user" and "password" property should be included
params.forEach((k, v) -> properties.setProperty(k, String.valueOf(v)));
return DriverManager.getConnection(sqlParams.getJdbcUrl(), properties);
}
@Override
protected String getSecurityDKey() {
return SecurityUtils.ENABLE_DYNAMIC_SQL_PROCESSOR;
}
}
DynamicDatasourceSqlProcessor继承了AbstractSqlProcessor,其validateParams要求jdbcUrl不能为空,其getConnection方法会从taskContext提取properties作为DriverManager.getConnection的属性,其getSecurityDKey返回的是
powerjob.official-processor.dynamic-datasource.enable
配置
tech/powerjob/official/processors/impl/sql/SpringDatasourceSqlProcessor.java
@Slf4j
public class SpringDatasourceSqlProcessor extends AbstractSqlProcessor {
/**
* 默认的数据源名称
*/
private static final String DEFAULT_DATASOURCE_NAME = "default";
/**
* name => data source
*/
private final Map<String, DataSource> dataSourceMap;
/**
* 指定默认的数据源
*
* @param defaultDataSource 默认数据源
*/
public SpringDatasourceSqlProcessor(DataSource defaultDataSource) {
dataSourceMap = Maps.newConcurrentMap();
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
}
@Override
Connection getConnection(SqlParams sqlParams, TaskContext taskContext) throws SQLException {
return dataSourceMap.get(sqlParams.getDataSourceName()).getConnection();
}
/**
* 校验参数,如果校验不通过直接抛异常
*
* @param sqlParams SQL 参数信息
*/
@Override
protected void validateParams(SqlParams sqlParams) {
// 检查数据源
if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
// use the default data source when current data source name is empty
sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
}
dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
});
}
/**
* 注册数据源
*
* @param dataSourceName 数据源名称
* @param dataSource 数据源
*/
public void registerDataSource(String dataSourceName, DataSource dataSource) {
Objects.requireNonNull(dataSourceName, "DataSource name must not be null");
Objects.requireNonNull(dataSource, "DataSource must not be null");
dataSourceMap.put(dataSourceName, dataSource);
log.info("register data source({})' successfully.", dataSourceName);
}
/**
* 移除数据源
*
* @param dataSourceName 数据源名称
*/
public void removeDataSource(String dataSourceName) {
DataSource remove = dataSourceMap.remove(dataSourceName);
if (remove != null) {
log.warn("remove data source({})' successfully.", dataSourceName);
}
}
}
SpringDatasourceSqlProcessor继承了AbstractSqlProcessor,其构造器注册名为default的DataSource,其getConnection根据sqlParams的dataSourceName来获取连接,validateParams会先校验指定的dataSource是否存在;它提供了registerDataSource、removeDataSource方法
AbstractSqlProcessor继承了CommonBasicProcessor,其process0先将入参解析为SqlParams,然后调用validateParams进行参数校验,针对sqlParser不为null的会通过sqlParser进行解析,接着通过validateSql校验sql,最后通过executeSql执行sql;它定义了getConnection抽象方法,提供了setSqlParser、registerSqlValidator方法。它有两个实现类分别是DynamicDatasourceSqlProcessor(通过jdbcUrl来构造连接
)、SpringDatasourceSqlProcessor(通过给定的dataSource获取连接
)。