/* * Copyright (C) 2012 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.stats.cardinality; import com.google.common.base.Preconditions; import com.google.common.collect.ComparisonChain; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import static com.facebook.stats.cardinality.StaticModelUtil.weightsToProbabilities; class SortedStaticModel implements Model { private final int[] symbolToIndex; private final int[] indexToSymbol; private final int[] countsByIndex; private final int totalIndex; public SortedStaticModel(double[] weights) { Preconditions.checkNotNull(weights, "weights is null"); Preconditions.checkArgument(weights.length > 1, "weights is empty"); Preconditions.checkArgument( weights.length <= 512, "weights is can not have more than 512 entries" ); symbolToIndex = new int[weights.length + 1]; indexToSymbol = new int[weights.length + 1]; countsByIndex = new int[weights.length + 1]; totalIndex = weights.length; double[] probabilities = weightsToProbabilities(weights, 10); List<SymbolProbability> symbolProbabilities = sortProbabilities(probabilities); int symbolIndex = 0; for (SymbolProbability symbolProbability : symbolProbabilities) { int symbol = symbolProbability.symbol; double probability = symbolProbability.probability; symbolToIndex[symbol] = symbolIndex; indexToSymbol[symbolIndex] = symbol; // value is low count + % of MAX_TOTAL int value = countsByIndex[symbolIndex] + ((int) (StaticModelUtil.MAX_COUNT * probability)); // reserve one space for each symbol value = Math.min(value, StaticModelUtil.MAX_COUNT - (probabilities.length - symbolIndex)); // high must be at least one bigger than the low value = Math.max(value, countsByIndex[symbolIndex] + 1); countsByIndex[symbolIndex + 1] = value; if (countsByIndex[symbolIndex + 1] <= countsByIndex[symbolIndex]) { Preconditions.checkState( countsByIndex[symbolIndex + 1] > countsByIndex[symbolIndex], "Internal error: symbol %s high value %s is not greater than the low value %s", symbol, countsByIndex[symbolIndex + 1], countsByIndex[symbolIndex] ); } symbolIndex++; } Preconditions.checkState( countsByIndex[totalIndex - 1] < StaticModelUtil.MAX_COUNT, "Internal error: model max value %s must be less than %s" ); symbolToIndex[totalIndex] = -1; countsByIndex[totalIndex] = StaticModelUtil.MAX_COUNT; // verify model for (int i = 1; i < countsByIndex.length; i++) { Preconditions.checkState( countsByIndex[i - 1] < countsByIndex[i], "Internal error: model is invalid" ); } } @Override public SymbolInfo getSymbolInfo(int symbol) { Preconditions.checkPositionIndex(symbol, symbolToIndex.length, "symbol"); int symbolIndex = symbolToIndex[symbol]; return new SymbolInfo(symbol, countsByIndex[symbolIndex], countsByIndex[symbolIndex + 1]); } @Override public int log2MaxCount() { return StaticModelUtil.COUNT_BITS; } @Override public SymbolInfo countToSymbol(int targetCount) { Preconditions.checkArgument(targetCount >= 0, "targetCount is negative %s", targetCount); // Since symbols are sorted by probability, simply linearly search for the symbol for (int symbolIndex = 0; symbolIndex < countsByIndex.length; symbolIndex++) { int count = countsByIndex[symbolIndex + 1]; if (targetCount < count) { return new SymbolInfo( indexToSymbol[symbolIndex], countsByIndex[symbolIndex], countsByIndex[symbolIndex + 1] ); } } throw new IllegalArgumentException("invalid target count " + targetCount); } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } SortedStaticModel staticModel = (SortedStaticModel) o; if (totalIndex != staticModel.totalIndex) { return false; } if (!Arrays.equals(countsByIndex, staticModel.countsByIndex)) { return false; } return true; } @Override public int hashCode() { int result = Arrays.hashCode(countsByIndex); result = 31 * result + totalIndex; return result; } private static class SymbolProbability implements Comparable<SymbolProbability> { private final int symbol; private final double probability; private SymbolProbability(int symbol, double probability) { this.symbol = symbol; this.probability = probability; } @Override public int compareTo(SymbolProbability o) { return ComparisonChain .start() .compare(o.probability, probability) .compare(symbol, o.symbol) .result(); } @Override public String toString() { final StringBuilder sb = new StringBuilder(); sb.append("SymbolProbability"); sb.append("{symbol=").append(symbol); sb.append(", probability=").append(probability); sb.append('}'); return sb.toString(); } } private List<SymbolProbability> sortProbabilities(double[] probabilities) { ArrayList<SymbolProbability> symbolProbabilities = new ArrayList<SymbolProbability>(); for (int symbol = 0; symbol < probabilities.length; symbol++) { symbolProbabilities.add(new SymbolProbability(symbol, probabilities[symbol])); } Collections.sort(symbolProbabilities); return symbolProbabilities; } }