/*
* Copyright (C) 2012 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.stats.cardinality;
import com.google.common.primitives.Ints;
import org.testng.Assert;
import org.testng.annotations.Test;
import java.io.ByteArrayOutputStream;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import static com.facebook.stats.cardinality.StaticModelUtil.SMALLEST_PROBABILITY;
import static java.lang.Math.E;
import static java.lang.Math.PI;
import static java.lang.Math.log;
import static java.lang.Math.pow;
import static java.lang.Math.sqrt;
@SuppressWarnings({"RedundantArrayCreation", "NonReproducibleMathCall", "ConstantMathCall"})
public class TestArithmeticCodec {
private static final Random random = new SecureRandom();
@Test
public void testPossibleOverflowInClose() throws Exception {
testRoundTrip(
new SortedStaticModel(new ExponentiallyDecreasingHistogramFactory().create(32)),
32,
Ints.asList(new int[]{25, 22, 1, 11, 5, 3, 6, 5, 25, 9, 2, 9, 3, 3, 17, 20}),
16
);
testRoundTrip(
new SortedStaticModel(new ExponentiallyDecreasingHistogramFactory().create(32)),
32,
Ints.asList(new int[]{27, 28, 1, 25, 4, 5, 15, 7, 14, 3, 23, 15, 25, 12, 3, 15}),
16
);
}
@Test
public void testUnderflowBytesInClose() throws Exception {
testRoundTrip(
new SortedStaticModel(new ExponentiallyDecreasingHistogramFactory().create(8)),
8,
Ints.asList(new int[]{3, 5, 4, 0, 2, 4, 7, 5, 7, 7, 3, 1, 1, 5, 0, 3}),
16
);
testRoundTrip(
new SortedStaticModel(new ExponentiallyDecreasingHistogramFactory().create(16)),
16,
Ints.asList(new int[]{12, 12, 4, 4, 13, 2, 9, 8, 9, 1, 0, 8, 2, 11, 12, 1}),
16
);
}
@Test
public void testDecodeZeroPaddingRequired() throws Exception {
// ArithmeticDecoder buffers 6 bytes; when there are fewer than 6, it should treat the input as
// if it had zeros for the missing bytes. In practice, this rarely matters, but for the case
// below, getting it wrong results in an "IllegalArgumentException: targetCount is negative" due
// to ArithmeticDecoder.bufferByte() removing underflow bytes from high and low, but not value.
int[] buckets = new int[2048];
buckets[860] = 1;
buckets[1258] = 1;
buckets[1618] = 1;
buckets[2033] = 1;
testRoundTrip(
HyperLogLogCodec.createHyperLogLogSymbolModel(4, 2048, (byte) 1),
2,
Ints.asList(buckets),
2048
);
}
@Test
public void testRoundTrip() throws Exception {
testRoundTrip(new SortedStaticDataModelFactory(new ExponentiallyDecreasingHistogramFactory()));
testRoundTrip(new SortedStaticDataModelFactory(new GaussianHistogramFactory()));
testRoundTrip(new SortedStaticDataModelFactory(new RandomHistogramFactory()));
testRoundTrip(new StaticDataModelFactory(new ExponentiallyDecreasingHistogramFactory()));
testRoundTrip(new StaticDataModelFactory(new GaussianHistogramFactory()));
testRoundTrip(new StaticDataModelFactory(new RandomHistogramFactory()));
}
public void testRoundTrip(DataModelFactory modelFactory) throws Exception {
testRoundTrip(modelFactory, new SequentialDataFactory());
testRoundTrip(modelFactory, new RandomDataFactory());
}
private void testRoundTrip(DataModelFactory modelFactory, DataFactory dataFactory)
throws Exception {
for (int size = 1; size < 100000; size <<= 1) {
for (int numberOfSymbols = 2; numberOfSymbols <= 512; numberOfSymbols <<= 1) {
testRoundTrip(
modelFactory.create(numberOfSymbols),
numberOfSymbols,
dataFactory.create(size, numberOfSymbols),
size
);
}
}
}
private void testRoundTrip(Model model, int numberOfSymbols, Iterable<Integer> symbols, int size)
throws Exception {
try {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ArithmeticEncoder encoder = new ArithmeticEncoder(model, out);
for (Integer symbol : symbols) {
encoder.encode(symbol);
}
encoder.close();
ArithmeticDecoder decoder = new ArithmeticDecoder(model, out.toByteArray());
for (Integer symbol : symbols) {
int newData = decoder.decode();
if (newData != symbol) {
Assert.assertEquals(
newData, (int) symbol, String.format(
"size=%d, numberOfSymbols=%d",
size,
numberOfSymbols
)
);
}
}
} catch (Exception e) {
throw new RuntimeException(
String.format(
"size=%d, numberOfSymbols=%d",
size,
numberOfSymbols
), e
);
}
}
public static interface DataModelFactory {
Model create(int numberOfSymbols);
}
private static class StaticDataModelFactory implements DataModelFactory {
private final HistogramFactory histogramFactory;
private StaticDataModelFactory(HistogramFactory histogramFactory) {
this.histogramFactory = histogramFactory;
}
public StaticModel create(int numberOfSymbols) {
double[] weights = histogramFactory.create(numberOfSymbols);
return new StaticModel(weights);
}
}
private static class SortedStaticDataModelFactory implements DataModelFactory {
private final HistogramFactory histogramFactory;
private SortedStaticDataModelFactory(HistogramFactory histogramFactory) {
this.histogramFactory = histogramFactory;
}
public SortedStaticModel create(int numberOfSymbols) {
double[] weights = histogramFactory.create(numberOfSymbols);
return new SortedStaticModel(weights);
}
}
private interface HistogramFactory {
public double[] create(int numberOfSymbols);
}
private static class ExponentiallyDecreasingHistogramFactory implements HistogramFactory {
public double[] create(int numberOfSymbols) {
double maxExponent = log(1 / SMALLEST_PROBABILITY);
double[] probability = new double[numberOfSymbols];
for (int symbol = 0; symbol < probability.length; symbol++) {
double exponent = symbol * maxExponent / numberOfSymbols;
probability[symbol] = 1.0D / pow(E, exponent);
}
return probability;
}
}
private static class GaussianHistogramFactory implements HistogramFactory {
public double[] create(int numberOfSymbols) {
double mean = numberOfSymbols / 2.0;
double std = sqrt(mean);
double[] probability = new double[numberOfSymbols];
for (int i = 0; i < probability.length; i++) {
// see wikipedia
double value = (1 / (std * sqrt(2.0 * PI))) * (pow(
E, -(pow(i - mean, 2) / (2 * pow(
std,
2
)))
));
probability[i] = value;
}
return probability;
}
}
private static class RandomHistogramFactory implements HistogramFactory {
public double[] create(int numberOfSymbols) {
double[] probability = new double[numberOfSymbols];
for (int i = 0; i < probability.length; i++) {
double value = random.nextDouble();
probability[i] = value;
}
return probability;
}
}
public static interface DataFactory {
Iterable<Integer> create(int size, int numberOfSymbols);
}
public static class SequentialDataFactory implements DataFactory {
@Override
public Iterable<Integer> create(int size, int numberOfSymbols) {
List<Integer> data = new ArrayList<Integer>(size);
for (int i = 0; i < size; i++) {
data.add(i % numberOfSymbols);
}
return data;
}
}
public static class RandomDataFactory implements DataFactory {
@Override
public Iterable<Integer> create(int size, int numberOfSymbols) {
List<Integer> data = new ArrayList<Integer>(size);
for (int i = 0; i < size; i++) {
data.add(random.nextInt(numberOfSymbols));
}
return data;
}
}
}