package demo.catlets; import io.mycat.cache.LayerCachePool; import io.mycat.route.RouteResultset; import io.mycat.route.RouteResultsetNode; import io.mycat.route.factory.RouteStrategyFactory; import io.mycat.server.ErrorCode; import io.mycat.server.Fields; import io.mycat.server.MySQLFrontConnection; import io.mycat.server.config.node.SchemaConfig; import io.mycat.server.config.node.SystemConfig; import io.mycat.server.packet.FieldPacket; import io.mycat.server.packet.RowDataPacket; import io.mycat.server.parser.ServerParse; import io.mycat.sqlengine.AllJobFinishedListener; import io.mycat.sqlengine.Catlet; import io.mycat.sqlengine.EngineCtx; import io.mycat.sqlengine.SQLJobHandler; import io.mycat.sqlengine.sharejoin.JoinParser; import io.mycat.util.ByteUtil; import io.mycat.util.ResultSetUtil; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; //import org.opencloudb.route.RouteStrategy; //import org.opencloudb.route.impl.DruidMysqlRouteStrategy; //import org.opencloudb.parser.druid.DruidParser; import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.ast.statement.SQLSelectQuery; import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; /** * 功能详细描述:分片join * @author sohudo[http://blog.csdn.net/wind520] * @create 2015年01月22日 下午6:50:23 * @version 0.0.1 */ public class ShareJoin implements Catlet { private EngineCtx ctx; private RouteResultset rrs ; private JoinParser joinParser; private Map<String, byte[]> rows = new ConcurrentHashMap<String, byte[]>(); private Map<String,String> ids = new ConcurrentHashMap<String,String>(); //private ConcurrentLinkedQueue<String> ids = new ConcurrentLinkedQueue<String>(); private List<byte[]> fields; //主表的字段 private ArrayList<byte[]> allfields;//所有的字段 private boolean isMfield=false; private int mjob=0; private int maxjob=0; private int joinindex=0;//关联join表字段的位置 private int sendField=0; private boolean childRoute=false; private boolean jointTableIsData=false; // join 字段的类型,一般情况都是int, long; 增加该字段为了支持非int,long类型的(一般为varchar)joinkey的sharejoin // 参见:io.mycat.server.packet.FieldPacket 属性: public int type; // 参见:http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition private int joinKeyType = Fields.FIELD_TYPE_LONG; // 默认 join 字段为int型 //重新路由使用 private SystemConfig sysConfig; private SchemaConfig schema; private int sqltype; private String charset; private MySQLFrontConnection sc; private LayerCachePool cachePool; public void setRoute(RouteResultset rrs){ this.rrs =rrs; } public void route(SystemConfig sysConfig, SchemaConfig schema,int sqlType, String realSQL, String charset, MySQLFrontConnection sc, LayerCachePool cachePool) { int rs = ServerParse.parse(realSQL); this.sqltype = rs & 0xff; this.sysConfig=sysConfig; this.schema=schema; this.charset=charset; this.sc=sc; this.cachePool=cachePool; try { // RouteStrategy routes=RouteStrategyFactory.getRouteStrategy(); // rrs =RouteStrategyFactory.getRouteStrategy().route(sysConfig, schema, sqlType2, realSQL,charset, sc, cachePool); MySqlStatementParser parser = new MySqlStatementParser(realSQL); SQLStatement statement = parser.parseStatement(); if(statement instanceof SQLSelectStatement) { SQLSelectStatement st=(SQLSelectStatement)statement; SQLSelectQuery sqlSelectQuery =st.getSelect().getQuery(); if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)st.getSelect().getQuery(); joinParser=new JoinParser(mysqlSelectQuery,realSQL); joinParser.parser(); } } /* if (routes instanceof DruidMysqlRouteStrategy) { SQLSelectStatement st=((DruidMysqlRouteStrategy) routes).getSQLStatement(); SQLSelectQuery sqlSelectQuery =st.getSelect().getQuery(); if(sqlSelectQuery instanceof MySqlSelectQueryBlock) { MySqlSelectQueryBlock mysqlSelectQuery = (MySqlSelectQueryBlock)st.getSelect().getQuery(); joinParser=new JoinParser(mysqlSelectQuery,realSQL); joinParser.parser(); } } */ } catch (Exception e) { } } private void getRoute(String sql){ try { if (joinParser!=null){ rrs =RouteStrategyFactory.getRouteStrategy().route(sysConfig, schema, sqltype,sql,charset, sc, cachePool); } } catch (Exception e) { } } private String[] getDataNodes(){ String[] dataNodes =new String[rrs.getNodes().length] ; for (int i=0;i<rrs.getNodes().length;i++){ dataNodes[i]=rrs.getNodes()[i].getName(); } return dataNodes; } private String getDataNode(String[] dataNodes){ String dataNode=""; for (int i=0;i<dataNodes.length;i++){ dataNode+=dataNodes[i]+","; } return dataNode; } public void processSQL(String sql, EngineCtx ctx) { String ssql=joinParser.getSql(); getRoute(ssql); RouteResultsetNode[] nodes = rrs.getNodes(); if (nodes == null || nodes.length == 0 || nodes[0].getName() == null || nodes[0].getName().equals("")) { ctx.getSession().getSource().writeErrMessage(ErrorCode.ER_NO_DB_ERROR, "No dataNode found ,please check tables defined in schema:" + ctx.getSession().getSource().getSchema()); return; } this.ctx=ctx; String[] dataNodes =getDataNodes(); maxjob=dataNodes.length; ShareDBJoinHandler joinHandler = new ShareDBJoinHandler(this,joinParser.getJoinLkey()); ctx.executeNativeSQLSequnceJob(dataNodes, ssql, joinHandler); EngineCtx.LOGGER.info("Catlet exec:"+getDataNode(getDataNodes())+" sql:" +ssql); ctx.setAllJobFinishedListener(new AllJobFinishedListener() { @Override public void onAllJobFinished(EngineCtx ctx) { if (!jointTableIsData) { ctx.writeHeader(fields); } ctx.writeEof(); EngineCtx.LOGGER.info("发送数据OK"); } }); } public void putDBRow(String id,String nid, byte[] rowData,int findex){ rows.put(id, rowData); ids.put(id, nid); joinindex=findex; //ids.offer(nid); int batchSize = 999; // 满1000条,发送一个查询请求 if (ids.size() > batchSize) { createQryJob(batchSize); } } public void putDBFields(List<byte[]> mFields){ if (!isMfield){ fields=mFields; } } public void endJobInput(String dataNode, boolean failed){ mjob++; if (mjob>=maxjob){ createQryJob(Integer.MAX_VALUE); ctx.endJobInput(); } // EngineCtx.LOGGER.info("完成"+mjob+":" + dataNode+" failed:"+failed); } //private void createQryJob(String dataNode,int batchSize) { private void createQryJob(int batchSize) { int count = 0; Map<String, byte[]> batchRows = new ConcurrentHashMap<String, byte[]>(); String theId = null; StringBuilder sb = new StringBuilder().append('('); String svalue=""; for(Map.Entry<String,String> e: ids.entrySet() ){ theId=e.getKey(); batchRows.put(theId, rows.remove(theId)); if (!svalue.equals(e.getValue())){ if(joinKeyType == Fields.FIELD_TYPE_VAR_STRING || joinKeyType == Fields.FIELD_TYPE_STRING){ // joinkey 为varchar sb.append("'").append(e.getValue()).append("'").append(','); // ('digdeep','yuanfang') }else{ // 默认joinkey为int/long sb.append(e.getValue()).append(','); // (1,2,3) } } svalue=e.getValue(); if (count++ > batchSize) { break; } } /* while ((theId = ids.poll()) != null) { batchRows.put(theId, rows.remove(theId)); sb.append(theId).append(','); if (count++ > batchSize) { break; } } */ if (count == 0) { return; } jointTableIsData=true; sb.deleteCharAt(sb.length() - 1).append(')'); String sql = String.format(joinParser.getChildSQL(), sb); //if (!childRoute){ getRoute(sql); //childRoute=true; //} ctx.executeNativeSQLParallJob(getDataNodes(),sql, new ShareRowOutPutDataHandler(this,fields,joinindex,joinParser.getJoinRkey(), batchRows)); EngineCtx.LOGGER.info("SQLParallJob:"+getDataNode(getDataNodes())+" sql:" + sql); } public void writeHeader(String dataNode,List<byte[]> afields, List<byte[]> bfields) { sendField++; if (sendField==1){ ctx.writeHeader(afields, bfields); setAllFields(afields, bfields); // EngineCtx.LOGGER.info("发送字段2:" + dataNode); } } private void setAllFields(List<byte[]> afields, List<byte[]> bfields){ allfields=new ArrayList<byte[]>(); for (byte[] field : afields) { allfields.add(field); } //EngineCtx.LOGGER.info("所有字段2:" +allfields.size()); for (int i=1;i<bfields.size();i++){ allfields.add(bfields.get(i)); } } public List<byte[]> getAllFields(){ return allfields; } public void writeRow(RowDataPacket rowDataPkg){ ctx.writeRow(rowDataPkg); } public int getFieldIndex(List<byte[]> fields,String fkey){ int i=0; for (byte[] field :fields) { FieldPacket fieldPacket = new FieldPacket(); fieldPacket.read(field); if (ByteUtil.getString(fieldPacket.name).equals(fkey)){ joinKeyType = fieldPacket.type; return i; } i++; } return i; } } class ShareDBJoinHandler implements SQLJobHandler { private List<byte[]> fields; private final ShareJoin ctx; private String joinkey; public ShareDBJoinHandler(ShareJoin ctx,String joinField) { super(); this.ctx = ctx; this.joinkey=joinField; //EngineCtx.LOGGER.info("二次查询:" +" sql:" + querySQL+"/"+joinkey); } //private Map<String, byte[]> rows = new ConcurrentHashMap<String, byte[]>(); //private ConcurrentLinkedQueue<String> ids = new ConcurrentLinkedQueue<String>(); @Override public void onHeader(String dataNode, byte[] header, List<byte[]> fields) { this.fields = fields; ctx.putDBFields(fields); } /* public static String getFieldNames(List<byte[]> fields){ String str=""; for (byte[] field :fields) { FieldPacket fieldPacket = new FieldPacket(); fieldPacket.read(field); str+=ByteUtil.getString(fieldPacket.name)+","; } return str; } public static String getFieldName(byte[] field){ FieldPacket fieldPacket = new FieldPacket(); fieldPacket.read(field); return ByteUtil.getString(fieldPacket.name); } */ @Override public boolean onRowData(String dataNode, byte[] rowData) { int fid=this.ctx.getFieldIndex(fields,joinkey); String id = ResultSetUtil.getColumnValAsString(rowData, fields, 0);//主键,默认id String nid = ResultSetUtil.getColumnValAsString(rowData, fields, fid); // 放入结果集 //rows.put(id, rowData); ctx.putDBRow(id,nid, rowData,fid); return false; } @Override public void finished(String dataNode, boolean failed) { ctx.endJobInput(dataNode,failed); } } class ShareRowOutPutDataHandler implements SQLJobHandler { private final List<byte[]> afields; private List<byte[]> bfields; private final ShareJoin ctx; private final Map<String, byte[]> arows; private int joinL;//A表(左边)关联字段的位置 private int joinR;//B表(右边)关联字段的位置 private String joinRkey;//B表(右边)关联字段 public ShareRowOutPutDataHandler(ShareJoin ctx,List<byte[]> afields,int joini,String joinField,Map<String, byte[]> arows) { super(); this.afields = afields; this.ctx = ctx; this.arows = arows; this.joinL =joini; this.joinRkey= joinField; //EngineCtx.LOGGER.info("二次查询:" +arows.size()+ " afields:"+FenDBJoinHandler.getFieldNames(afields)); } @Override public void onHeader(String dataNode, byte[] header, List<byte[]> bfields) { this.bfields=bfields; joinR=this.ctx.getFieldIndex(bfields,joinRkey); ctx.writeHeader(dataNode,afields, bfields); } //不是主键,获取join左边的的记录 private byte[] getRow(String value,int index){ for(Map.Entry<String,byte[]> e: arows.entrySet() ){ String key=e.getKey(); RowDataPacket rowDataPkg = ResultSetUtil.parseRowData(e.getValue(), afields); String id = ByteUtil.getString(rowDataPkg.fieldValues.get(index)); if (id.equals(value)){ return arows.remove(key); } } return null; } @Override public boolean onRowData(String dataNode, byte[] rowData) { RowDataPacket rowDataPkgold = ResultSetUtil.parseRowData(rowData, bfields); // 获取Id字段, String id = ByteUtil.getString(rowDataPkgold.fieldValues.get(joinR)); // 查找ID对应的A表的记录 byte[] arow = getRow(id,joinL);//arows.remove(id); while (arow!=null) { RowDataPacket rowDataPkg = ResultSetUtil.parseRowData(arow,afields );//ctx.getAllFields()); for (int i=1;i<rowDataPkgold.fieldCount;i++){ // 设置b.name 字段 byte[] bname = rowDataPkgold.fieldValues.get(i); rowDataPkg.add(bname); rowDataPkg.addFieldCount(1); } //RowData(rowDataPkg); ctx.writeRow(rowDataPkg); arow = getRow(id,joinL); } return false; } @Override public void finished(String dataNode, boolean failed) { // EngineCtx.LOGGER.info("完成2:" + dataNode+" failed:"+failed); } }