/* * Copyright 2015 Liu Huanting. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package fm.liu.timo.route; import java.sql.SQLSyntaxErrorException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map.Entry; import java.util.Set; import org.pmw.tinylog.Logger; import fm.liu.timo.config.model.Database; import fm.liu.timo.config.model.Function; import fm.liu.timo.config.model.Table; import fm.liu.timo.config.model.Table.TableType; import fm.liu.timo.parser.ast.expression.primary.RowExpression; import fm.liu.timo.parser.ast.stmt.SQLStatement; import fm.liu.timo.parser.ast.stmt.ddl.DDLStatement; import fm.liu.timo.parser.ast.stmt.dml.DMLInsertReplaceStatement; import fm.liu.timo.parser.recognizer.SQLParserDelegate; import fm.liu.timo.parser.visitor.OutputVisitor; import fm.liu.timo.route.visitor.RouteVisitor; import fm.liu.timo.server.parser.ServerParse; /** * 路由计算器 * @author Liu Huanting 2015年5月10日 */ public class Router { // hint格式示例: // /*!timo:node1,3*/select * from table_a; 仅在节点1和节点3上执行该语句 // /*!timo:master*/select * from table_a; 读写分离时强制在主库上执行该语句 private static final String HINT = "/*!timo:"; private static final String NODE = "node"; private static final String MASTER = "master"; public static Outlets route(Database database, String sql, String charset, int type) throws SQLSyntaxErrorException { Outlets outlets = new Outlets(); sql = sql.trim(); if (sql.startsWith(HINT)) { int end = sql.indexOf("*/"); if (end > 0) { String hint = sql.substring(HINT.length(), end).trim(); sql = sql.substring(end + "*/".length()); type = ServerParse.parse(sql) & 0xff; if (hint.startsWith(NODE)) { String[] nodes = hint.substring(NODE.length()).split(","); for (String node : nodes) { int id = Integer.parseInt(node); if (!database.getNodes().contains(id)) { throw new IllegalArgumentException( "unknown datanoe" + id + " in hint:" + hint); } outlets.add(new Outlet(id, sql)); } return outlets; } else if (hint.startsWith(MASTER)) { outlets.setUsingMaster(true); } else { Logger.warn("unsupported hint: {}", sql); } } } sql = removeDB(sql, database.getName()); SQLStatement stmt = SQLParserDelegate.parse(sql, charset); RouteVisitor visitor = new RouteVisitor(database); stmt.accept(visitor); Table table = visitor.getTable(); if (table == null) { outlets.add(new Outlet(database.getRandomNode(), sql)); return outlets; } ArrayList<Object> values = visitor.getValues(); int info = visitor.getInfo(); outlets.setInfo(info); switch (type) { case ServerParse.SELECT: if (TableType.GLOBAL.equals(table.getType())) { outlets.add(new Outlet(table.getRandomNode(), sql)); return outlets; } if ((info & Info.HAS_GROUPBY) == Info.HAS_GROUPBY) { outlets.setGroupBy(visitor.getGroupBy()); } if ((info & Info.HAS_ORDERBY) == Info.HAS_ORDERBY) { outlets.setOrderBy(visitor.getOrderBy()); } if ((info & Info.HAS_LIMIT) == Info.HAS_LIMIT) { outlets.setLimit(visitor.getLimitSize(), visitor.getLimitOffset()); } break; case ServerParse.INSERT: case ServerParse.REPLACE: if (TableType.SPLIT.equals(table.getType())) { return routeBatch(table, (DMLInsertReplaceStatement) stmt, visitor.getBatchIndex(), type, outlets, values); } break; } if ((info & Info.TO_ALL_NODE) == Info.TO_ALL_NODE) { return toAllNode(outlets, table, sql); } return route(stmt, outlets, table, values, sql); } private static Outlets toAllNode(Outlets outlets, Table table, String sql) { table.getNodes().forEach(i -> outlets.add(new Outlet(i, sql))); return outlets; } /** * <pre> * turn * INSERT/REPLACE INTO TABLE_A(COL_1,COL2,...) VALUES (VAL_11,VAL12,...),(VAL_21,VAL22,...),(VAL_31,VAL32,...)...; * into something like * INSERT/REPLACE INTO TABLE_A(COL_1,COL2,...) VALUES (VAL_11,VAL12,...),(VAL_31,VAL_32,...)...; * INSERT/REPLACE INTO TABLE_A(COL_1,COL2,...) VALUES (VAL_21,VAL22,...),(VAL_41,VAL_42,...)...; * </pre> */ private static Outlets routeBatch(Table table, DMLInsertReplaceStatement stmt, int index, int type, Outlets outlets, ArrayList<Object> values) { List<RowExpression> rows = stmt.getRowList(); HashMap<Integer, List<RowExpression>> results = new HashMap<>(); if (values.isEmpty()) { throw new IllegalArgumentException("can't route without the value of split column"); } else { Function function = table.getRule().getFunction(); int i = 0; for (Object value : values) { int node = function.calcute(value); if (results.containsKey(node)) { results.get(node).add(rows.get(i++)); } else { List<RowExpression> exps = new ArrayList<>(); exps.add(rows.get(i++)); results.put(node, exps); } } for (Entry<Integer, List<RowExpression>> entry : results.entrySet()) { stmt.setReplaceRowList(entry.getValue()); outlets.add(new Outlet(entry.getKey(), updateSQL(stmt))); stmt.clearReplaceRowList(); } } return outlets; } private static String updateSQL(SQLStatement stmt) { OutputVisitor visitor = new OutputVisitor(new StringBuilder(), true); stmt.accept(visitor); return visitor.getSql(); } private static Outlets route(SQLStatement stmt, Outlets outlets, Table table, ArrayList<Object> values, String sql) { if (!(stmt instanceof DDLStatement)) { sql = updateSQL(stmt); } if (values.isEmpty()) { for (Integer id : table.getNodes()) { Outlet out = new Outlet(id, sql); outlets.add(out); } } else { Function function = table.getRule().getFunction(); Set<Integer> result = function.calcute(values); for (int id : result) { Outlet out = new Outlet(id, sql); outlets.add(out); } } return outlets; } /** * 数据库前缀移除 */ private static String removeDB(String sql, String database) { final String upSQL = sql.toUpperCase(); final String upDB = database.toUpperCase() + "."; int strtPos = 0; int indx = 0; boolean flag = false; indx = upSQL.indexOf(upDB, strtPos); if (indx < 0) { StringBuilder sb = new StringBuilder("`").append(database.toUpperCase()).append("`."); indx = upSQL.indexOf(sb.toString(), strtPos); flag = true; if (indx < 0) { return sql; } } StringBuilder sb = new StringBuilder(); while (indx > 0) { sb.append(sql.substring(strtPos, indx)); strtPos = indx + upDB.length(); if (flag) { strtPos += 2; } indx = upSQL.indexOf(upDB, strtPos); } sb.append(sql.substring(strtPos)); return sb.toString(); } }