/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.blur.lucene.security.search; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; import org.apache.blur.lucene.security.DocumentAuthorizations; import org.apache.blur.lucene.security.DocumentVisibility; import org.apache.blur.lucene.security.DocumentVisibilityEvaluator; import org.apache.blur.lucene.security.search.DocumentVisibilityFilterCacheStrategy.Builder; import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.index.DocsEnum; import org.apache.lucene.index.Fields; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.DocIdSet; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Filter; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; public class DocumentVisibilityFilter extends Filter { private static final Comparator<DocIdSetIterator> COMPARATOR = new Comparator<DocIdSetIterator>() { @Override public int compare(DocIdSetIterator o1, DocIdSetIterator o2) { int docID1 = o1.docID(); int docID2 = o2.docID(); return docID1 - docID2; } }; private final String _fieldName; private final DocumentAuthorizations _authorizations; private final DocumentVisibilityFilterCacheStrategy _filterCacheStrategy; public DocumentVisibilityFilter(String fieldName, DocumentAuthorizations authorizations, DocumentVisibilityFilterCacheStrategy filterCacheStrategy) { _fieldName = fieldName; _authorizations = authorizations; _filterCacheStrategy = filterCacheStrategy; } @Override public String toString() { return "DocumentVisibilityFilter [_fieldName=" + _fieldName + ", _authorizations=" + _authorizations + ", _filterCacheStrategy=" + _filterCacheStrategy + "]"; } @Override public DocIdSet getDocIdSet(AtomicReaderContext context, Bits acceptDocs) throws IOException { AtomicReader reader = context.reader(); List<DocIdSet> list = new ArrayList<DocIdSet>(); Fields fields = reader.fields(); Terms terms = fields.terms(_fieldName); if (terms == null) { // if field is not present then show nothing. return DocIdSet.EMPTY_DOCIDSET; } TermsEnum iterator = terms.iterator(null); BytesRef bytesRef; DocumentVisibilityEvaluator visibilityEvaluator = new DocumentVisibilityEvaluator(_authorizations); while ((bytesRef = iterator.next()) != null) { if (isVisible(visibilityEvaluator, bytesRef)) { DocIdSet docIdSet = _filterCacheStrategy.getDocIdSet(_fieldName, bytesRef, reader); if (docIdSet != null) { list.add(docIdSet); } else { // Do not use acceptDocs because we want the acl cache to be version // agnostic. DocsEnum docsEnum = iterator.docs(null, null); list.add(buildCache(reader, docsEnum, bytesRef)); } } } return getLogicalOr(list); } private DocIdSet buildCache(AtomicReader reader, DocIdSetIterator it, BytesRef bytesRef) throws IOException { Builder builder = _filterCacheStrategy.createBuilder(_fieldName, bytesRef, reader); builder.or(it); return builder.getDocIdSet(); } private boolean isVisible(DocumentVisibilityEvaluator visibilityEvaluator, BytesRef bytesRef) throws IOException { DocumentVisibility visibility = new DocumentVisibility(trim(bytesRef)); return visibilityEvaluator.evaluate(visibility); } private byte[] trim(BytesRef bytesRef) { byte[] buf = new byte[bytesRef.length]; System.arraycopy(bytesRef.bytes, bytesRef.offset, buf, 0, bytesRef.length); return buf; } public static DocIdSet getLogicalOr(DocIdSet... list) throws IOException { return getLogicalOr(Arrays.asList(list)); } public static DocIdSet getLogicalOr(final List<DocIdSet> list) throws IOException { if (list.size() == 0) { return DocIdSet.EMPTY_DOCIDSET; } if (list.size() == 1) { DocIdSet docIdSet = list.get(0); Bits bits = docIdSet.bits(); if (bits == null) { throw new IOException("Bits are not allowed to be null for DocIdSet [" + docIdSet + "]."); } return docIdSet; } int index = 0; final Bits[] bitsArray = new Bits[list.size()]; int length = -1; for (DocIdSet docIdSet : list) { Bits bits = docIdSet.bits(); if (bits == null) { throw new IOException("Bits are not allowed to be null for DocIdSet [" + docIdSet + "]."); } bitsArray[index] = bits; index++; if (length < 0) { length = bits.length(); } else if (length != bits.length()) { throw new IOException("Bits length need to be the same [" + length + "] and [" + bits.length() + "]"); } } final int len = length; return new DocIdSet() { @Override public Bits bits() throws IOException { return new Bits() { @Override public boolean get(int index) { for (int i = 0; i < bitsArray.length; i++) { if (bitsArray[i].get(index)) { return true; } } return false; } @Override public int length() { return len; } }; } @Override public boolean isCacheable() { return true; } @Override public DocIdSetIterator iterator() throws IOException { final DocIdSetIterator[] docIdSetIteratorArray = new DocIdSetIterator[list.size()]; long c = 0; int index = 0; for (DocIdSet docIdSet : list) { DocIdSetIterator iterator = docIdSet.iterator(); iterator.nextDoc(); docIdSetIteratorArray[index] = iterator; c += iterator.cost(); index++; } final long cost = c; return new DocIdSetIterator() { private int _docId = -1; @Override public int advance(int target) throws IOException { callAdvanceOnAllThatAreBehind(target); Arrays.sort(docIdSetIteratorArray, COMPARATOR); DocIdSetIterator iterator = docIdSetIteratorArray[0]; return _docId = iterator.docID(); } private void callAdvanceOnAllThatAreBehind(int target) throws IOException { for (int i = 0; i < docIdSetIteratorArray.length; i++) { DocIdSetIterator iterator = docIdSetIteratorArray[i]; if (iterator.docID() < target) { iterator.advance(target); } } } @Override public int nextDoc() throws IOException { return advance(_docId + 1); } @Override public int docID() { return _docId; } @Override public long cost() { return cost; } }; } }; } }