/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.lucene.search.grouping; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NavigableSet; import java.util.TreeSet; import org.apache.lucene.search.FieldComparator; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; /** * Represents a group that is found during the first pass search. * * @lucene.experimental */ public class SearchGroup<T> { /** The value that defines this group */ public T groupValue; /** The sort values used during sorting. These are the * groupSort field values of the highest rank document * (by the groupSort) within the group. Can be * <code>null</code> if <code>fillFields=false</code> had * been passed to {@link FirstPassGroupingCollector#getTopGroups} */ public Object[] sortValues; @Override public String toString() { return("SearchGroup(groupValue=" + groupValue + " sortValues=" + Arrays.toString(sortValues) + ")"); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SearchGroup<?> that = (SearchGroup<?>) o; if (groupValue == null) { if (that.groupValue != null) { return false; } } else if (!groupValue.equals(that.groupValue)) { return false; } return true; } @Override public int hashCode() { return groupValue != null ? groupValue.hashCode() : 0; } private static class ShardIter<T> { public final Iterator<SearchGroup<T>> iter; public final int shardIndex; public ShardIter(Collection<SearchGroup<T>> shard, int shardIndex) { this.shardIndex = shardIndex; iter = shard.iterator(); assert iter.hasNext(); } public SearchGroup<T> next() { assert iter.hasNext(); final SearchGroup<T> group = iter.next(); if (group.sortValues == null) { throw new IllegalArgumentException("group.sortValues is null; you must pass fillFields=true to the first pass collector"); } return group; } @Override public String toString() { return "ShardIter(shard=" + shardIndex + ")"; } } // Holds all shards currently on the same group private static class MergedGroup<T> { // groupValue may be null! public final T groupValue; public Object[] topValues; public final List<ShardIter<T>> shards = new ArrayList<>(); public int minShardIndex; public boolean processed; public boolean inQueue; public MergedGroup(T groupValue) { this.groupValue = groupValue; } // Only for assert private boolean neverEquals(Object _other) { if (_other instanceof MergedGroup) { MergedGroup<?> other = (MergedGroup<?>) _other; if (groupValue == null) { assert other.groupValue != null; } else { assert !groupValue.equals(other.groupValue); } } return true; } @Override public boolean equals(Object _other) { // We never have another MergedGroup instance with // same groupValue assert neverEquals(_other); if (_other instanceof MergedGroup) { MergedGroup<?> other = (MergedGroup<?>) _other; if (groupValue == null) { return other == null; } else { return groupValue.equals(other); } } else { return false; } } @Override public int hashCode() { if (groupValue == null) { return 0; } else { return groupValue.hashCode(); } } } private static class GroupComparator<T> implements Comparator<MergedGroup<T>> { @SuppressWarnings("rawtypes") public final FieldComparator[] comparators; public final int[] reversed; @SuppressWarnings({"unchecked", "rawtypes"}) public GroupComparator(Sort groupSort) { final SortField[] sortFields = groupSort.getSort(); comparators = new FieldComparator[sortFields.length]; reversed = new int[sortFields.length]; for (int compIDX = 0; compIDX < sortFields.length; compIDX++) { final SortField sortField = sortFields[compIDX]; comparators[compIDX] = sortField.getComparator(1, compIDX); reversed[compIDX] = sortField.getReverse() ? -1 : 1; } } @Override @SuppressWarnings({"unchecked","rawtypes"}) public int compare(MergedGroup<T> group, MergedGroup<T> other) { if (group == other) { return 0; } //System.out.println("compare group=" + group + " other=" + other); final Object[] groupValues = group.topValues; final Object[] otherValues = other.topValues; //System.out.println(" groupValues=" + groupValues + " otherValues=" + otherValues); for (int compIDX = 0;compIDX < comparators.length; compIDX++) { final int c = reversed[compIDX] * comparators[compIDX].compareValues(groupValues[compIDX], otherValues[compIDX]); if (c != 0) { return c; } } // Tie break by min shard index: assert group.minShardIndex != other.minShardIndex; return group.minShardIndex - other.minShardIndex; } } private static class GroupMerger<T> { private final GroupComparator<T> groupComp; private final NavigableSet<MergedGroup<T>> queue; private final Map<T,MergedGroup<T>> groupsSeen; public GroupMerger(Sort groupSort) { groupComp = new GroupComparator<>(groupSort); queue = new TreeSet<>(groupComp); groupsSeen = new HashMap<>(); } @SuppressWarnings({"unchecked","rawtypes"}) private void updateNextGroup(int topN, ShardIter<T> shard) { while(shard.iter.hasNext()) { final SearchGroup<T> group = shard.next(); MergedGroup<T> mergedGroup = groupsSeen.get(group.groupValue); final boolean isNew = mergedGroup == null; //System.out.println(" next group=" + (group.groupValue == null ? "null" : ((BytesRef) group.groupValue).utf8ToString()) + " sort=" + Arrays.toString(group.sortValues)); if (isNew) { // Start a new group: //System.out.println(" new"); mergedGroup = new MergedGroup<>(group.groupValue); mergedGroup.minShardIndex = shard.shardIndex; assert group.sortValues != null; mergedGroup.topValues = group.sortValues; groupsSeen.put(group.groupValue, mergedGroup); mergedGroup.inQueue = true; queue.add(mergedGroup); } else if (mergedGroup.processed) { // This shard produced a group that we already // processed; move on to next group... continue; } else { //System.out.println(" old"); boolean competes = false; for(int compIDX=0;compIDX<groupComp.comparators.length;compIDX++) { final int cmp = groupComp.reversed[compIDX] * groupComp.comparators[compIDX].compareValues(group.sortValues[compIDX], mergedGroup.topValues[compIDX]); if (cmp < 0) { // Definitely competes competes = true; break; } else if (cmp > 0) { // Definitely does not compete break; } else if (compIDX == groupComp.comparators.length-1) { if (shard.shardIndex < mergedGroup.minShardIndex) { competes = true; } } } //System.out.println(" competes=" + competes); if (competes) { // Group's sort changed -- remove & re-insert if (mergedGroup.inQueue) { queue.remove(mergedGroup); } mergedGroup.topValues = group.sortValues; mergedGroup.minShardIndex = shard.shardIndex; queue.add(mergedGroup); mergedGroup.inQueue = true; } } mergedGroup.shards.add(shard); break; } // Prune un-competitive groups: while(queue.size() > topN) { final MergedGroup<T> group = queue.pollLast(); //System.out.println("PRUNE: " + group); group.inQueue = false; } } public Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> shards, int offset, int topN) { final int maxQueueSize = offset + topN; //System.out.println("merge"); // Init queue: for(int shardIDX=0;shardIDX<shards.size();shardIDX++) { final Collection<SearchGroup<T>> shard = shards.get(shardIDX); if (!shard.isEmpty()) { //System.out.println(" insert shard=" + shardIDX); updateNextGroup(maxQueueSize, new ShardIter<>(shard, shardIDX)); } } // Pull merged topN groups: final List<SearchGroup<T>> newTopGroups = new ArrayList<>(topN); int count = 0; while(!queue.isEmpty()) { final MergedGroup<T> group = queue.pollFirst(); group.processed = true; //System.out.println(" pop: shards=" + group.shards + " group=" + (group.groupValue == null ? "null" : (((BytesRef) group.groupValue).utf8ToString())) + " sortValues=" + Arrays.toString(group.topValues)); if (count++ >= offset) { final SearchGroup<T> newGroup = new SearchGroup<>(); newGroup.groupValue = group.groupValue; newGroup.sortValues = group.topValues; newTopGroups.add(newGroup); if (newTopGroups.size() == topN) { break; } //} else { // System.out.println(" skip < offset"); } // Advance all iters in this group: for(ShardIter<T> shardIter : group.shards) { updateNextGroup(maxQueueSize, shardIter); } } if (newTopGroups.isEmpty()) { return null; } else { return newTopGroups; } } } /** Merges multiple collections of top groups, for example * obtained from separate index shards. The provided * groupSort must match how the groups were sorted, and * the provided SearchGroups must have been computed * with fillFields=true passed to {@link * FirstPassGroupingCollector#getTopGroups}. * * <p>NOTE: this returns null if the topGroups is empty. */ public static <T> Collection<SearchGroup<T>> merge(List<Collection<SearchGroup<T>>> topGroups, int offset, int topN, Sort groupSort) { if (topGroups.isEmpty()) { return null; } else { return new GroupMerger<T>(groupSort).merge(topGroups, offset, topN); } } }