/* * Copyright 2014-2015 the original author or authors * * Licensed under the Apache License, Version 2.0 (the “License”); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an “AS IS” BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Created on 2014年4月24日 // $Id$ package com.wplatform.ddal.dispatch.rule; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; import com.wplatform.ddal.util.New; import com.wplatform.ddal.value.Value; /** * @author <a href="mailto:jorgie.mail@gmail.com">jorgie li</a> */ public class RoutingCalculatorImpl implements RoutingCalculator { private RuleEvaluator evaluator = new OgnlRuleEvaluator(); public RuleEvaluator getEvaluator() { return evaluator; } public void setEvaluator(RuleEvaluator evaluator) { this.evaluator = evaluator; } @Override public RoutingResult calculate(TableRouter tableRouter, Map<String, List<Value>> columnValue) { if (tableRouter == null) { throw new IllegalArgumentException("tableRule is null."); } if (columnValue == null) { throw new IllegalArgumentException("columnValue is null."); } RuleExpression expression = tableRouter.getRuleExpression(); List<TableNode> tableNode = null; if (canUseRule(expression, columnValue)) { tableNode = evaluateTableRule(tableRouter, columnValue); } else { // 无库规则,库的范围是TableRule配置的所有库 tableNode = tableRouter.getPartition(); } return new RoutingResult(tableRouter.getPartition(), tableNode); } /** * @param ruleToUse * @param columnValue * @param tableRule */ private List<TableNode> evaluateTableRule(TableRouter tr, Map<String, List<Value>> args) { List<TableNode> result = New.arrayList(); RuleExpression rule = tr.getRuleExpression(); List<TableNode> partion = tr.getPartition(); List<RuleColumn> ruleColumns = rule.getRuleColumns(); Map<String, List<Value>> paramCollections = New.hashMap(ruleColumns.size(), 1L); for (RuleColumn ruleColumn : ruleColumns) { String name = ruleColumn.getName(); paramCollections.put(name, args.get(name)); } // 一个规则存在多个RuleColumn,多个RuleColumn对应的取值集合做笛卡尔积后的所有集 for (Map<String, Value> parameters : new CrossedCollection(paramCollections)) { TableNode tableNode = null; Object evlValue = evaluator.evaluate(rule, parameters); if (evlValue == null) { throw new RuleEvaluateException("The rule expression " + rule.getExpression() + " evaluate a null value."); } if (evlValue instanceof TableNode) { if (!partion.contains(evlValue)) { throw new RuleEvaluateException("The rule expression " + rule.getExpression() + " evaluated " + evlValue + " is not in partition list."); } tableNode = (TableNode) evlValue; } else if (evlValue.getClass() == int.class || evlValue.getClass() == Integer.class || evlValue.getClass() == long.class || evlValue.getClass() == Long.class || evlValue.getClass() == short.class || evlValue.getClass() == Short.class || evlValue.getClass() == byte.class || evlValue.getClass() == Byte.class) { try { int index = Integer.parseInt(evlValue.toString()); tableNode = partion.get(index); } catch (IndexOutOfBoundsException e) { throw new RuleEvaluateException("The rule expression " + rule.getExpression() + " evaluated " + evlValue + " is out of range partition list."); } } else { throw new RuleEvaluateException("The group rule expression " + rule.getExpression() + " return a value " + evlValue.getClass() + " which type is unsupported."); } result.add(tableNode); } return result; } /** * 对于分库分表存在多个Rule的情况下,choiceRule负责根据表的字段值选取一个符合条件的Rule做为sharding规则, * 先择的规则按优先顺序,优先最大匹配,先匹配所有列,找不到再去除可选列之后匹配 * * @param rules * @param columnValue * @return */ private boolean canUseRule(RuleExpression rule, Map<String, List<Value>> columnValue) { if (rule == null) { return false; } // 完全匹配所有列 List<RuleColumn> ruleColumns = rule.getRuleColumns(); for (RuleColumn ruleColumn : ruleColumns) { List<Value> values = columnValue.get(ruleColumn.getName()); if (values == null || values.isEmpty()) { return false; } } return true; } /** * 将列的值域通过笛卡尔积运算,转化为参数一一对应的值域 * <p> * <p> * <pre> * 如输入参数: { * column1:{ 1, 2, 3 }, * column2:{ a, b, c, d } * } * </pre> * <p> * <p> * <pre> * 输出结果:{ * {column1=1, column2=a} * {column1=1, column2=b} * {column1=1, column2=c} * {column1=2, column2=a} * {column1=2, column2=b} * {column1=2, column2=c} * } * </pre> * * @author <a href="mailto:jorgie.mail@gmail.com">jorgie li</a> */ private static class CrossedCollection implements Iterable<Map<String, Value>> { private Map<String, List<Value>> collection; private List<Map<String, Value>> crossedResult; /** * @param collection */ private CrossedCollection(Map<String, List<Value>> collection) { this.collection = collection; } /* (non-Javadoc) * @see java.lang.Iterable#iterator() */ @Override public Iterator<Map<String, Value>> iterator() { crossedResult = cross(this.collection); return crossedResult.iterator(); } private List<Map<String, Value>> cross(Map<String, List<Value>> crossArgs) { // Set是无顺的且不能按顺号迭代,先将Set转为List Map<String, List<Value>> crossSource = New.hashMap(crossArgs.size(), 1L); List<String> columnNames = new ArrayList<String>(crossArgs.keySet()); // 计算出笛卡尔积行数 int rows = columnNames.size() > 0 ? 1 : 0; for (String column : columnNames) { crossSource.put(column, new ArrayList<Value>(crossArgs.get(column))); rows *= crossArgs.get(column).size(); } // 笛卡尔积索引记录 int[] record = new int[columnNames.size()]; List<Map<String, Value>> results = new ArrayList<Map<String, Value>>(); // 产生笛卡尔积 for (int i = 0; i < rows; i++) { // List<String> row = new ArrayList<String>(); Map<String, Value> row = New.linkedHashMap(record.length, 1L); // 生成笛卡尔积的每组数据 for (int index = 0; index < record.length; index++) { String columnName = columnNames.get(index); List<Value> columnValues = crossSource.get(columnName); row.put(columnName, columnValues.get(record[index])); } results.add(row); crossRecord(columnNames, crossSource, record, crossArgs.size() - 1); } return results; } /** * 产生笛卡尔积当前行索引记录. * * @param sourceArgs 要产生笛卡尔积的源数据 * @param record 每行笛卡尔积的索引组合 * @param level 索引组合的当前计算层级 */ private void crossRecord(List<String> columnNames, Map<String, List<Value>> crossSource, int[] record, int level) { record[level] = record[level] + 1; if (record[level] >= crossSource.get(columnNames.get(level)).size() && level > 0) { record[level] = 0; crossRecord(columnNames, crossSource, record, level - 1); } } } }