package org.apache.lucene.search.grouping.term; /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.search.*; import org.apache.lucene.search.grouping.AbstractAllGroupHeadsCollector; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.SentinelIntSet; import java.io.IOException; import java.util.*; /** * A base implementation of {@link org.apache.lucene.search.grouping.AbstractAllGroupHeadsCollector} for retrieving the most relevant groups when grouping * on a string based group field. More specifically this all concrete implementations of this base implementation * use {@link org.apache.lucene.search.FieldCache.DocTermsIndex}. * * @lucene.experimental */ public abstract class TermAllGroupHeadsCollector<GH extends AbstractAllGroupHeadsCollector.GroupHead<?>> extends AbstractAllGroupHeadsCollector<GH> { private static final int DEFAULT_INITIAL_SIZE = 128; final String groupField; final BytesRef scratchBytesRef = new BytesRef(); FieldCache.DocTermsIndex groupIndex; AtomicReaderContext readerContext; protected TermAllGroupHeadsCollector(String groupField, int numberOfSorts) { super(numberOfSorts); this.groupField = groupField; } /** * Creates an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments. * This factory method decides with implementation is best suited. * * Delegates to {@link #create(String, org.apache.lucene.search.Sort, int)} with an initialSize of 128. * * @param groupField The field to group by * @param sortWithinGroup The sort within each group * @return an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments */ public static AbstractAllGroupHeadsCollector<?> create(String groupField, Sort sortWithinGroup) { return create(groupField, sortWithinGroup, DEFAULT_INITIAL_SIZE); } /** * Creates an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments. * This factory method decides with implementation is best suited. * * @param groupField The field to group by * @param sortWithinGroup The sort within each group * @param initialSize The initial allocation size of the internal int set and group list which should roughly match * the total number of expected unique groups. Be aware that the heap usage is * 4 bytes * initialSize. * @return an <code>AbstractAllGroupHeadsCollector</code> instance based on the supplied arguments */ public static AbstractAllGroupHeadsCollector<?> create(String groupField, Sort sortWithinGroup, int initialSize) { boolean sortAllScore = true; boolean sortAllFieldValue = true; for (SortField sortField : sortWithinGroup.getSort()) { if (sortField.getType() == SortField.Type.SCORE) { sortAllFieldValue = false; } else if (needGeneralImpl(sortField)) { return new GeneralAllGroupHeadsCollector(groupField, sortWithinGroup); } else { sortAllScore = false; } } if (sortAllScore) { return new ScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); } else if (sortAllFieldValue) { return new OrdAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); } else { return new OrdScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); } } // Returns when a sort field needs the general impl. private static boolean needGeneralImpl(SortField sortField) { SortField.Type sortType = sortField.getType(); // Note (MvG): We can also make an optimized impl when sorting is SortField.DOC return sortType != SortField.Type.STRING_VAL && sortType != SortField.Type.STRING && sortType != SortField.Type.SCORE; } // A general impl that works for any group sort. static class GeneralAllGroupHeadsCollector extends TermAllGroupHeadsCollector<GeneralAllGroupHeadsCollector.GroupHead> { private final Sort sortWithinGroup; private final Map<BytesRef, GroupHead> groups; private Scorer scorer; GeneralAllGroupHeadsCollector(String groupField, Sort sortWithinGroup) { super(groupField, sortWithinGroup.getSort().length); this.sortWithinGroup = sortWithinGroup; groups = new HashMap<BytesRef, GroupHead>(); final SortField[] sortFields = sortWithinGroup.getSort(); for (int i = 0; i < sortFields.length; i++) { reversed[i] = sortFields[i].getReverse() ? -1 : 1; } } protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { final int ord = groupIndex.getOrd(doc); final BytesRef groupValue = ord == 0 ? null : groupIndex.lookup(ord, scratchBytesRef); GroupHead groupHead = groups.get(groupValue); if (groupHead == null) { groupHead = new GroupHead(groupValue, sortWithinGroup, doc); groups.put(groupValue == null ? null : BytesRef.deepCopyOf(groupValue), groupHead); temporalResult.stop = true; } else { temporalResult.stop = false; } temporalResult.groupHead = groupHead; } protected Collection<GroupHead> getCollectedGroupHeads() { return groups.values(); } public void setNextReader(AtomicReaderContext context) throws IOException { this.readerContext = context; groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), groupField); for (GroupHead groupHead : groups.values()) { for (int i = 0; i < groupHead.comparators.length; i++) { groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context); } } } public void setScorer(Scorer scorer) throws IOException { this.scorer = scorer; for (GroupHead groupHead : groups.values()) { for (FieldComparator<?> comparator : groupHead.comparators) { comparator.setScorer(scorer); } } } class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> { final FieldComparator<?>[] comparators; @SuppressWarnings({"unchecked","rawtypes"}) private GroupHead(BytesRef groupValue, Sort sort, int doc) throws IOException { super(groupValue, doc + readerContext.docBase); final SortField[] sortFields = sort.getSort(); comparators = new FieldComparator[sortFields.length]; for (int i = 0; i < sortFields.length; i++) { comparators[i] = sortFields[i].getComparator(1, i).setNextReader(readerContext); comparators[i].setScorer(scorer); comparators[i].copy(0, doc); comparators[i].setBottom(0); } } public int compare(int compIDX, int doc) throws IOException { return comparators[compIDX].compareBottom(doc); } public void updateDocHead(int doc) throws IOException { for (FieldComparator<?> comparator : comparators) { comparator.copy(0, doc); comparator.setBottom(0); } this.doc = doc + readerContext.docBase; } } } // AbstractAllGroupHeadsCollector optimized for ord fields and scores. static class OrdScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector<OrdScoreAllGroupHeadsCollector.GroupHead> { private final SentinelIntSet ordSet; private final List<GroupHead> collectedGroups; private final SortField[] fields; private FieldCache.DocTermsIndex[] sortsIndex; private Scorer scorer; private GroupHead[] segmentGroupHeads; OrdScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { super(groupField, sortWithinGroup.getSort().length); ordSet = new SentinelIntSet(initialSize, -1); collectedGroups = new ArrayList<GroupHead>(initialSize); final SortField[] sortFields = sortWithinGroup.getSort(); fields = new SortField[sortFields.length]; sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; for (int i = 0; i < sortFields.length; i++) { reversed[i] = sortFields[i].getReverse() ? -1 : 1; fields[i] = sortFields[i]; } } protected Collection<GroupHead> getCollectedGroupHeads() { return collectedGroups; } public void setScorer(Scorer scorer) throws IOException { this.scorer = scorer; } protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { int key = groupIndex.getOrd(doc); GroupHead groupHead; if (!ordSet.exists(key)) { ordSet.put(key); BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); groupHead = new GroupHead(doc, term); collectedGroups.add(groupHead); segmentGroupHeads[key] = groupHead; temporalResult.stop = true; } else { temporalResult.stop = false; groupHead = segmentGroupHeads[key]; } temporalResult.groupHead = groupHead; } public void setNextReader(AtomicReaderContext context) throws IOException { this.readerContext = context; groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), groupField); for (int i = 0; i < fields.length; i++) { if (fields[i].getType() == SortField.Type.SCORE) { continue; } sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader(), fields[i].getField()); } // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. ordSet.clear(); segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; for (GroupHead collectedGroup : collectedGroups) { int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); if (ord >= 0) { ordSet.put(ord); segmentGroupHeads[ord] = collectedGroup; for (int i = 0; i < sortsIndex.length; i++) { if (fields[i].getType() == SortField.Type.SCORE) { continue; } collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); } } } } class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> { BytesRef[] sortValues; int[] sortOrds; float[] scores; private GroupHead(int doc, BytesRef groupValue) throws IOException { super(groupValue, doc + readerContext.docBase); sortValues = new BytesRef[sortsIndex.length]; sortOrds = new int[sortsIndex.length]; scores = new float[sortsIndex.length]; for (int i = 0; i < sortsIndex.length; i++) { if (fields[i].getType() == SortField.Type.SCORE) { scores[i] = scorer.score(); } else { sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); sortOrds[i] = sortsIndex[i].getOrd(doc); } } } public int compare(int compIDX, int doc) throws IOException { if (fields[compIDX].getType() == SortField.Type.SCORE) { float score = scorer.score(); if (scores[compIDX] < score) { return 1; } else if (scores[compIDX] > score) { return -1; } return 0; } else { if (sortOrds[compIDX] < 0) { // The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative. return sortValues[compIDX].compareTo(sortsIndex[compIDX].getTerm(doc, scratchBytesRef)); } else { return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); } } } public void updateDocHead(int doc) throws IOException { for (int i = 0; i < sortsIndex.length; i++) { if (fields[i].getType() == SortField.Type.SCORE) { scores[i] = scorer.score(); } else { sortValues[i] = sortsIndex[i].getTerm(doc, sortValues[i]); sortOrds[i] = sortsIndex[i].getOrd(doc); } } this.doc = doc + readerContext.docBase; } } } // AbstractAllGroupHeadsCollector optimized for ord fields. static class OrdAllGroupHeadsCollector extends TermAllGroupHeadsCollector<OrdAllGroupHeadsCollector.GroupHead> { private final SentinelIntSet ordSet; private final List<GroupHead> collectedGroups; private final SortField[] fields; private FieldCache.DocTermsIndex[] sortsIndex; private GroupHead[] segmentGroupHeads; OrdAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { super(groupField, sortWithinGroup.getSort().length); ordSet = new SentinelIntSet(initialSize, -1); collectedGroups = new ArrayList<GroupHead>(initialSize); final SortField[] sortFields = sortWithinGroup.getSort(); fields = new SortField[sortFields.length]; sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; for (int i = 0; i < sortFields.length; i++) { reversed[i] = sortFields[i].getReverse() ? -1 : 1; fields[i] = sortFields[i]; } } protected Collection<GroupHead> getCollectedGroupHeads() { return collectedGroups; } public void setScorer(Scorer scorer) throws IOException { } protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { int key = groupIndex.getOrd(doc); GroupHead groupHead; if (!ordSet.exists(key)) { ordSet.put(key); BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); groupHead = new GroupHead(doc, term); collectedGroups.add(groupHead); segmentGroupHeads[key] = groupHead; temporalResult.stop = true; } else { temporalResult.stop = false; groupHead = segmentGroupHeads[key]; } temporalResult.groupHead = groupHead; } public void setNextReader(AtomicReaderContext context) throws IOException { this.readerContext = context; groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), groupField); for (int i = 0; i < fields.length; i++) { sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader(), fields[i].getField()); } // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. ordSet.clear(); segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; for (GroupHead collectedGroup : collectedGroups) { int groupOrd = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); if (groupOrd >= 0) { ordSet.put(groupOrd); segmentGroupHeads[groupOrd] = collectedGroup; for (int i = 0; i < sortsIndex.length; i++) { collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); } } } } class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> { BytesRef[] sortValues; int[] sortOrds; private GroupHead(int doc, BytesRef groupValue) { super(groupValue, doc + readerContext.docBase); sortValues = new BytesRef[sortsIndex.length]; sortOrds = new int[sortsIndex.length]; for (int i = 0; i < sortsIndex.length; i++) { sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); sortOrds[i] = sortsIndex[i].getOrd(doc); } } public int compare(int compIDX, int doc) throws IOException { if (sortOrds[compIDX] < 0) { // The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative. return sortValues[compIDX].compareTo(sortsIndex[compIDX].getTerm(doc, scratchBytesRef)); } else { return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); } } public void updateDocHead(int doc) throws IOException { for (int i = 0; i < sortsIndex.length; i++) { sortValues[i] = sortsIndex[i].getTerm(doc, sortValues[i]); sortOrds[i] = sortsIndex[i].getOrd(doc); } this.doc = doc + readerContext.docBase; } } } // AbstractAllGroupHeadsCollector optimized for scores. static class ScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector<ScoreAllGroupHeadsCollector.GroupHead> { private final SentinelIntSet ordSet; private final List<GroupHead> collectedGroups; private final SortField[] fields; private Scorer scorer; private GroupHead[] segmentGroupHeads; ScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { super(groupField, sortWithinGroup.getSort().length); ordSet = new SentinelIntSet(initialSize, -1); collectedGroups = new ArrayList<GroupHead>(initialSize); final SortField[] sortFields = sortWithinGroup.getSort(); fields = new SortField[sortFields.length]; for (int i = 0; i < sortFields.length; i++) { reversed[i] = sortFields[i].getReverse() ? -1 : 1; fields[i] = sortFields[i]; } } protected Collection<GroupHead> getCollectedGroupHeads() { return collectedGroups; } public void setScorer(Scorer scorer) throws IOException { this.scorer = scorer; } protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { int key = groupIndex.getOrd(doc); GroupHead groupHead; if (!ordSet.exists(key)) { ordSet.put(key); BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); groupHead = new GroupHead(doc, term); collectedGroups.add(groupHead); segmentGroupHeads[key] = groupHead; temporalResult.stop = true; } else { temporalResult.stop = false; groupHead = segmentGroupHeads[key]; } temporalResult.groupHead = groupHead; } public void setNextReader(AtomicReaderContext context) throws IOException { this.readerContext = context; groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), groupField); // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. ordSet.clear(); segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; for (GroupHead collectedGroup : collectedGroups) { int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); if (ord >= 0) { ordSet.put(ord); segmentGroupHeads[ord] = collectedGroup; } } } class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead<BytesRef> { float[] scores; private GroupHead(int doc, BytesRef groupValue) throws IOException { super(groupValue, doc + readerContext.docBase); scores = new float[fields.length]; float score = scorer.score(); for (int i = 0; i < scores.length; i++) { scores[i] = score; } } public int compare(int compIDX, int doc) throws IOException { float score = scorer.score(); if (scores[compIDX] < score) { return 1; } else if (scores[compIDX] > score) { return -1; } return 0; } public void updateDocHead(int doc) throws IOException { float score = scorer.score(); for (int i = 0; i < scores.length; i++) { scores[i] = score; } this.doc = doc + readerContext.docBase; } } } }