package org.apache.lucene.search.suggest; /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import java.io.File; import java.io.IOException; import java.util.Comparator; import org.apache.lucene.search.spell.TermFreqIterator; import org.apache.lucene.search.suggest.fst.Sort; import org.apache.lucene.search.suggest.fst.Sort.ByteSequencesReader; import org.apache.lucene.search.suggest.fst.Sort.ByteSequencesWriter; import org.apache.lucene.store.ByteArrayDataInput; import org.apache.lucene.store.ByteArrayDataOutput; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; /** * This wrapper buffers incoming elements and makes sure they are sorted based on given comparator. * @lucene.experimental */ public class SortedTermFreqIteratorWrapper implements TermFreqIterator { private final TermFreqIterator source; private File tempInput; private File tempSorted; private final ByteSequencesReader reader; private boolean done = false; private long weight; private final BytesRef scratch = new BytesRef(); private final Comparator<BytesRef> comparator; /** * Calls {@link #SortedTermFreqIteratorWrapper(TermFreqIterator, Comparator, boolean) * SortedTermFreqIteratorWrapper(source, comparator, false)} */ public SortedTermFreqIteratorWrapper(TermFreqIterator source, Comparator<BytesRef> comparator) throws IOException { this(source, comparator, false); } /** * Creates a new sorted wrapper. if <code>compareRawBytes</code> is true, then * only the bytes (not the weight) will be used for comparison. */ public SortedTermFreqIteratorWrapper(TermFreqIterator source, Comparator<BytesRef> comparator, boolean compareRawBytes) throws IOException { this.source = source; this.comparator = comparator; this.reader = sort(compareRawBytes ? comparator : new BytesOnlyComparator(this.comparator)); } @Override public BytesRef next() throws IOException { boolean success = false; if (done) { return null; } try { ByteArrayDataInput input = new ByteArrayDataInput(); if (reader.read(scratch)) { weight = decode(scratch, input); success = true; return scratch; } close(); success = done = true; return null; } finally { if (!success) { done = true; close(); } } } @Override public Comparator<BytesRef> getComparator() { return comparator; } @Override public long weight() { return weight; } private Sort.ByteSequencesReader sort(Comparator<BytesRef> comparator) throws IOException { String prefix = getClass().getSimpleName(); File directory = Sort.defaultTempDir(); tempInput = File.createTempFile(prefix, ".input", directory); tempSorted = File.createTempFile(prefix, ".sorted", directory); final Sort.ByteSequencesWriter writer = new Sort.ByteSequencesWriter(tempInput); boolean success = false; try { BytesRef spare; byte[] buffer = new byte[0]; ByteArrayDataOutput output = new ByteArrayDataOutput(buffer); while ((spare = source.next()) != null) { encode(writer, output, buffer, spare, source.weight()); } writer.close(); new Sort(comparator).sort(tempInput, tempSorted); ByteSequencesReader reader = new Sort.ByteSequencesReader(tempSorted); success = true; return reader; } finally { if (success) { IOUtils.close(writer); } else { try { IOUtils.closeWhileHandlingException(writer); } finally { close(); } } } } private void close() throws IOException { IOUtils.close(reader); if (tempInput != null) { tempInput.delete(); } if (tempSorted != null) { tempSorted.delete(); } } private final static class BytesOnlyComparator implements Comparator<BytesRef> { final Comparator<BytesRef> other; private final BytesRef leftScratch = new BytesRef(); private final BytesRef rightScratch = new BytesRef(); public BytesOnlyComparator(Comparator<BytesRef> other) { this.other = other; } @Override public int compare(BytesRef left, BytesRef right) { wrap(leftScratch, left); wrap(rightScratch, right); return other.compare(leftScratch, rightScratch); } private void wrap(BytesRef wrapper, BytesRef source) { wrapper.bytes = source.bytes; wrapper.offset = source.offset; wrapper.length = source.length - 8; } } /** encodes an entry (bytes+weight) to the provided writer */ protected void encode(ByteSequencesWriter writer, ByteArrayDataOutput output, byte[] buffer, BytesRef spare, long weight) throws IOException { if (spare.length + 8 >= buffer.length) { buffer = ArrayUtil.grow(buffer, spare.length + 8); } output.reset(buffer); output.writeBytes(spare.bytes, spare.offset, spare.length); output.writeLong(weight); writer.write(buffer, 0, output.getPosition()); } /** decodes the weight at the current position */ protected long decode(BytesRef scratch, ByteArrayDataInput tmpInput) { tmpInput.reset(scratch.bytes); tmpInput.skipBytes(scratch.length - 8); // suggestion + separator scratch.length -= 8; // sep + long return tmpInput.readLong(); } }