/* * Copyright (C) 2012 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.stats; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; import com.google.common.collect.Ordering; import com.google.common.collect.PeekingIterator; import com.google.common.util.concurrent.AtomicDouble; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; import javax.annotation.concurrent.ThreadSafe; /** * <p></p>Implements http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.132.7343, a data * structure for approximating quantiles by trading off error with memory requirements.</p> * * <p></p>The size of the digest is adjusted dynamically to achieve the error bound and requires * O(log2(U) / maxError) space, where <em>U</em> is the number of bits needed to represent the * domain of the values added to the digest.</p> * * <p>The error is defined as the discrepancy between the real rank of the value returned in a * quantile query and the rank corresponding to the queried quantile.</p> * * <p>Thus, for a query for quantile <em>q</em> that returns value <em>v</em>, the error is * |rank(v) - q * N| / N, where N is the number of elements added to the digest and rank(v) is the * real rank of <em>v</em></p> * * <p>This class also supports exponential decay. The implementation is based on the ideas laid out * in http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.159.3978</p> */ @ThreadSafe public class QuantileDigest { private static final int MAX_BITS = 64; private static final double MAX_SIZE_FACTOR = 1.5; // needs to be such that Math.exp(alpha * seconds) does not grow too big static final long RESCALE_THRESHOLD_SECONDS = 50; static final double ZERO_WEIGHT_THRESHOLD = 1e-5; private final double maxError; private final Clock clock; private final double alpha; private final boolean compressAutomatically; private Node root; private double weightedCount; private long max; private long min = Long.MAX_VALUE; private long landmarkInSeconds; private int totalNodeCount = 0; private int nonZeroNodeCount = 0; private int compressions = 0; private int maxTotalNodeCount = 0; private int maxTotalNodesAfterCompress = 0; private enum TraversalOrder { FORWARD, REVERSE } /** * <p>Create a QuantileDigest with a maximum error guarantee of "maxError" and no decay. * * @param maxError the max error tolerance */ public QuantileDigest(double maxError) { this(maxError, 0); } /** *<p>Create a QuantileDigest with a maximum error guarantee of "maxError" and exponential decay * with factor "alpha".</p> * * @param maxError the max error tolerance * @param alpha the exponential decay factor (0.0 => no decay) */ public QuantileDigest(double maxError, double alpha) { this(maxError, alpha, new RealtimeClock(), true); } @VisibleForTesting QuantileDigest(double maxError, double alpha, Clock clock, boolean compressAutomatically) { checkArgument(maxError >= 0 && maxError <= 1, "maxError must be in range [0, 1]"); checkArgument(alpha >= 0 && alpha < 1, "alpha must be in range [0, 1)"); this.maxError = maxError; this.alpha = alpha; this.clock = clock; this.compressAutomatically = compressAutomatically; landmarkInSeconds = TimeUnit.MILLISECONDS.toSeconds(clock.getMillis()); } /** * Adds a value to this digest. The value must be >= 0 * * @param value */ public synchronized void add(long value) { checkArgument(value >= 0, "value must be >= 0"); long nowInSeconds = TimeUnit.MILLISECONDS.toSeconds(clock.getMillis()); int maxExpectedNodeCount = 3 * calculateCompressionFactor(); if (nowInSeconds - landmarkInSeconds >= RESCALE_THRESHOLD_SECONDS) { rescale(nowInSeconds); compress(); // need to compress to get rid of nodes that may have decayed to ~ 0 } else if (nonZeroNodeCount > MAX_SIZE_FACTOR * maxExpectedNodeCount && compressAutomatically) { // The size (number of non-zero nodes) of the digest is at most 3 * compression factor // If we're over MAX_SIZE_FACTOR of the expected size, compress // Note: we don't compress as soon as we go over expectedNodeCount to avoid unnecessarily // running a compression for every new added element when we're close to boundary compress(); } double weight = weight(TimeUnit.MILLISECONDS.toSeconds(clock.getMillis())); weightedCount += weight; max = Math.max(max, value); min = Math.min(min, value); insert(value, weight); } /** * Gets the values at the specified quantiles +/- maxError. The list of quantiles must be sorted * in increasing order, and each value must be in the range [0, 1] */ public synchronized List<Long> getQuantiles(List<Double> quantiles) { checkArgument(Ordering.natural().isOrdered(quantiles), "quantiles must be sorted in increasing order"); for (double quantile : quantiles) { checkArgument(quantile >= 0 && quantile <= 1, "quantile must be between [0,1]"); } final ImmutableList.Builder<Long> builder = ImmutableList.builder(); final PeekingIterator<Double> iterator = Iterators.peekingIterator(quantiles.iterator()); postOrderTraversal(root, new Callback() { private double sum = 0; public boolean process(Node node) { sum += node.weightedCount; while (iterator.hasNext() && sum > iterator.peek() * weightedCount) { iterator.next(); // we know the max value ever seen, so cap the percentile to provide better error // bounds in this case long value = Math.min(node.getUpperBound(), max); builder.add(value); } return iterator.hasNext(); } }); // we finished the traversal without consuming all quantiles. This means the remaining quantiles // correspond to the max known value while (iterator.hasNext()) { builder.add(max); iterator.next(); } return builder.build(); } /** * Gets the value at the specified quantile +/- maxError. The quantile must be in the range [0, 1] */ public synchronized long getQuantile(double quantile) { return getQuantiles(ImmutableList.of(quantile)).get(0); } /** * Number (decayed) of elements added to this quantile digest */ public synchronized double getCount() { return weightedCount / weight(TimeUnit.MILLISECONDS.toSeconds(clock.getMillis())); } /* * Get the exponentially-decayed approximate counts of values in multiple buckets. The elements in * the provided list denote the upper bound each of the buckets and must be sorted in ascending * order. * * The approximate count in each bucket is guaranteed to be within 2 * totalCount * maxError of * the real count. */ public synchronized List<Bucket> getHistogram(List<Long> bucketUpperBounds) { checkArgument( Ordering.natural().isOrdered(bucketUpperBounds), "buckets must be sorted in increasing order" ); final ImmutableList.Builder<Bucket> builder = ImmutableList.builder(); final PeekingIterator<Long> iterator = Iterators.peekingIterator(bucketUpperBounds.iterator()); final AtomicDouble sum = new AtomicDouble(); final AtomicDouble lastSum = new AtomicDouble(); // for computing weighed average of values in bucket final AtomicDouble bucketWeightedSum = new AtomicDouble(); final double normalizationFactor = weight(TimeUnit.MILLISECONDS.toSeconds(clock.getMillis())); postOrderTraversal(root, new Callback() { public boolean process(Node node) { while (iterator.hasNext() && iterator.peek() <= node.getUpperBound()) { double bucketCount = sum.get() - lastSum.get(); Bucket bucket = new Bucket( bucketCount / normalizationFactor, bucketWeightedSum.get() / bucketCount); builder.add(bucket); lastSum.set(sum.get()); bucketWeightedSum.set(0); iterator.next(); } bucketWeightedSum.addAndGet(node.getMiddle() * node.weightedCount); sum.addAndGet(node.weightedCount); return iterator.hasNext(); } }); while (iterator.hasNext()) { double bucketCount = sum.get() - lastSum.get(); Bucket bucket = new Bucket( bucketCount / normalizationFactor, bucketWeightedSum.get() / bucketCount); builder.add(bucket); iterator.next(); } return builder.build(); } public long getMin() { final AtomicLong chosen = new AtomicLong(min); postOrderTraversal(root, new Callback() { public boolean process(Node node) { if (node.weightedCount >= ZERO_WEIGHT_THRESHOLD) { chosen.set(node.getLowerBound()); return false; } return true; } }, TraversalOrder.FORWARD); return Math.max(min, chosen.get()); } public long getMax() { final AtomicLong chosen = new AtomicLong(max); postOrderTraversal(root, new Callback() { public boolean process(Node node) { if (node.weightedCount >= ZERO_WEIGHT_THRESHOLD) { chosen.set(node.getUpperBound()); return false; } return true; } }, TraversalOrder.REVERSE); return Math.min(max, chosen.get()); } @VisibleForTesting synchronized int getTotalNodeCount() { return totalNodeCount; } @VisibleForTesting synchronized int getNonZeroNodeCount() { return nonZeroNodeCount; } @VisibleForTesting synchronized int getCompressions() { return compressions; } @VisibleForTesting synchronized void compress() { ++compressions; final int compressionFactor = calculateCompressionFactor(); postOrderTraversal(root, new Callback() { public boolean process(Node node) { if (node.isLeaf()) { return true; } // if children's weights are ~0 remove them and shift the weight to their parent double leftWeight = 0; if (node.left != null) { leftWeight = node.left.weightedCount; } double rightWeight = 0; if (node.right != null) { rightWeight = node.right.weightedCount; } boolean shouldCompress = node.weightedCount + leftWeight + rightWeight < weightedCount / compressionFactor; double oldNodeWeight = node.weightedCount; if (shouldCompress || leftWeight < ZERO_WEIGHT_THRESHOLD) { node.left = tryRemove(node.left); weightedCount += leftWeight; node.weightedCount += leftWeight; } if (shouldCompress || rightWeight < ZERO_WEIGHT_THRESHOLD) { node.right = tryRemove(node.right); weightedCount += rightWeight; node.weightedCount += rightWeight; } if (oldNodeWeight < ZERO_WEIGHT_THRESHOLD && node.weightedCount >= ZERO_WEIGHT_THRESHOLD) { ++nonZeroNodeCount; } return true; } }); if (root != null && root.weightedCount < ZERO_WEIGHT_THRESHOLD) { root = tryRemove(root); } maxTotalNodesAfterCompress = Math.max(maxTotalNodesAfterCompress, totalNodeCount); } private double weight(long timestamp) { return Math.exp(alpha * (timestamp - landmarkInSeconds)); } private void rescale(long newLandmarkInSeconds) { // rescale the weights based on a new landmark to avoid numerical overflow issues final double factor = Math.exp(-alpha * (newLandmarkInSeconds - landmarkInSeconds)); weightedCount *= factor; postOrderTraversal(root, new Callback() { public boolean process(Node node) { double oldWeight = node.weightedCount; node.weightedCount *= factor; if (oldWeight >= ZERO_WEIGHT_THRESHOLD && node.weightedCount < ZERO_WEIGHT_THRESHOLD) { --nonZeroNodeCount; } return true; } }); landmarkInSeconds = newLandmarkInSeconds; } private int calculateCompressionFactor() { if (root == null) { return 1; } return Math.max((int) ((root.level + 1) / maxError), 1); } private void insert(long value, double weight) { long lastBranch = 0; Node parent = null; Node current = root; while (true) { if (current == null) { setChild(parent, lastBranch, createLeaf(value, weight)); return; } else if ((value >>> current.level) != (current.value >>> current.level)) { // if value and node.value are not in the same branch given node's level, // insert a parent above them at the point at which branches diverge setChild(parent, lastBranch, makeSiblings(current, createLeaf(value, weight))); return; } else if (current.level == 0 && current.value == value) { // found the node double oldWeight = current.weightedCount; current.weightedCount += weight; if (current.weightedCount >= ZERO_WEIGHT_THRESHOLD && oldWeight < ZERO_WEIGHT_THRESHOLD) { ++nonZeroNodeCount; } return; } // we're on the correct branch of the tree and we haven't reached a leaf, so keep going down long branch = value & current.getBranchMask(); parent = current; lastBranch = branch; if (branch == 0) { current = current.left; } else { current = current.right; } } } private void setChild(Node parent, long branch, Node child) { if (parent == null) { root = child; } else if (branch == 0) { parent.left = child; } else { parent.right = child; } } private Node makeSiblings(Node node, Node sibling) { int parentLevel = MAX_BITS - Long.numberOfLeadingZeros(node.value ^ sibling.value); Node parent = new Node(node.value, parentLevel, 0); // the branch is given by the bit at the level one below parent long branch = sibling.value & parent.getBranchMask(); if (branch == 0) { parent.left = sibling; parent.right = node; } else { parent.left = node; parent.right = sibling; } ++totalNodeCount; maxTotalNodeCount = Math.max(maxTotalNodeCount, totalNodeCount); return parent; } private Node createLeaf(long value, double weight) { ++totalNodeCount; maxTotalNodeCount = Math.max(maxTotalNodeCount, totalNodeCount); ++nonZeroNodeCount; return new Node(value, 0, weight); } /** * Remove the node if possible or set its count to 0 if it has children and * it needs to be kept around */ private Node tryRemove(Node node) { if (node == null) { return null; } if (node.weightedCount >= ZERO_WEIGHT_THRESHOLD) { --nonZeroNodeCount; } weightedCount -= node.weightedCount; Node result = null; if (node.isLeaf()) { --totalNodeCount; } else if (node.hasSingleChild()) { result = node.getSingleChild(); --totalNodeCount; } else { node.weightedCount = 0; result = node; } return result; } private boolean postOrderTraversal(Node node, Callback callback) { return postOrderTraversal(node, callback, TraversalOrder.FORWARD); } // returns true if traversal should continue private boolean postOrderTraversal(Node node, Callback callback, TraversalOrder order) { if (node == null) { return false; } Node first; Node second; if (order == TraversalOrder.FORWARD) { first = node.left; second = node.right; } else { first = node.right; second = node.left; } if (first != null && !postOrderTraversal(first, callback, order)) { return false; } if (second != null && !postOrderTraversal(second, callback, order)) { return false; } return callback.process(node); } /** * Computes the maximum error of the current digest */ public synchronized double getConfidenceFactor() { return computeMaxPathWeight(root) * 1.0 / weightedCount; } /** * Computes the max "weight" of any path starting at node and ending at a leaf in the * hypothetical complete tree. The weight is the sum of counts in the ancestors of a given node */ private double computeMaxPathWeight(Node node) { if (node == null || node.level == 0) { return 0; } double leftMaxWeight = computeMaxPathWeight(node.left); double rightMaxWeight = computeMaxPathWeight(node.right); return Math.max(leftMaxWeight, rightMaxWeight) + node.weightedCount; } @VisibleForTesting synchronized void validate() { final AtomicDouble sumOfWeights = new AtomicDouble(); final AtomicInteger actualNodeCount = new AtomicInteger(); final AtomicInteger actualNonZeroNodeCount = new AtomicInteger(); if (root != null) { validateStructure(root); postOrderTraversal(root, new Callback() { @Override public boolean process(Node node) { sumOfWeights.addAndGet(node.weightedCount); actualNodeCount.incrementAndGet(); if (node.weightedCount > ZERO_WEIGHT_THRESHOLD) { actualNonZeroNodeCount.incrementAndGet(); } return true; } }); } checkState(Math.abs(sumOfWeights.get() - weightedCount) < ZERO_WEIGHT_THRESHOLD, "Computed weight (%s) doesn't match summary (%s)", sumOfWeights.get(), weightedCount); checkState(actualNodeCount.get() == totalNodeCount, "Actual node count (%s) doesn't match summary (%s)", actualNodeCount.get(), totalNodeCount); checkState(actualNonZeroNodeCount.get() == nonZeroNodeCount, "Actual non-zero node count (%s) doesn't match summary (%s)", actualNonZeroNodeCount.get(), nonZeroNodeCount); } private void validateStructure(Node node) { checkState(node.level >= 0); if (node.left != null) { validateBranchStructure(node, node.left, node.right, true); validateStructure(node.left); } if (node.right != null) { validateBranchStructure(node, node.right, node.left, false); validateStructure(node.right); } } private void validateBranchStructure(Node parent, Node child, Node otherChild, boolean isLeft) { checkState(child.level < parent.level, "Child level (%s) should be smaller than parent level (%s)", child.level, parent.level); long branch = child.value & (1L << (parent.level - 1)); checkState(branch == 0 && isLeft || branch != 0 && !isLeft, "Value of child node is inconsistent with its branch"); Preconditions.checkState(parent.weightedCount >= ZERO_WEIGHT_THRESHOLD || child.weightedCount >= ZERO_WEIGHT_THRESHOLD || otherChild != null, "Found a linear chain of zero-weight nodes"); } public static class Bucket { private double count; private double mean; public Bucket(double count, double mean) { this.count = count; this.mean = mean; } public double getCount() { return count; } public double getMean() { return mean; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } final Bucket bucket = (Bucket) o; if (Double.compare(bucket.count, count) != 0) { return false; } if (Double.compare(bucket.mean, mean) != 0) { return false; } return true; } @Override public int hashCode() { int result; long temp; temp = count != +0.0d ? Double.doubleToLongBits(count) : 0L; result = (int) (temp ^ (temp >>> 32)); temp = mean != +0.0d ? Double.doubleToLongBits(mean) : 0L; result = 31 * result + (int) (temp ^ (temp >>> 32)); return result; } public String toString() { return String.format("[count: %f, mean: %f]", count, mean); } } private static class Node { private double weightedCount; private int level; private long value; private Node left; private Node right; private Node(long value, int level, double weightedCount) { this.value = value; this.level = level; this.weightedCount = weightedCount; } public boolean isLeaf() { return left == null && right == null; } public boolean hasSingleChild() { return left == null && right != null || left != null && right == null; } public Node getSingleChild() { checkState(hasSingleChild(), "Node does not have a single child"); return MoreObjects.firstNonNull(left, right); } public long getUpperBound() { // set all lsb below level to 1 (we're looking for the highest value of the range covered // by this node) long mask = (1L << level) - 1; return value | mask; } public long getBranchMask() { return (1L << (level - 1)); } public long getLowerBound() { // set all lsb below level to 0 (we're looking for the lowes value of the range covered // by this node) long mask = (0x7FFFFFFFFFFFFFFFL << level); return value & mask; } public long getMiddle() { return getLowerBound() + (getUpperBound() - getLowerBound()) / 2; } public String toString() { return format("%s (level = %d, count = %s, left = %s, right = %s)", value, level, weightedCount, left != null, right != null); } } private static interface Callback { /** * @param node the node to process * @return true if processing should continue */ boolean process(Node node); } }