/*
* Copyright (C) 2012 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.collections.specialized;
import com.facebook.collections.SetFactory;
import com.facebook.collections.WrappedIterator;
import com.facebook.util.digest.DigestFunction;
import com.facebook.util.serialization.SerDe;
import com.facebook.util.serialization.SerDeException;
import com.google.common.collect.ImmutableSet;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* Thread-safe implementation of SampledSet
* http://algo.inria.fr/flajolet/Publications/Slides/aofa07.pdf
* section 2.1 Adaptive Sampling
*
* @param <T> type of element in the set
*/
// TODO : optimize concurrency if proves to be a bottleneck
public class SampledSetImpl<T> implements SampledSet<T> {
private final DigestFunction<T> digestFunction;
private final SetFactory<T, SnapshotableSet<T>> setFactory;
private final int maxSetSize;
private final AtomicInteger proposedSize = new AtomicInteger(0);
// this protects baseSet and currentSampleRate; writeLock is held to
// safely perform downsampling or downsample + add(). add() and copy
// operations also need the read lock
private final ReadWriteLock downSampleLock = new ReentrantReadWriteLock();
private final AtomicBoolean dirty = new AtomicBoolean(false);
private volatile SnapshotableSet<T> baseSet;
// if md5 % sampleRate == 0, we will keep the value
private volatile int currentSampleRate;
// ideally, this would be private, but HashDistinctCountAggregation needs it to convert legacy
// SampledSetImpl<Long> to SampledSetImpl<Integer>
@Deprecated
public SampledSetImpl(
int maxSetSize,
DigestFunction<T> digestFunction,
SnapshotableSet<T> baseSet,
SetFactory<T, SnapshotableSet<T>> setFactory,
int currentSampleRate
) {
this.maxSetSize = maxSetSize;
this.digestFunction = digestFunction;
this.setFactory = setFactory;
this.baseSet = baseSet;
this.currentSampleRate = currentSampleRate;
proposedSize.set(baseSet.size());
}
public SampledSetImpl(
int maxSetSize,
DigestFunction<T> digestFunction,
SetFactory<T, SnapshotableSet<T>> setFactory
) {
this(maxSetSize, digestFunction, setFactory.create(), setFactory, 1);
}
@Override
public boolean add(T element) {
// algorithm:
// 1. store sample rate
// 2. check if element is in in sample
// 3. increment proposed size
// 4. if >= max, grab writeLock, downsample, re-check if in sample,
// and add to set if in sample;
// 5. if < max, grab read lock. if sample rate changed, re-check if in
// sample. add if still in sample or rate hasn't changed
// 6. either case, decrement proposed size if element ends up
// not being added
// 7. release read or write lock
boolean returnValue = false;
long elementDigest = digestFunction.computeDigest(element);
int sampleRateSnapshot = currentSampleRate;
// is this value in our current sample
if (inSample(elementDigest, sampleRateSnapshot)) {
// check if we will exceed the max size
if (proposedSize.incrementAndGet() > maxSetSize) {
// then acquire writeLock and perform downsample + add while holding
// the writeLock
downSampleLock.writeLock().lock();
try {
if (inSample(elementDigest, currentSampleRate) && !baseSet.contains(element)) {
// adding something new
downSample();
// need to add while we hold the lock to guarantee we don't exceed
// the max
if (inSample(elementDigest, currentSampleRate)) {
returnValue = baseSet.add(element);
}
}
} finally {
if (!returnValue) {
proposedSize.decrementAndGet();
}
downSampleLock.writeLock().unlock();
}
} else {
// we won't exceed max size; make sure the sample rate holds constant
// and add to the set
downSampleLock.readLock().lock();
try {
// we only need to check if this element is in the sample again if
// currentSampleRate has changed
if (currentSampleRate == sampleRateSnapshot ||
inSample(elementDigest, currentSampleRate)) {
returnValue = baseSet.add(element);
}
} finally {
if (!returnValue) {
proposedSize.decrementAndGet();
}
downSampleLock.readLock().unlock();
}
}
}
if (returnValue) {
dirty.set(true);
}
return returnValue;
}
private boolean inSample(long digest, int sampleRate) {
return digest % sampleRate == 0;
}
private void downSample() {
// very unlikely, but possible that increasing the sample rate won't
// remove a single value; so do this in a loop
int removed = 0;
while (baseSet.size() >= maxSetSize) {
currentSampleRate <<= 1;
assert (currentSampleRate > 1);
removed += downSampleAtRate(currentSampleRate, baseSet);
}
if (removed > 0) {
proposedSize.addAndGet(-removed);
}
}
private int downSampleAtRate(int sampleRate, Set<T> set) {
int removed = 0;
Iterator<T> iterator = set.iterator();
while (iterator.hasNext()) {
T value = iterator.next();
if (!inSample(digestFunction.computeDigest(value), sampleRate)) {
iterator.remove();
removed++;
}
}
return removed;
}
private SnapshotableSet<T> copyAtRate(int sampleRate) {
if (sampleRate <= currentSampleRate) {
// make a fast copy
return baseSet.makeSnapshot();
} else {
// make a fast-copy and down-sample--faster to remove elements
// than re-add them
SnapshotableSet<T> target = baseSet.makeSnapshot();
downSampleAtRate(sampleRate, target);
return target;
}
}
@Override
public int getMaxSetSize() {
return maxSetSize;
}
@Override
public int getScaledSize() {
return baseSet.size() * currentSampleRate;
}
@Override
public int getSampleRate() {
return currentSampleRate;
}
@Override
public int getSize() {
return baseSet.size();
}
@Override
public Set<T> getEntries() {
return ImmutableSet.copyOf(baseSet);
}
@Override
public SampledSetSnapshot<T> sampleAt(int rate) {
SnapshotableSet<T> setCopy;
int setCopySampleRate;
// grab this lock to make sure we have a consistent view of the sample rate
// and the set;
downSampleLock.readLock().lock();
try {
setCopySampleRate = Math.max(rate, currentSampleRate);
setCopy = copyAtRate(setCopySampleRate);
} finally {
downSampleLock.readLock().unlock();
}
return new SampledSetSnapshot<T>(setCopySampleRate, maxSetSize, setCopy);
}
@Override
public SampledSet<T> merge(SampledSet<T> sampledSet) {
// fast-copy of ourself for merging
SampledSet<T> mergedSampleSet = this.makeSnapshot();
// now merge sampledSet into the copy
mergedSampleSet.mergeInPlaceWith(sampledSet);
// clear the changed status
mergedSampleSet.hasChanged();
return mergedSampleSet;
}
@Override
public boolean mergeInPlaceWith(SampledSet<T> sampledSet) {
boolean changed = false;
// take a snapshot of the other set at our sample rate. Note that it
// only will use this rate if it is higher than its current sample rate
SampledSetSnapshot<T> snapshot = sampledSet.sampleAt(currentSampleRate);
// grab our downSampleLock.writeLock to make sure the sampleRate doesn't
// change while we work
downSampleLock.writeLock().lock();
try {
// shortcut for fast copy: we're empty, the snapshot's sample rate is
// compatible with our sampleRate, and fits within our maxSize
if (currentSampleRate <= snapshot.getSampleRate() && baseSet.isEmpty() &&
maxSetSize >= snapshot.getElements().size()
) {
// copy the set, current sample size, and increment the version
baseSet = snapshot.getElements();
currentSampleRate = snapshot.getSampleRate();
proposedSize.set(baseSet.size());
dirty.set(true);
return true;
} else if (!snapshot.getElements().isEmpty() &&
snapshot.getSampleRate() > currentSampleRate
) {
// only downsample ourself if there are actually elements in the other
// set to merge into ourself
int removed = downSampleAtRate(snapshot.getSampleRate(), baseSet);
if (removed > 0) {
changed = true;
proposedSize.addAndGet(-removed);
}
currentSampleRate = snapshot.getSampleRate();
}
} finally {
downSampleLock.writeLock().unlock();
}
// safe to do this outside the lock since we know our sampleRate is
// at least as high as that of the elements we are adding
for (T element : snapshot.getElements()) {
if (add(element)) {
changed = true;
}
}
if (changed) {
dirty.set(true);
}
return changed;
}
@Override
public boolean hasChanged() {
return dirty.getAndSet(false);
}
@Override
public Iterator<T> iterator() {
return new WrappedIterator<T>(baseSet.iterator()) {
@Override
public void remove() {
super.remove();
dirty.set(true);
}
};
}
@Override
public int size() {
return getSize();
}
@Override
public boolean isEmpty() {
return baseSet.isEmpty();
}
@Override
public boolean contains(Object o) {
return baseSet.contains(o);
}
@Override
public Object[] toArray() {
return baseSet.toArray();
}
@Override
public <V> V[] toArray(V[] a) {
return baseSet.toArray(a);
}
@Override
public boolean remove(Object o) {
if (baseSet.remove(o)) {
dirty.set(true);
return true;
}
return false;
}
@Override
public boolean containsAll(Collection<?> c) {
return baseSet.containsAll(c);
}
@Override
public boolean addAll(Collection<? extends T> c) {
boolean added = false;
for (T item : c) {
if (add(item)) {
added = true;
}
}
if (added) {
dirty.set(true);
}
return added;
}
@Override
public boolean retainAll(Collection<?> c) {
if (baseSet.retainAll(c)) {
dirty.set(true);
return true;
}
return false;
}
@Override
public boolean removeAll(Collection<?> c) {
if (baseSet.removeAll(c)) {
dirty.set(true);
return true;
}
return false;
}
@Override
public void clear() {
baseSet.clear();
dirty.set(true);
}
@Override
public SampledSet<T> makeSnapshot() {
return new SampledSetImpl<T>(
maxSetSize,
digestFunction,
baseSet.makeSnapshot(),
setFactory,
currentSampleRate
);
}
@Override
public SampledSet<T> makeTransientSnapshot() {
SnapshotableSetImplFactory<T> cpuEfficientHashSetFactory =
new SnapshotableSetImplFactory<T>(new HashSetFactory<T>());
SnapshotableSet<T> cpuEfficientHashSet = baseSet.makeTransientSnapshot();
return new SampledSetImpl<T>(
maxSetSize,
digestFunction,
cpuEfficientHashSet,
cpuEfficientHashSetFactory,
currentSampleRate
);
}
@Override
public boolean equals(Object o) {
downSampleLock.writeLock().lock();
try {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final SampledSetImpl<T> that = (SampledSetImpl<T>) o;
if (currentSampleRate != that.currentSampleRate) {
return false;
}
if (maxSetSize != that.maxSetSize) {
return false;
}
if (baseSet != null ? !baseSet.equals(that.baseSet) : that.baseSet != null) {
return false;
}
return true;
} finally {
downSampleLock.writeLock().unlock();
}
}
@Override
public int hashCode() {
downSampleLock.writeLock().lock();
try {
int result = baseSet != null ? baseSet.hashCode() : 0;
result = 31 * result + maxSetSize;
result = 31 * result + currentSampleRate;
return result;
} finally {
downSampleLock.writeLock().unlock();
}
}
public static class SerDeImpl<T> implements SerDe<SampledSet<T>> {
private final SetFactory<T, SnapshotableSet<T>> setFactory;
private final DigestFunction<T> digestFunction;
private final SerDe<T> elementSerDe;
public SerDeImpl(
SetFactory<T, SnapshotableSet<T>> setFactory,
DigestFunction<T> digestFunction,
SerDe<T> elementSerDe
) {
this.setFactory = setFactory;
this.digestFunction = digestFunction;
this.elementSerDe = elementSerDe;
}
@Override
public SampledSet<T> deserialize(DataInput in) throws SerDeException {
try {
int maxSize = in.readInt();
int sampleRate = in.readInt();
int numElements = in.readInt();
SnapshotableSet<T> baseSet = setFactory.create();
for (int i = 0; i < numElements; i++) {
baseSet.add(elementSerDe.deserialize(in));
}
SampledSet<T> sampledSet = new SampledSetImpl<T>(
maxSize, digestFunction, baseSet, setFactory, sampleRate
);
return sampledSet;
} catch (IOException e) {
throw new SerDeException(e);
}
}
@Override
public void serialize(SampledSet<T> value, DataOutput out)
throws SerDeException {
try {
// sampling at 0 will make a copy at the existing sample rate (since the
// rate is only used if it is larger than the existing rate)
SampledSetSnapshot<T> snapshot = value.sampleAt(0);
Set<T> elements = snapshot.getElements();
out.writeInt(snapshot.getMaxSetSize());
out.writeInt(snapshot.getSampleRate());
out.writeInt(elements.size());
for (T element : elements) {
elementSerDe.serialize(element, out);
}
} catch (IOException e) {
throw new SerDeException(e);
}
}
}
}