/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.blur.lucene.security.index; import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Set; import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocsAndPositionsEnum; import org.apache.lucene.index.DocsEnum; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.Fields; import org.apache.lucene.index.FilterAtomicReader; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.CompiledAutomaton; import com.google.common.base.Splitter; /** * The current {@link SecureAtomicReader} will protect access to documents based * on the {@link AccessControl} object. * * NOTE: If you are using the {@link Fields} and {@link Terms} with * {@link TermsEnum} to create a type ahead. Make sure that you check that the * {@link TermsEnum} actually points to a single document because the * {@link SecureAtomicReader} will leak terms that users don't have access to * read or discover. */ public class SecureAtomicReader extends FilterAtomicReader { private final AccessControlReader _accessControl; private final AtomicReader _original; public static SecureAtomicReader create(AccessControlFactory accessControlFactory, AtomicReader in, Collection<String> readAuthorizations, Collection<String> discoverAuthorizations, Set<String> discoverableFields, String defaultReadMaskMessage) throws IOException { AccessControlReader accessControlReader = accessControlFactory.getReader(readAuthorizations, discoverAuthorizations, discoverableFields, defaultReadMaskMessage); return new SecureAtomicReader(in, accessControlReader); } public SecureAtomicReader(AtomicReader in, AccessControlReader accessControlReader) throws IOException { super(in); _accessControl = accessControlReader.clone(in); _original = in; } public AtomicReader getOriginalReader() { return _original; } @Override public Bits getLiveDocs() { final Bits liveDocs = in.getLiveDocs(); final int maxDoc = maxDoc(); return new Bits() { @Override public boolean get(int index) { if (liveDocs == null || liveDocs.get(index)) { // Need to check access try { if (_accessControl.hasAccess(ReadType.LIVEDOCS, index)) { return true; } } catch (IOException e) { throw new RuntimeException(e); } } return false; } @Override public int length() { return maxDoc; } }; } @Override public Fields getTermVectors(int docID) throws IOException { // use doc auth throw new RuntimeException("Not implemented."); } @Override public void document(int docID, final StoredFieldVisitor visitor) throws IOException { if (_accessControl.hasAccess(ReadType.DOCUMENT_FETCH_READ, docID)) { GetReadMaskFields getReadMaskFields = new GetReadMaskFields(_accessControl.getDefaultReadMaskMessage()); in.document(docID, getReadMaskFields); Map<String, String> readMaskFields = getReadMaskFields.getReadMaskFields(); if (readMaskFields.isEmpty()) { in.document(docID, visitor); } else { in.document(docID, new ReadMaskStoredFieldVisitor(visitor, readMaskFields)); } return; } if (_accessControl.hasAccess(ReadType.DOCUMENT_FETCH_DISCOVER, docID)) { in.document(docID, new StoredFieldVisitor() { @Override public Status needsField(FieldInfo fieldInfo) throws IOException { if (_accessControl.canDiscoverField(fieldInfo.name)) { return visitor.needsField(fieldInfo); } else { return Status.NO; } } @Override public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException { visitor.binaryField(fieldInfo, value); } @Override public void stringField(FieldInfo fieldInfo, String value) throws IOException { visitor.stringField(fieldInfo, value); } @Override public void intField(FieldInfo fieldInfo, int value) throws IOException { visitor.intField(fieldInfo, value); } @Override public void longField(FieldInfo fieldInfo, long value) throws IOException { visitor.longField(fieldInfo, value); } @Override public void floatField(FieldInfo fieldInfo, float value) throws IOException { visitor.floatField(fieldInfo, value); } @Override public void doubleField(FieldInfo fieldInfo, double value) throws IOException { visitor.doubleField(fieldInfo, value); } }); return; } } private static class ReadMaskStoredFieldVisitor extends StoredFieldVisitor { private final StoredFieldVisitor _visitor; private final Map<String, String> _readMaskFieldsAndMessages; public ReadMaskStoredFieldVisitor(StoredFieldVisitor visitor, Map<String, String> readMaskFieldsAndMessages) { _visitor = visitor; _readMaskFieldsAndMessages = readMaskFieldsAndMessages; } @Override public Status needsField(FieldInfo fieldInfo) throws IOException { return Status.YES; } @Override public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException { if (!checkReadMask(fieldInfo)) { _visitor.binaryField(fieldInfo, value); } } private boolean checkReadMask(FieldInfo fieldInfo) throws IOException { final String message = _readMaskFieldsAndMessages.get(fieldInfo.name); if (message != null) { if (message.isEmpty()) { return true; } _visitor.stringField(fieldInfo, message); return true; } return false; } @Override public void stringField(FieldInfo fieldInfo, String value) throws IOException { if (!checkReadMask(fieldInfo)) { _visitor.stringField(fieldInfo, value); } } @Override public void intField(FieldInfo fieldInfo, int value) throws IOException { if (!checkReadMask(fieldInfo)) { _visitor.intField(fieldInfo, value); } } @Override public void longField(FieldInfo fieldInfo, long value) throws IOException { if (!checkReadMask(fieldInfo)) { _visitor.longField(fieldInfo, value); } } @Override public void floatField(FieldInfo fieldInfo, float value) throws IOException { if (!checkReadMask(fieldInfo)) { _visitor.floatField(fieldInfo, value); } } @Override public void doubleField(FieldInfo fieldInfo, double value) throws IOException { if (!checkReadMask(fieldInfo)) { _visitor.doubleField(fieldInfo, value); } } } private static class GetReadMaskFields extends StoredFieldVisitor { private final Map<String, String> _fieldsAndMessages = new HashMap<String, String>(); private final Splitter splitter = Splitter.on('|'); private final String _defaultReadMask; GetReadMaskFields(String defaultReadMask) { _defaultReadMask = defaultReadMask == null ? "" : defaultReadMask; } @Override public Status needsField(FieldInfo fieldInfo) throws IOException { if (fieldInfo.name.equals(FilterAccessControlFactory.READ_MASK_FIELD)) { return Status.YES; } return Status.NO; } @Override public void stringField(FieldInfo fieldInfo, String value) throws IOException { Iterable<String> split = splitter.split(value); Iterator<String> iterator = split.iterator(); String field; String message = null; if (iterator.hasNext()) { field = iterator.next(); } else { return; } if (iterator.hasNext()) { message = iterator.next(); } if (message != null) { _fieldsAndMessages.put(field, message); } else { _fieldsAndMessages.put(field, _defaultReadMask); } } Map<String, String> getReadMaskFields() { return _fieldsAndMessages; } } @Override public Fields fields() throws IOException { return new SecureFields(in.fields(), _accessControl, maxDoc()); } @Override public NumericDocValues getNumericDocValues(String field) throws IOException { return secureNumericDocValues(in.getNumericDocValues(field), ReadType.NUMERIC_DOC_VALUE); } private NumericDocValues secureNumericDocValues(final NumericDocValues numericDocValues, final ReadType type) { if (numericDocValues == null) { return null; } return new NumericDocValues() { @Override public long get(int docID) { try { if (_accessControl.hasAccess(type, docID)) { return numericDocValues.get(docID); } return 0L; // Default missing value. } catch (IOException e) { throw new RuntimeException(e); } } }; } @Override public BinaryDocValues getBinaryDocValues(String field) throws IOException { final BinaryDocValues binaryDocValues = in.getBinaryDocValues(field); if (binaryDocValues == null) { return null; } return new BinaryDocValues() { @Override public void get(int docID, BytesRef result) { try { if (_accessControl.hasAccess(ReadType.BINARY_DOC_VALUE, docID)) { binaryDocValues.get(docID, result); return; } // Default missing value. result.bytes = MISSING; result.length = 0; result.offset = 0; } catch (IOException e) { throw new RuntimeException(e); } } }; } @Override public SortedDocValues getSortedDocValues(String field) throws IOException { final SortedDocValues sortedDocValues = in.getSortedDocValues(field); if (sortedDocValues == null) { return null; } return new SortedDocValues() { @Override public void lookupOrd(int ord, BytesRef result) { sortedDocValues.lookupOrd(ord, result); } @Override public int getValueCount() { return sortedDocValues.getValueCount(); } @Override public int getOrd(int docID) { try { if (_accessControl.hasAccess(ReadType.SORTED_DOC_VALUE, docID)) { return sortedDocValues.getOrd(docID); } return -1; // Default missing value. } catch (IOException e) { throw new RuntimeException(e); } } }; } @Override public SortedSetDocValues getSortedSetDocValues(String field) throws IOException { final SortedSetDocValues sortedSetDocValues = in.getSortedSetDocValues(field); if (sortedSetDocValues == null) { return null; } return new SortedSetDocValues() { private boolean _access; @Override public void setDocument(int docID) { try { if (_access = _accessControl.hasAccess(ReadType.SORTED_SET_DOC_VALUE, docID)) { sortedSetDocValues.setDocument(docID); } } catch (IOException e) { throw new RuntimeException(e); } } @Override public long nextOrd() { if (_access) { return sortedSetDocValues.nextOrd(); } return NO_MORE_ORDS; } @Override public void lookupOrd(long ord, BytesRef result) { if (_access) { sortedSetDocValues.lookupOrd(ord, result); } else { result.bytes = BinaryDocValues.MISSING; result.length = 0; result.offset = 0; } } @Override public long getValueCount() { return sortedSetDocValues.getValueCount(); } }; } @Override public NumericDocValues getNormValues(String field) throws IOException { return secureNumericDocValues(in.getNormValues(field), ReadType.NORM_VALUE); } static class SecureFields extends FilterFields { private final int _maxDoc; private final AccessControlReader _accessControlReader; public SecureFields(Fields in, AccessControlReader accessControlReader, int maxDoc) { super(in); _accessControlReader = accessControlReader; _maxDoc = maxDoc; } @Override public Terms terms(String field) throws IOException { Terms terms = in.terms(field); if (terms == null) { return null; } Terms readMask = getReadMaskTerms(in, field); SecureTerms secureTerms = new SecureTerms(terms, _accessControlReader, _maxDoc); if (readMask == null) { return secureTerms; } else { return new ReadMaskTerms(secureTerms, readMask); } } private Terms getReadMaskTerms(Fields in, String field) throws IOException { return in.terms(field + FilterAccessControlFactory.READ_MASK_SUFFIX); } } static class ReadMaskTerms extends FilterTerms { private final Terms _readMask; public ReadMaskTerms(Terms in, Terms readMask) { super(in); _readMask = readMask; } @Override public TermsEnum iterator(TermsEnum reuse) throws IOException { TermsEnum maskTermsEnum = _readMask.iterator(null); return new ReadMaskTermsEnum(maskTermsEnum, in.iterator(reuse)); } @Override public TermsEnum intersect(CompiledAutomaton compiled, BytesRef startTerm) throws IOException { TermsEnum maskTermsEnum = _readMask.intersect(compiled, startTerm); return new ReadMaskTermsEnum(maskTermsEnum, in.intersect(compiled, startTerm)); } } static class SecureTerms extends FilterTerms { private final int _maxDoc; private final AccessControlReader _accessControlReader; public SecureTerms(Terms in, AccessControlReader accessControlReader, int maxDoc) { super(in); _accessControlReader = accessControlReader; _maxDoc = maxDoc; } @Override public TermsEnum iterator(TermsEnum reuse) throws IOException { return new SecureTermsEnum(in.iterator(reuse), _accessControlReader, _maxDoc); } @Override public TermsEnum intersect(CompiledAutomaton compiled, BytesRef startTerm) throws IOException { return new SecureTermsEnum(in.intersect(compiled, startTerm), _accessControlReader, _maxDoc); } } static class ReadMaskTermsEnum extends FilterTermsEnum { private final TermsEnum _maskTermsEnum; public ReadMaskTermsEnum(TermsEnum maskTermsEnum, TermsEnum realTermsEnum) { super(realTermsEnum); _maskTermsEnum = maskTermsEnum; } @Override public SeekStatus seekCeil(BytesRef text, boolean useCache) throws IOException { SeekStatus seekStatus = in.seekCeil(text, useCache); if (seekStatus == SeekStatus.END) { return SeekStatus.END; } BytesRef term = in.term(); BytesRef bytesRef = new BytesRef(); bytesRef.copyBytes(term); // If not in mask then return. if (!_maskTermsEnum.seekExact(bytesRef, true)) { return seekStatus; } else { if (checkDocs()) { return seekStatus; } // check if any docs for given term are not present in mask terms enums if (next() == null) { return SeekStatus.END; } else { return SeekStatus.NOT_FOUND; } } } @Override public void seekExact(long ord) throws IOException { throw new IOException("Not supported"); } @Override public BytesRef next() throws IOException { while (true) { BytesRef ref = in.next(); if (ref == null) { return null; } if (!_maskTermsEnum.seekExact(ref, true)) { return ref; } if (checkDocs()) { return ref; } } } private boolean checkDocs() throws IOException { DocsEnum maskDocsEnum = _maskTermsEnum.docs(null, null, DocsEnum.FLAG_NONE); DocsEnum docsEnum = in.docs(null, null, DocsEnum.FLAG_NONE); int docId; while ((docId = docsEnum.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { if (maskDocsEnum.advance(docId) != docId) { return true; } } return false; } } static class SecureTermsEnum extends FilterTermsEnum { private final int _maxDoc; private final AccessControlReader _accessControlReader; public SecureTermsEnum(TermsEnum in, AccessControlReader accessControlReader, int maxDoc) { super(in); _accessControlReader = accessControlReader; _maxDoc = maxDoc; } @Override public SeekStatus seekCeil(BytesRef text, boolean useCache) throws IOException { SeekStatus seekStatus = in.seekCeil(text, useCache); if (seekStatus == SeekStatus.END) { return SeekStatus.END; } BytesRef term = in.term(); if (hasAccess(term)) { return seekStatus; } if (next() == null) { return SeekStatus.END; } else { return SeekStatus.NOT_FOUND; } } @Override public BytesRef next() throws IOException { BytesRef t; while ((t = in.next()) != null) { if (hasAccess(t)) { return t; } } return null; } private boolean hasAccess(BytesRef term) throws IOException { DocsEnum docsEnum = in.docs(null, null, DocsEnum.FLAG_NONE); int docId; while ((docId = docsEnum.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { if (_accessControlReader.hasAccess(ReadType.TERMS_ENUM, docId)) { return true; } } return false; } @Override public DocsEnum docs(Bits liveDocs, DocsEnum reuse, int flags) throws IOException { Bits secureLiveDocs = getSecureLiveDocs(liveDocs, _maxDoc, _accessControlReader); return in.docs(secureLiveDocs, reuse, flags); } @Override public DocsAndPositionsEnum docsAndPositions(Bits liveDocs, DocsAndPositionsEnum reuse, int flags) throws IOException { Bits secureLiveDocs = getSecureLiveDocs(liveDocs, _maxDoc, _accessControlReader); return in.docsAndPositions(secureLiveDocs, reuse, flags); } } public static Bits getSecureLiveDocs(Bits bits, int maxDoc, final AccessControlReader accessControlReader) { final Bits liveDocs; if (bits == null) { liveDocs = getMatchAll(maxDoc); } else { liveDocs = bits; } final int length = liveDocs.length(); Bits secureLiveDocs = new Bits() { @Override public boolean get(int index) { if (liveDocs.get(index)) { try { if (accessControlReader.hasAccess(ReadType.DOCS_ENUM, index)) { return true; } } catch (IOException e) { throw new RuntimeException(e); } } return false; } @Override public int length() { return length; } }; return secureLiveDocs; } public static Bits getMatchAll(final int length) { return new Bits() { @Override public int length() { return length; } @Override public boolean get(int index) { return true; } }; } }