/* * Copyright 2014-2015 the original author or authors * * Licensed under the Apache License, Version 2.0 (the “License”); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an “AS IS” BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.wplatform.ddal.excutor.dml; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import com.wplatform.ddal.command.Prepared; import com.wplatform.ddal.dbobject.index.IndexCondition; import com.wplatform.ddal.dbobject.table.Column; import com.wplatform.ddal.dbobject.table.TableMate; import com.wplatform.ddal.dispatch.RoutingHandler; import com.wplatform.ddal.dispatch.rule.RoutingResult; import com.wplatform.ddal.dispatch.rule.TableNode; import com.wplatform.ddal.excutor.CommonPreparedExecutor; import com.wplatform.ddal.excutor.JdbcWorker; import com.wplatform.ddal.message.DbException; import com.wplatform.ddal.result.Row; import com.wplatform.ddal.result.SearchRow; import com.wplatform.ddal.util.New; import com.wplatform.ddal.util.StatementBuilder; import com.wplatform.ddal.value.Value; import com.wplatform.ddal.value.ValueNull; /** * @author <a href="mailto:jorgie.mail@gmail.com">jorgie li</a> * */ public abstract class PreparedRoutingExecutor<T extends Prepared> extends CommonPreparedExecutor<T> { protected final RoutingHandler routingHandler; /** * @param prepared */ public PreparedRoutingExecutor(T prepared) { super(prepared); this.routingHandler = database.getRoutingHandler(); } protected int updateRow(TableMate table, Row row) { session.checkCanceled(); RoutingResult result = routingHandler.doRoute(table, row); return invokeUpdateRow(result, row); } protected int updateRow(TableMate table, Row row, List<IndexCondition> where) { session.checkCanceled(); RoutingResult result = routingHandler.doRoute(table, session, where); return invokeUpdateRow(result, row); } protected int updateRows(TableMate table, List<Row> rows) { Map<BatchKey, List<List<Value>>> batches = New.hashMap(); session.checkCanceled(); for (Row row : rows) { RoutingResult result = routingHandler.doRoute(table, row); TableNode[] selectNodes = result.getSelectNodes(); for (TableNode node : selectNodes) { StatementBuilder sqlBuff = new StatementBuilder(); List<Value> params = doTranslate(node, row, sqlBuff); BatchKey batchKey = new BatchKey(node.getShardName(), sqlBuff.toString()); List<List<Value>> batchArgs = batches.get(batchKey); if (batchArgs == null) { batchArgs = New.arrayList(10); batches.put(batchKey, batchArgs); } batchArgs.add(params); } } List<JdbcWorker<Integer[]>> workers = New.arrayList(batches.size()); for (Map.Entry<BatchKey, List<List<Value>>> entry : batches.entrySet()) { String shardName = entry.getKey().shardName; String sql = entry.getKey().sql; List<List<Value>> array = entry.getValue(); workers.add(createBatchUpdateWorker(shardName, sql, array)); } try { addRuningJdbcWorkers(workers); int affectRows = 0; if (workers.size() > 1) { int queryTimeout = getQueryTimeout();//MILLISECONDS List<Future<Integer[]>> invokeAll; if(queryTimeout > 0) { invokeAll = jdbcExecutor.invokeAll(workers,queryTimeout,TimeUnit.MILLISECONDS); } else { invokeAll = jdbcExecutor.invokeAll(workers); } for (Future<Integer[]> future : invokeAll) { Integer[] integers = future.get(); for (Integer integer : integers) { affectRows += integer; } } } else if (workers.size() == 1) { Integer[] integers = workers.get(0).doWork(); for (Integer integer : integers) { affectRows += integer; } } return affectRows; } catch (InterruptedException e) { throw DbException.convert(e); } catch (ExecutionException e) { throw DbException.convert(e.getCause()); } finally { removeRuningJdbcWorkers(workers); for (JdbcWorker<Integer[]> jdbcWorker : workers) { jdbcWorker.closeResource(); } } } protected abstract List<Value> doTranslate(TableNode node, SearchRow row, StatementBuilder buff); protected static boolean isNull(Value v) { return v == null || v == ValueNull.INSTANCE; } /** * @param result * @param row * @return */ private int invokeUpdateRow(RoutingResult result, Row row) { List<JdbcWorker<Integer>> workers = New.arrayList(result.tableNodeCount()); TableNode[] selectNodes = result.getSelectNodes(); for (TableNode node : selectNodes) { StatementBuilder sqlBuff = new StatementBuilder(); List<Value> params = doTranslate(node, row, sqlBuff); workers.add(createUpdateWorker(node.getShardName(), sqlBuff.toString(), params)); } try { addRuningJdbcWorkers(workers); int affectRows = 0; if (workers.size() > 1) { int queryTimeout = getQueryTimeout();//MILLISECONDS List<Future<Integer>> invokeAll; if(queryTimeout > 0) { invokeAll = jdbcExecutor.invokeAll(workers,queryTimeout,TimeUnit.MILLISECONDS); } else { invokeAll = jdbcExecutor.invokeAll(workers); } for (Future<Integer> future : invokeAll) { affectRows += future.get(); } } else if (workers.size() == 1) { affectRows = workers.get(0).doWork(); } return affectRows; } catch (InterruptedException e) { throw DbException.convert(e); } catch (ExecutionException e) { throw DbException.convert(e.getCause()); } finally { removeRuningJdbcWorkers(workers); for (JdbcWorker<Integer> jdbcWorker : workers) { jdbcWorker.closeResource(); } } } /** * build insert statement * @param forTable * @param columns * @param row * @param buff * @return */ protected List<Value> buildInsert(String forTable, Column[] columns, SearchRow row, StatementBuilder buff) { ArrayList<Value> params = New.arrayList(); buff.append("INSERT INTO "); buff.append(identifier(forTable)).append('('); for (Column c : columns) { buff.appendExceptFirst(", "); buff.append(c.getSQL()); } buff.append(") "); buff.resetCount(); buff.append("VALUES( "); for (int i = 0; i < columns.length; i++) { Value v = row.getValue(i); buff.appendExceptFirst(", "); if (v == null) { buff.append("DEFAULT"); } else if (isNull(v)) { buff.append("NULL"); } else { buff.append('?'); params.add(v); } } buff.append(")"); return params; } private static class BatchKey implements Serializable { private static final long serialVersionUID = 1L; private final String shardName; private final String sql; /** * @param shardName * @param sql */ private BatchKey(String shardName, String sql) { super(); this.shardName = shardName; this.sql = sql; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((shardName == null) ? 0 : shardName.hashCode()); result = prime * result + ((sql == null) ? 0 : sql.hashCode()); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; BatchKey other = (BatchKey) obj; if (shardName == null) { if (other.shardName != null) return false; } else if (!shardName.equals(other.shardName)) return false; if (sql == null) { if (other.sql != null) return false; } else if (!sql.equals(other.sql)) return false; return true; } @Override public String toString() { return "BatchKey [shardName=" + shardName + ", sql=" + sql + "]"; } } }