上一篇文章咱们分析了sharding-jdbc 的SQL改写(改写),今天咱们分析下sql执行。sql
SQL执行主要分为2部分:数据库
/**
* @Author serlain
* @Date 2018/11/14 下午11:21
*/
public class MyShardingStatementTest extends AbstractShardingDatabaseOnlyDBUnitTest {
private ShardingDataSource shardingDataSource;
private static String sql = "SELECT o.order_id FROM t_order o WHERE o.order_id = 4";
@Before
public void init() throws SQLException {
shardingDataSource = getShardingDataSource();
}
@Test
public void testselect() throws SQLException {
try (
Connection connection = shardingDataSource.getConnection();
Statement stmt = connection.createStatement();
ResultSet resultSet = stmt.executeQuery(sql)) {
assertTrue(resultSet.next());
assertThat(resultSet.getLong(1), is(40L));
}
}
}
protected final ShardingDataSource getShardingDataSource() {
if (null != shardingDataSource && !isShutdown) {
return shardingDataSource;
}
isShutdown = false;
DataSourceRule dataSourceRule = new DataSourceRule(createDataSourceMap("dataSource_%s"));
TableRule orderTableRule = TableRule.builder("t_order").dataSourceRule(dataSourceRule).actualTables(Lists.newArrayList("t_order_0", "t_order_1")).generateKeyColumn("order_id", IncrementKeyGenerator.class).build();
ShardingRule shardingRule = ShardingRule.builder().dataSourceRule(dataSourceRule).tableRules(Arrays.asList(orderTableRule))
.databaseShardingStrategy(new DatabaseShardingStrategy(Collections.singletonList("order_id"), new MultipleKeysModuloDatabaseShardingAlgorithm()))
.tableShardingStrategy(new TableShardingStrategy("order_id", new OrderShardingAlgorithm())).build();
shardingDataSource = new ShardingDataSource(shardingRule);
return shardingDataSource;
}
复制代码
咱们能够看到ShardingDataSource的构造函数:编程
ShardingRule:分片的规则,外部自定义;ExecutorEngine:执行引擎,内部初始化;ShardingContext:数据源运行期上下文(这个类很关键,把SQL执行期间须要的类都贯穿起来,须要就从这个类里面拿)。缓存
public ShardingDataSource(final ShardingRule shardingRule) {
this(shardingRule, new Properties());
}
public ShardingDataSource(final ShardingRule shardingRule, final Properties props) {
Preconditions.checkNotNull(shardingRule);
Preconditions.checkNotNull(props);
shardingProperties = new ShardingProperties(props);
int executorSize = shardingProperties.getValue(ShardingPropertiesConstant.EXECUTOR_SIZE);
executorEngine = new ExecutorEngine(executorSize);
try {
shardingContext = new ShardingContext(shardingRule, DatabaseType.valueFrom(getDatabaseProductName(shardingRule)), executorEngine);
} catch (final SQLException ex) {
throw new ShardingJdbcException(ex);
}
}
复制代码
@Override
public ShardingConnection getConnection() throws SQLException {
MetricsContext.init(shardingProperties);
return new ShardingConnection(shardingContext);
}
复制代码
@Override
public Statement createStatement(final int resultSetType, final int resultSetConcurrency) throws SQLException {
return new ShardingStatement(this, resultSetType, resultSetConcurrency);
}
public ShardingStatement(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency) {
this(shardingConnection, resultSetType, resultSetConcurrency, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}
public ShardingStatement(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(Statement.class);
this.shardingConnection = shardingConnection;
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
}
复制代码
@Override
public ResultSet executeQuery(final String sql) throws SQLException {
ResultSet result;
try {
//包含了执行和归并,咱们今天只分析执行的部分,归并下一篇文章分析
result = ResultSetFactory.getResultSet(generateExecutor(sql).executeQuery(), routeResult.getSqlStatement());
} finally {
setCurrentResultSet(null);
}
setCurrentResultSet(result);
return result;
}
复制代码
根据路由改写后的结果(dataSource),建立SQL Statement;构建StatementUnit(SQL语句到Statement映射);构建StatementExecutor(SQL执行单元)bash
private StatementExecutor generateExecutor(final String sql) throws SQLException {
clearPrevious();
// 路由、改写后的结果
routeResult = new StatementRoutingEngine(shardingConnection.getShardingContext()).route(sql);
Collection<StatementUnit> statementUnits = new LinkedList<>();
//遍历SQL最小执行单元,根据dataSourceName获取datasource,建立ShardingStatement
for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
Statement statement = shardingConnection.getConnection(
each.getDataSource(), routeResult.getSqlStatement().getType()).createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
replayMethodsInvocation(statement);
//建立StatementUnit:SQL语句-statement 映射
statementUnits.add(new StatementUnit(each, statement));
routedStatements.add(statement);
}
//StatementExecutor:
return new StatementExecutor(shardingConnection.getShardingContext().getExecutorEngine(), routeResult.getSqlStatement().getType(), statementUnits);
}
/**
* 根据数据源名称获取相应的数据库链接.
*
* @param dataSourceName 数据源名称
* @param sqlType SQL语句类型
* @return 数据库链接
* @throws SQLException SQL异常
*/
public Connection getConnection(final String dataSourceName, final SQLType sqlType) throws SQLException {
//缓存的Conn
Optional<Connection> connection = getCachedConnection(dataSourceName, sqlType);
if (connection.isPresent()) {
return connection.get();
}
Context metricsContext = MetricsContext.start(Joiner.on("-").join("ShardingConnection-getConnection", dataSourceName));
//根据配置的DataSourceRule,获取指定name的DataSource
DataSource dataSource = shardingContext.getShardingRule().getDataSourceRule().getDataSource(dataSourceName);
Preconditions.checkState(null != dataSource, "Missing the rule of %s in DataSourceRule", dataSourceName);
String realDataSourceName;
//MasterSlaveDataSource
if (dataSource instanceof MasterSlaveDataSource) {
dataSource = ((MasterSlaveDataSource) dataSource).getDataSource(sqlType);
realDataSourceName = MasterSlaveDataSource.getDataSourceName(dataSourceName, sqlType);
} else {
realDataSourceName = dataSourceName;
}
//获取conn
Connection result = dataSource.getConnection();
MetricsContext.stop(metricsContext);
connectionMap.put(realDataSourceName, result);
replayMethodsInvocation(result);
return result;
}
复制代码
/**
* 执行SQL查询.
*
* @return 结果集列表
*/
public List<ResultSet> executeQuery() {
Context context = MetricsContext.start("ShardingStatement-executeQuery");
List<ResultSet> result;
try {
//交给executorEngine执行
result = executorEngine.executeStatement(sqlType, statementUnits, new ExecuteCallback<ResultSet>() {
//使用回调,具体执行的逻辑由外部编写
@Override
public ResultSet execute(final BaseStatementUnit baseStatementUnit) throws Exception {
return baseStatementUnit.getStatement().executeQuery(baseStatementUnit.getSqlExecutionUnit().getSql());
}
});
} finally {
MetricsContext.stop(context);
}
return result;
}
复制代码
/**
* 执行Statement.
*
* @param sqlType SQL类型
* @param statementUnits 语句对象执行单元集合
* @param executeCallback 执行回调函数
* @param <T> 返回值类型
* @return 执行结果
*/
public <T> List<T> executeStatement(final SQLType sqlType, final Collection<StatementUnit> statementUnits, final ExecuteCallback<T> executeCallback) {
return execute(sqlType, statementUnits, Collections.<List<Object>>emptyList(), executeCallback);
}
复制代码
private <T> List<T> execute(
final SQLType sqlType, final Collection<? extends BaseStatementUnit> baseStatementUnits, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback) {
//须要执行的SQL为空,直接返回空list
if (baseStatementUnits.isEmpty()) {
return Collections.emptyList();
}
Iterator<? extends BaseStatementUnit> iterator = baseStatementUnits.iterator();
//获取第一个须要执行的
BaseStatementUnit firstInput = iterator.next();
//剩下的异步执行
ListenableFuture<List<T>> restFutures = asyncExecute(sqlType, Lists.newArrayList(iterator), parameterSets, executeCallback);
T firstOutput;
List<T> restOutputs;
try {
//同步执行
firstOutput = syncExecute(sqlType, firstInput, parameterSets, executeCallback);
//等待
restOutputs = restFutures.get();
//CHECKSTYLE:OFF
} catch (final Exception ex) {
//CHECKSTYLE:ON
ExecutorExceptionHandler.handleException(ex);
return null;
}
返回
List<T> result = Lists.newLinkedList(restOutputs);
result.add(0, firstOutput);
return result;
}
复制代码
异步执行的逻辑:多线程
private <T> ListenableFuture<List<T>> asyncExecute(
final SQLType sqlType, final Collection<BaseStatementUnit> baseStatementUnits, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback) {
List<ListenableFuture<T>> result = new ArrayList<>(baseStatementUnits.size());
final boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
final Map<String, Object> dataMap = ExecutorDataMap.getDataMap();
for (final BaseStatementUnit each : baseStatementUnits) {
result.add(executorService.submit(new Callable<T>() {
@Override
public T call() throws Exception {
return executeInternal(sqlType, each, parameterSets, executeCallback, isExceptionThrown, dataMap);
}
}));
}
return Futures.allAsList(result);
}
复制代码
同步执行的逻辑:异步
private <T> T syncExecute(final SQLType sqlType, final BaseStatementUnit baseStatementUnit, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback) throws Exception {
return executeInternal(sqlType, baseStatementUnit, parameterSets, executeCallback, ExecutorExceptionHandler.isExceptionThrown(), ExecutorDataMap.getDataMap());
}
复制代码
executeInternal:async
private <T> T executeInternal(final SQLType sqlType, final BaseStatementUnit baseStatementUnit, final List<List<Object>> parameterSets, final ExecuteCallback<T> executeCallback,
final boolean isExceptionThrown, final Map<String, Object> dataMap) throws Exception {
synchronized (baseStatementUnit.getStatement().getConnection()) {
T result;
ExecutorExceptionHandler.setExceptionThrown(isExceptionThrown);
ExecutorDataMap.setDataMap(dataMap);
List<AbstractExecutionEvent> events = new LinkedList<>();
//添加执行前事件
if (parameterSets.isEmpty()) {
events.add(getExecutionEvent(sqlType, baseStatementUnit, Collections.emptyList()));
}
//添加执行前事件
for (List<Object> each : parameterSets) {
events.add(getExecutionEvent(sqlType, baseStatementUnit, each));
}
//发布事件:EventBus 单例
for (AbstractExecutionEvent event : events) {
EventBusInstance.getInstance().post(event);
}
//调用executeCallback 执行SQL
try {
result = executeCallback.execute(baseStatementUnit);
} catch (final SQLException ex) {
//发布执行失败事件
for (AbstractExecutionEvent each : events) {
each.setEventExecutionType(EventExecutionType.EXECUTE_FAILURE);
each.setException(Optional.of(ex));
EventBusInstance.getInstance().post(each);
ExecutorExceptionHandler.handleException(ex);
}
return null;
}
//发布执行成功事件
for (AbstractExecutionEvent each : events) {
each.setEventExecutionType(EventExecutionType.EXECUTE_SUCCESS);
EventBusInstance.getInstance().post(each);
}
return result;
}
}
复制代码
小尾巴走一波,欢迎关注个人公众号,不按期分享编程、投资、生活方面的感悟:)分布式