package org.araqne.logdb.query.command; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import org.araqne.logdb.Query; import org.araqne.logdb.QueryCommand; import org.araqne.logdb.QueryResultSet; import org.araqne.logdb.QueryStopReason; import org.araqne.logdb.QueryTask; import org.araqne.logdb.Row; import org.araqne.logdb.RowBatch; import org.araqne.logdb.SubQueryCommand; import org.araqne.logdb.SubQueryTask; import org.araqne.logdb.query.command.Sort.SortField; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class Join extends QueryCommand implements SubQueryCommand { public enum JoinType { Inner, Left, Right, Full } private final Logger logger = LoggerFactory.getLogger(Join.class); private final JoinType joinType; // for hash join private HashMap<JoinKeys, List<Object>> hashJoinMap; private JoinKeys joinKeys; private int joinKeyCount; private SortField[] sortFields; private Query subQuery; private SubQueryTask subQueryTask; // tasks private PostSubQueryTask postSubQueryTask = new PostSubQueryTask(); private SortMergeJoiner sortMergeJoiner; public Join(JoinType joinType, SortField[] sortFields, Query subQuery) { this.joinType = joinType; this.joinKeyCount = sortFields.length; this.joinKeys = new JoinKeys(new Object[joinKeyCount]); this.sortFields = sortFields; this.subQuery = subQuery; this.subQueryTask = new SubQueryTask(subQuery); this.sortMergeJoiner = new SortMergeJoiner(joinType, sortFields, new SortMergeJoinerCallback(this)); logger.debug("araqne logdb: join subquery created [{}:{}]", subQuery.getId(), subQuery.getQueryString()); } @Override public String getName() { return "join"; } @Override public boolean isReducer() { return true; } @Override public Query getSubQuery() { return subQuery; } @Override public void onStart() { postSubQueryTask.addDependency(subQueryTask); postSubQueryTask.addSubTask(subQueryTask); } @Override public void onClose(QueryStopReason reason) { if (hashJoinMap != null) { hashJoinMap = null; } else { if (reason == QueryStopReason.PartialFetch || reason == QueryStopReason.End) { sortMergeJoiner.merge(); } else { try { sortMergeJoiner.cancel(); } catch (Throwable t) { logger.error("araqne logdb: can not cancel sortMergeJoiner", t); } } } try { subQuery.cancel(reason); } catch (Throwable t) { logger.error("araqne logdb: cannot stop subquery [" + subQuery.getQueryString() + "]", t); } finally { subQuery.purge(); } } @Override public void onPush(RowBatch rowBatch) { if (rowBatch.selectedInUse) { for (int i = 0; i < rowBatch.size; i++) { Row row = rowBatch.rows[rowBatch.selected[i]]; onPush(row); } } else { for (int i = 0; i < rowBatch.size; i++) { Row row = rowBatch.rows[i]; onPush(row); } } } @Override public void onPush(Row m) { if (hashJoinMap != null) { int i = 0; for (SortField f : sortFields) { Object joinValue = m.get(f.getName()); if (joinValue instanceof Integer || joinValue instanceof Short) joinValue = ((Number) joinValue).longValue(); joinKeys.keys[i++] = joinValue; } List<Object> l = hashJoinMap.get(joinKeys); if (l == null) { if (joinType == JoinType.Left) pushPipe(m); return; } for (Object o : l) { @SuppressWarnings("unchecked") Map<String, Object> sm = (Map<String, Object>) o; Map<String, Object> joinMap = new HashMap<String, Object>(m.map()); joinMap.putAll(sm); pushPipe(new Row(joinMap)); } return; } else { try { sortMergeJoiner.setR(m); } catch (Throwable t) { logger.error("araqne logdb: cannot setR on sortMergeJoiner[" + m.toString() + "]", t); } } } @Override public QueryTask getMainTask() { return postSubQueryTask; } public JoinType getType() { return joinType; } public SortField[] getSortFields() { return sortFields; } @Override public String toString() { String typeOpt = ""; if (joinType != JoinType.Inner) typeOpt = " type=" + joinType.toString().toLowerCase(); return "join" + typeOpt + " " + SortField.serialize(sortFields) + " [ " + subQuery.getQueryString() + " ] "; } // bulid hash table or sort private class PostSubQueryTask extends QueryTask { private final int HASH_JOIN_THRESHOLD = Integer.parseInt(System.getProperty("araqne.hashjointhreshold", "100000")); @Override public void run() { logger.debug("araqne logdb: join subquery end, main query [{}] sub query [{}]", query.getId(), subQuery.getId()); QueryResultSet rs = null; try { rs = subQuery.getResultSet(); logger.debug("araqne logdb: join fetch subquery result of query [{}:{}]", query.getId(), query.getQueryString()); if (rs.size() <= HASH_JOIN_THRESHOLD && (joinType == JoinType.Inner || joinType == JoinType.Left)) buildHashJoinTable(rs); else sortMergeJoiner.setS(rs); } catch (Throwable e) { logger.error("araqne logdb: cannot get subquery result of query " + query.getId(), e); } finally { if (rs != null) { rs.close(); } } } private void buildHashJoinTable(QueryResultSet rs) { hashJoinMap = new HashMap<JoinKeys, List<Object>>(HASH_JOIN_THRESHOLD); while (rs.hasNext()) { Map<String, Object> sm = rs.next(); Object[] keys = new Object[joinKeyCount]; for (int i = 0; i < joinKeyCount; i++) { Object joinValue = sm.get(sortFields[i].getName()); if (joinValue instanceof Integer || joinValue instanceof Short) { joinValue = ((Number) joinValue).longValue(); } keys[i] = joinValue; } JoinKeys joinKeys = new JoinKeys(keys); List<Object> l = hashJoinMap.get(joinKeys); if (l == null) { l = new ArrayList<Object>(2); hashJoinMap.put(joinKeys, l); } l.add(sm); } } } public static class JoinKeys { public Object[] keys; public JoinKeys(Object[] keys) { this.keys = keys; } @Override public int hashCode() { return Arrays.hashCode(keys); } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; JoinKeys other = (JoinKeys) obj; return equals(keys, other.keys); } public static boolean equals(Object[] a, Object[] a2) { if (a == a2) return true; if (a == null || a2 == null) return false; int length = a.length; if (a2.length != length) return false; for (int i = 0; i < length; i++) { Object o1 = a[i]; Object o2 = a2[i]; if ((o1 == null || o2 == null) || (!o1.equals(o2))) return false; } return true; } } class SortMergeJoinerCallback implements SortMergeJoinerListener { Join join; SortMergeJoinerCallback(Join join) { this.join = join; } @Override public void onPushPipe(Row row) { join.pushPipe(row); } } }