/*
* Copyright (C) 2012 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.stats.cardinality;
import com.google.common.base.Preconditions;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.Arrays;
/**
* A low (compact) cardinality estimator
*/
@NotThreadSafe
class SparseEstimator
implements Estimator {
private final static int BITS_PER_BUCKET = 4;
private final static int BUCKET_VALUE_MASK = (1 << BITS_PER_BUCKET) - 1;
public final static int MAX_BUCKET_VALUE = (1 << BITS_PER_BUCKET);
private final static int INSTANCE_SIZE = UnsafeUtil.sizeOf(SparseEstimator.class);
// number of bits used for bucket index
private final byte indexBits;
private short bucketCount = 0;
/*
This structure keeps a sorted list of bucket entries of size log2(numberOfBuckets) +
BITS_PER_BUCKET packed into an array of longs.
Within each long, buckets are stored in little-endian order, aligned to the least-significant bit
edge. For instance, if a bucket entry is 16 bits long (4 buckets per slot), the layout is:
slot 0: [ index 3 | index 2 | index 1 | index 0 ]
slot 1: [ index 7 | index 6 | index 5 | index 4 ]
....
*/
private long[] slots;
SparseEstimator(int numberOfBuckets) {
this(numberOfBuckets, 1);
}
SparseEstimator(int[] buckets) {
this(buckets.length, countNonZeroBuckets(buckets));
for (int bucket = 0; bucket < buckets.length; bucket++) {
int value = buckets[bucket];
if (value != 0) {
setEntry(bucketCount, bucket, value);
++bucketCount;
}
}
}
SparseEstimator(int numberOfBuckets, int initialCapacity) {
Preconditions.checkArgument(
Numbers.isPowerOf2(numberOfBuckets),
"numberOfBuckets must be a power of 2"
);
this.indexBits = (byte) Integer.numberOfTrailingZeros(numberOfBuckets); // log2(numberOfBuckets)
slots = new long[(initialCapacity + getBucketsPerSlot()) / getBucketsPerSlot()];
}
public boolean setIfGreater(int bucket, int highestBitPosition) {
Preconditions.checkArgument(
highestBitPosition < MAX_BUCKET_VALUE,
"highestBitPosition %s is bigger than allowed by BITS_PER_BUCKET (%s)",
highestBitPosition,
BITS_PER_BUCKET
);
if (highestBitPosition == 0) {
return false; // no need to set anything -- 0 is implied if bucket is not present
}
int index = findBucket(bucket);
if (index < 0) {
insertAt(-(index + 1), bucket, highestBitPosition);
return true;
}
if (getEntry(index).getValue() < highestBitPosition) {
setEntry(index, bucket, highestBitPosition);
return true;
}
return false;
}
public int[] buckets() {
int[] buckets = new int[getNumberOfBuckets()];
for (int i = 0; i < bucketCount; ++i) {
Entry entry = getEntry(i);
buckets[entry.getBucket()] = entry.getValue();
}
return buckets;
}
public int getNumberOfBuckets() {
return 1 << indexBits;
}
@Override
public int getMaxAllowedBucketValue() {
return MAX_BUCKET_VALUE;
}
private Entry getEntry(int index) {
int totalBitsPerBucket = getTotalBitsPerBucket();
int bucketMask = (1 << totalBitsPerBucket) - 1;
int bucketsPerSlot = getBucketsPerSlot();
int slot = index / bucketsPerSlot;
int offset = index % bucketsPerSlot;
int bucketEntry = (int) ((slots[slot] >>> (offset * totalBitsPerBucket)) & bucketMask);
return new Entry(bucketEntry >> BITS_PER_BUCKET, bucketEntry & BUCKET_VALUE_MASK);
}
private int getBucketsPerSlot() {
return Long.SIZE / getTotalBitsPerBucket();
}
private int getTotalBitsPerBucket() {
return indexBits + BITS_PER_BUCKET;
}
private void setEntry(int index, int bucket, int value) {
int totalBitsPerBucket = getTotalBitsPerBucket();
long bucketMask = (1L << totalBitsPerBucket) - 1;
int bucketsPerSlot = getBucketsPerSlot();
int slot = index / bucketsPerSlot;
int offset = index % bucketsPerSlot;
long bucketEntry = (bucket << BITS_PER_BUCKET) | value;
long bucketClearMask = bucketMask << (offset * totalBitsPerBucket);
long bucketSetMask = bucketEntry << (offset * totalBitsPerBucket);
slots[slot] = (slots[slot] & ~bucketClearMask) | bucketSetMask;
}
public int estimateSizeInBytes() {
return estimateSizeInBytes(bucketCount, getNumberOfBuckets());
}
public static int estimateSizeInBytes(int nonZeroBuckets, int totalBuckets) {
Preconditions.checkArgument(
Numbers.isPowerOf2(totalBuckets),
"totalBuckets must be a power of 2"
);
int bits = Integer.numberOfTrailingZeros(totalBuckets); // log2(totalBuckets)
int bucketsPerSlot = Long.SIZE / (bits + BITS_PER_BUCKET);
return (nonZeroBuckets + bucketsPerSlot) / bucketsPerSlot * Long.SIZE / 8 + INSTANCE_SIZE;
}
public long estimate() {
int totalBuckets = getNumberOfBuckets();
// small cardinality estimate
int zeroBuckets = totalBuckets - bucketCount;
return Math.round(totalBuckets * Math.log(totalBuckets * 1.0 / zeroBuckets));
}
private void grow() {
slots = Arrays.copyOf(slots, slots.length + 1);
}
private int findBucket(int bucket) {
int low = 0;
int high = bucketCount - 1;
while (low <= high) {
int middle = (low + high) >>> 1;
Entry middleBucket = getEntry(middle);
if (bucket > middleBucket.getBucket()) {
low = middle + 1;
} else if (bucket < middleBucket.getBucket()) {
high = middle - 1;
} else {
return middle;
}
}
return -(low + 1); // not found... return insertion point
}
private void insertAt(int index, int bucket, int value) {
int totalBitsPerBucket = getTotalBitsPerBucket();
int bucketsPerSlot = getBucketsPerSlot();
++bucketCount;
if ((bucketCount + bucketsPerSlot - 1) / bucketsPerSlot > slots.length) {
grow();
}
// the last slot that would have any data after the bucket is inserted
int lastUsedSlot = (bucketCount - 1) / bucketsPerSlot;
int insertAtSlot = index / bucketsPerSlot;
int insertOffset = index % bucketsPerSlot;
long bucketMask = (1L << totalBitsPerBucket) - 1;
// shift all buckets one position to the right
for (int i = lastUsedSlot; i > insertAtSlot; --i) {
int overflow = (int) ((slots[i - 1] >>> ((bucketsPerSlot - 1) * totalBitsPerBucket)) &
bucketMask);
slots[i] = (slots[i] << totalBitsPerBucket) | overflow;
}
long old = slots[insertAtSlot];
long bottomMask = (1L << (insertOffset * totalBitsPerBucket)) - 1;
long topMask = 0;
if (insertOffset < this.getBucketsPerSlot() - 1) {
// to get around the fact that X << 64 == X, not 0
topMask = (0xFFFFFFFFFFFFFFFFL << ((insertOffset + 1) * totalBitsPerBucket));
}
long bucketSetMask = ((((long) bucket) << BITS_PER_BUCKET) | value) << (insertOffset *
totalBitsPerBucket);
slots[insertAtSlot] = ((old << totalBitsPerBucket) & topMask) | bucketSetMask |
(old & bottomMask);
}
private static int countNonZeroBuckets(int[] buckets) {
int count = 0;
for (int bucket : buckets) {
if (bucket > 0) {
++count;
}
}
return count;
}
private static class Entry {
private final int bucket;
private final int value;
private Entry(int bucket, int value) {
this.bucket = bucket;
this.value = value;
}
public int getBucket() {
return bucket;
}
public int getValue() {
return value;
}
}
}