/*
* Copyright 1999-2015 dangdang.com.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* </p>
*/
package com.dangdang.ddframe.rdb.sharding.jdbc;
import com.dangdang.ddframe.rdb.sharding.executor.StatementExecutor;
import com.dangdang.ddframe.rdb.sharding.executor.wrapper.StatementExecutorWrapper;
import com.dangdang.ddframe.rdb.sharding.jdbc.adapter.AbstractStatementAdapter;
import com.dangdang.ddframe.rdb.sharding.merger.ResultSetFactory;
import com.dangdang.ddframe.rdb.sharding.parser.result.GeneratedKeyContext;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.MergeContext;
import com.dangdang.ddframe.rdb.sharding.router.SQLExecutionUnit;
import com.dangdang.ddframe.rdb.sharding.router.SQLRouteResult;
import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
* 支持分片的静态语句对象.
*
* @author gaohongtao
* @author caohao
*/
public class ShardingStatement extends AbstractStatementAdapter {
private static final Function<BackendStatementWrapper, Statement> TRANSFORM_FUNCTION = new Function<BackendStatementWrapper, Statement>() {
@Override
public Statement apply(final BackendStatementWrapper input) {
return input.getStatement();
}
};
@Getter(AccessLevel.PROTECTED)
private final ShardingConnection shardingConnection;
@Getter
private final int resultSetType;
@Getter
private final int resultSetConcurrency;
@Getter
private final int resultSetHoldability;
private final Deque<List<BackendStatementWrapper>> cachedRoutedStatements = Lists.newLinkedList();
@Getter(AccessLevel.PROTECTED)
@Setter(AccessLevel.PROTECTED)
private MergeContext mergeContext;
@Setter(AccessLevel.PROTECTED)
private ResultSet currentResultSet;
@Getter(AccessLevel.PROTECTED)
@Setter(AccessLevel.PROTECTED)
private GeneratedKeyContext generatedKeyContext;
private ResultSet generatedKeyResultSet;
ShardingStatement(final ShardingConnection shardingConnection) {
this(shardingConnection, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}
ShardingStatement(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency) {
this(shardingConnection, resultSetType, resultSetConcurrency, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}
public ShardingStatement(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(Statement.class);
this.shardingConnection = shardingConnection;
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
cachedRoutedStatements.add(new LinkedList<BackendStatementWrapper>());
cachedRoutedStatements.add(new LinkedList<BackendStatementWrapper>());
}
@Override
public Connection getConnection() throws SQLException {
return shardingConnection;
}
@Override
public ResultSet executeQuery(final String sql) throws SQLException {
ResultSet rs;
try {
rs = ResultSetFactory.getResultSet(generateExecutor(sql).executeQuery(), mergeContext);
} finally {
clearRouteContext();
}
setCurrentResultSet(rs);
return rs;
}
@Override
public int executeUpdate(final String sql) throws SQLException {
try {
return generateExecutor(sql).executeUpdate();
} finally {
clearRouteContext();
}
}
@Override
public int executeUpdate(final String sql, final int autoGeneratedKeys) throws SQLException {
try {
return generateExecutor(sql).executeUpdate(autoGeneratedKeys);
} finally {
if (null != generatedKeyContext) {
generatedKeyContext.setAutoGeneratedKeys(autoGeneratedKeys);
}
clearRouteContext();
}
}
@Override
public int executeUpdate(final String sql, final int[] columnIndexes) throws SQLException {
try {
return generateExecutor(sql).executeUpdate(columnIndexes);
} finally {
if (null != generatedKeyContext) {
generatedKeyContext.setColumnIndexes(columnIndexes);
}
clearRouteContext();
}
}
@Override
public int executeUpdate(final String sql, final String[] columnNames) throws SQLException {
try {
return generateExecutor(sql).executeUpdate(columnNames);
} finally {
if (null != generatedKeyContext) {
generatedKeyContext.setColumnNames(columnNames);
}
clearRouteContext();
}
}
@Override
public boolean execute(final String sql) throws SQLException {
try {
return generateExecutor(sql).execute();
} finally {
clearRouteContext();
}
}
@Override
public boolean execute(final String sql, final int autoGeneratedKeys) throws SQLException {
try {
return generateExecutor(sql).execute(autoGeneratedKeys);
} finally {
if (null != generatedKeyContext) {
generatedKeyContext.setAutoGeneratedKeys(autoGeneratedKeys);
}
clearRouteContext();
}
}
@Override
public boolean execute(final String sql, final int[] columnIndexes) throws SQLException {
try {
return generateExecutor(sql).execute(columnIndexes);
} finally {
if (null != generatedKeyContext) {
generatedKeyContext.setColumnIndexes(columnIndexes);
}
clearRouteContext();
}
}
@Override
public boolean execute(final String sql, final String[] columnNames) throws SQLException {
try {
return generateExecutor(sql).execute(columnNames);
} finally {
if (null != generatedKeyContext) {
generatedKeyContext.setColumnNames(columnNames);
}
clearRouteContext();
}
}
@Override
public final ResultSet getGeneratedKeys() throws SQLException {
if (null != generatedKeyResultSet) {
return generatedKeyResultSet;
}
if (null == generatedKeyContext || generatedKeyContext.getColumnNameToIndexMap().isEmpty()) {
Collection<? extends Statement> routedStatements = getRoutedStatements();
if (1 == routedStatements.size()) {
return generatedKeyResultSet = routedStatements.iterator().next().getGeneratedKeys();
}
}
if (Statement.RETURN_GENERATED_KEYS != generatedKeyContext.getAutoGeneratedKeys() && null == generatedKeyContext.getColumnIndexes()
&& null == generatedKeyContext.getColumnNames()) {
return generatedKeyResultSet = new GeneratedKeysResultSet();
}
return generatedKeyResultSet = new GeneratedKeysResultSet(generateAutoIncrementTable(), generatedKeyContext.getColumnNameToIndexMap(), this);
}
private Table<Integer, Integer, Object> generateAutoIncrementTable() {
if (null != generatedKeyContext.getColumnIndexes()) {
return subTable(generatedKeyContext.getColumnIndexes());
} else if (null != generatedKeyContext.getColumnNames()) {
List<Integer> columnIndexes = new ArrayList<>(generatedKeyContext.getColumnNames().length);
for (String each : generatedKeyContext.getColumnNames()) {
if (!generatedKeyContext.getColumnNameToIndexMap().containsKey(each)) {
continue;
}
columnIndexes.add(generatedKeyContext.getColumnNameToIndexMap().get(each) + 1);
}
int[] parameter = new int[columnIndexes.size()];
int index = 0;
for (Integer each : columnIndexes) {
parameter[index++] = each;
}
return subTable(parameter);
}
return generatedKeyContext.getValueTable();
}
private Table<Integer, Integer, Object> subTable(final int[] columnIndexes) {
Table<Integer, Integer, Object> result = TreeBasedTable.create();
for (int each : columnIndexes) {
for (Map.Entry<Integer, Object> eachEntry : generatedKeyContext.getValueTable().column(each - 1).entrySet()) {
result.put(eachEntry.getKey(), each - 1, eachEntry.getValue());
}
}
return result;
}
protected void clearRouteContext() throws SQLException {
setCurrentResultSet(null);
List<BackendStatementWrapper> firstList = cachedRoutedStatements.pollFirst();
cachedRoutedStatements.getFirst().addAll(firstList);
firstList.clear();
cachedRoutedStatements.addLast(firstList);
generatedKeyResultSet = null;
}
private StatementExecutor generateExecutor(final String sql) throws SQLException {
StatementExecutor result = new StatementExecutor(shardingConnection.getShardingContext().getExecutorEngine());
SQLRouteResult sqlRouteResult = shardingConnection.getShardingContext().getSqlRouteEngine().route(sql);
generatedKeyContext = sqlRouteResult.getGeneratedKeyContext();
mergeContext = sqlRouteResult.getMergeContext();
for (SQLExecutionUnit each : sqlRouteResult.getExecutionUnits()) {
Statement statement = getStatement(shardingConnection.getConnection(each.getDataSource(), sqlRouteResult.getSqlStatementType()), each.getSql());
replayMethodsInvocation(statement);
result.addStatement(new StatementExecutorWrapper(statement, each));
}
return result;
}
protected Statement getStatement(final Connection connection, final String sql) throws SQLException {
BackendStatementWrapper statement = null;
for (Iterator<BackendStatementWrapper> iterator = cachedRoutedStatements.getFirst().iterator(); iterator.hasNext();) {
BackendStatementWrapper each = iterator.next();
if (each.isBelongTo(connection, sql)) {
statement = each;
iterator.remove();
}
}
if (null == statement) {
statement = generateStatement(connection, sql);
}
cachedRoutedStatements.getLast().add(statement);
return statement.getStatement();
}
protected BackendStatementWrapper generateStatement(final Connection connection, final String sql) throws SQLException {
Statement result;
if (0 == resultSetHoldability) {
result = connection.createStatement(resultSetType, resultSetConcurrency);
} else {
result = connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
}
return new BackendStatementWrapper(result);
}
@Override
public ResultSet getResultSet() throws SQLException {
if (null != currentResultSet) {
return currentResultSet;
}
List<ResultSet> resultSets = new ArrayList<>(getRoutedStatements().size());
if (getRoutedStatements().size() == 1) {
currentResultSet = getRoutedStatements().iterator().next().getResultSet();
return currentResultSet;
}
for (Statement each : getRoutedStatements()) {
resultSets.add(each.getResultSet());
}
currentResultSet = ResultSetFactory.getResultSet(resultSets, mergeContext);
return currentResultSet;
}
@Override
protected void clearRouteStatements() {
cachedRoutedStatements.getFirst().clear();
cachedRoutedStatements.getLast().clear();
}
@Override
public Collection<? extends Statement> getRoutedStatements() {
return Lists.newArrayList(Iterators.concat(Iterators.transform(cachedRoutedStatements.getFirst().iterator(), TRANSFORM_FUNCTION),
Iterators.transform(cachedRoutedStatements.getLast().iterator(), TRANSFORM_FUNCTION)));
}
}