/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.lucene.util; import java.util.Arrays; /** Radix selector. * <p>This implementation works similarly to a MSB radix sort except that it * only recurses into the sub partition that contains the desired value. * @lucene.internal */ public abstract class RadixSelector extends Selector { // after that many levels of recursion we fall back to introselect anyway // this is used as a protection against the fact that radix sort performs // worse when there are long common prefixes (probably because of cache // locality) private static final int LEVEL_THRESHOLD = 8; // size of histograms: 256 + 1 to indicate that the string is finished private static final int HISTOGRAM_SIZE = 257; // buckets below this size will be sorted with introselect private static final int LENGTH_THRESHOLD = 100; // we store one histogram per recursion level private final int[] histogram = new int[HISTOGRAM_SIZE]; private final int[] commonPrefix; private final int maxLength; /** * Sole constructor. * @param maxLength the maximum length of keys, pass {@link Integer#MAX_VALUE} if unknown. */ protected RadixSelector(int maxLength) { this.maxLength = maxLength; this.commonPrefix = new int[Math.min(24, maxLength)]; } /** Return the k-th byte of the entry at index {@code i}, or {@code -1} if * its length is less than or equal to {@code k}. This may only be called * with a value of {@code i} between {@code 0} included and * {@code maxLength} excluded. */ protected abstract int byteAt(int i, int k); /** Get a fall-back selector which may assume that the first {@code d} bytes * of all compared strings are equal. This fallback selector is used when * the range becomes narrow or when the maximum level of recursion has * been exceeded. */ protected Selector getFallbackSelector(int d) { return new IntroSelector() { @Override protected void swap(int i, int j) { RadixSelector.this.swap(i, j); } @Override protected int compare(int i, int j) { for (int o = d; o < maxLength; ++o) { final int b1 = byteAt(i, o); final int b2 = byteAt(j, o); if (b1 != b2) { return b1 - b2; } else if (b1 == -1) { break; } } return 0; } @Override protected void setPivot(int i) { pivot.setLength(0); for (int o = d; o < maxLength; ++o) { final int b = byteAt(i, o); if (b == -1) { break; } pivot.append((byte) b); } } @Override protected int comparePivot(int j) { for (int o = 0; o < pivot.length(); ++o) { final int b1 = pivot.byteAt(o) & 0xff; final int b2 = byteAt(j, d + o); if (b1 != b2) { return b1 - b2; } } if (d + pivot.length() == maxLength) { return 0; } return -1 - byteAt(j, d + pivot.length()); } private final BytesRefBuilder pivot = new BytesRefBuilder(); }; } @Override public void select(int from, int to, int k) { checkArgs(from, to, k); select(from, to, k, 0, 0); } private void select(int from, int to, int k, int d, int l) { if (to - from <= LENGTH_THRESHOLD || d >= LEVEL_THRESHOLD) { getFallbackSelector(d).select(from, to, k); } else { radixSelect(from, to, k, d, l); } } /** * @param d the character number to compare * @param l the level of recursion */ private void radixSelect(int from, int to, int k, int d, int l) { final int[] histogram = this.histogram; Arrays.fill(histogram, 0); final int commonPrefixLength = computeCommonPrefixLengthAndBuildHistogram(from, to, d, histogram); if (commonPrefixLength > 0) { // if there are no more chars to compare or if all entries fell into the // first bucket (which means strings are shorter than d) then we are done // otherwise recurse if (d + commonPrefixLength < maxLength && histogram[0] < to - from) { radixSelect(from, to, k, d + commonPrefixLength, l); } return; } assert assertHistogram(commonPrefixLength, histogram); int bucketFrom = from; for (int bucket = 0; bucket < HISTOGRAM_SIZE; ++bucket) { final int bucketTo = bucketFrom + histogram[bucket]; if (bucketTo > k) { partition(from, to, bucket, bucketFrom, bucketTo, d); if (bucket != 0 && d + 1 < maxLength) { // all elements in bucket 0 are equal so we only need to recurse if bucket != 0 select(bucketFrom, bucketTo, k, d + 1, l + 1); } return; } bucketFrom = bucketTo; } throw new AssertionError("Unreachable code"); } // only used from assert private boolean assertHistogram(int commonPrefixLength, int[] histogram) { int numberOfUniqueBytes = 0; for (int freq : histogram) { if (freq > 0) { numberOfUniqueBytes++; } } if (numberOfUniqueBytes == 1) { assert commonPrefixLength >= 1; } else { assert commonPrefixLength == 0; } return true; } /** Return a number for the k-th character between 0 and {@link #HISTOGRAM_SIZE}. */ private int getBucket(int i, int k) { return byteAt(i, k) + 1; } /** Build a histogram of the number of values per {@link #getBucket(int, int) bucket} * and return a common prefix length for all visited values. * @see #buildHistogram */ private int computeCommonPrefixLengthAndBuildHistogram(int from, int to, int k, int[] histogram) { final int[] commonPrefix = this.commonPrefix; int commonPrefixLength = Math.min(commonPrefix.length, maxLength - k); for (int j = 0; j < commonPrefixLength; ++j) { final int b = byteAt(from, k + j); commonPrefix[j] = b; if (b == -1) { commonPrefixLength = j + 1; break; } } int i; outer: for (i = from + 1; i < to; ++i) { for (int j = 0; j < commonPrefixLength; ++j) { final int b = byteAt(i, k + j); if (b != commonPrefix[j]) { commonPrefixLength = j; if (commonPrefixLength == 0) { // we have no common prefix histogram[commonPrefix[0] + 1] = i - from; histogram[b + 1] = 1; break outer; } break; } } } if (i < to) { // the loop got broken because there is no common prefix assert commonPrefixLength == 0; buildHistogram(i + 1, to, k, histogram); } else { assert commonPrefixLength > 0; histogram[commonPrefix[0] + 1] = to - from; } return commonPrefixLength; } /** Build an histogram of the k-th characters of values occurring between * offsets {@code from} and {@code to}, using {@link #getBucket}. */ private void buildHistogram(int from, int to, int k, int[] histogram) { for (int i = from; i < to; ++i) { histogram[getBucket(i, k)]++; } } /** Reorder elements so that all of them that fall into {@code bucket} are * between offsets {@code bucketFrom} and {@code bucketTo}. */ private void partition(int from, int to, int bucket, int bucketFrom, int bucketTo, int d) { int left = from; int right = to - 1; int slot = bucketFrom; for (;;) { int leftBucket = getBucket(left, d); int rightBucket = getBucket(right, d); while (leftBucket <= bucket && left < bucketFrom) { if (leftBucket == bucket) { swap(left, slot++); } else { ++left; } leftBucket = getBucket(left, d); } while (rightBucket >= bucket && right >= bucketTo) { if (rightBucket == bucket) { swap(right, slot++); } else { --right; } rightBucket = getBucket(right, d); } if (left < bucketFrom && right >= bucketTo) { swap(left++, right--); } else { assert left == bucketFrom; assert right == bucketTo - 1; break; } } } }