/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation; import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.operator.aggregation.state.LongState; import com.facebook.presto.operator.aggregation.state.NullableLongState; import com.facebook.presto.operator.aggregation.state.StateCompiler; import com.facebook.presto.operator.aggregation.state.VarianceState; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateSerializer; import com.facebook.presto.spi.function.GroupedAccumulatorState; import com.facebook.presto.spi.type.Type; import com.facebook.presto.type.ArrayType; import com.facebook.presto.type.RowType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; import java.util.Map; import java.util.Optional; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.util.StructuralTestUtil.mapType; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedDoubleArray; import static org.testng.Assert.assertEquals; public class TestStateCompiler { private static final int SLICE_INSTANCE_SIZE = ClassLayout.parseClass(Slice.class).instanceSize(); @Test public void testPrimitiveNullableLongSerialization() { AccumulatorStateFactory<NullableLongState> factory = StateCompiler.generateStateFactory(NullableLongState.class); AccumulatorStateSerializer<NullableLongState> serializer = StateCompiler.generateStateSerializer(NullableLongState.class); NullableLongState state = factory.createSingleState(); NullableLongState deserializedState = factory.createSingleState(); state.setLong(2); state.setNull(false); BlockBuilder builder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), 2); serializer.serialize(state, builder); state.setNull(true); serializer.serialize(state, builder); Block block = builder.build(); assertEquals(block.isNull(0), false); assertEquals(BIGINT.getLong(block, 0), state.getLong()); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getLong(), state.getLong()); assertEquals(block.isNull(1), true); } @Test public void testPrimitiveLongSerialization() { AccumulatorStateFactory<LongState> factory = StateCompiler.generateStateFactory(LongState.class); AccumulatorStateSerializer<LongState> serializer = StateCompiler.generateStateSerializer(LongState.class); LongState state = factory.createSingleState(); LongState deserializedState = factory.createSingleState(); state.setLong(2); BlockBuilder builder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(state, builder); Block block = builder.build(); assertEquals(BIGINT.getLong(block, 0), state.getLong()); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getLong(), state.getLong()); } @Test public void testGetSerializedType() { AccumulatorStateSerializer<LongState> serializer = StateCompiler.generateStateSerializer(LongState.class); assertEquals(serializer.getSerializedType(), BIGINT); } @Test public void testPrimitiveBooleanSerialization() { AccumulatorStateFactory<BooleanState> factory = StateCompiler.generateStateFactory(BooleanState.class); AccumulatorStateSerializer<BooleanState> serializer = StateCompiler.generateStateSerializer(BooleanState.class); BooleanState state = factory.createSingleState(); BooleanState deserializedState = factory.createSingleState(); state.setBoolean(true); BlockBuilder builder = BOOLEAN.createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(state, builder); Block block = builder.build(); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.isBoolean(), state.isBoolean()); } @Test public void testPrimitiveByteSerialization() { AccumulatorStateFactory<ByteState> factory = StateCompiler.generateStateFactory(ByteState.class); AccumulatorStateSerializer<ByteState> serializer = StateCompiler.generateStateSerializer(ByteState.class); ByteState state = factory.createSingleState(); ByteState deserializedState = factory.createSingleState(); state.setByte((byte) 3); BlockBuilder builder = TINYINT.createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(state, builder); Block block = builder.build(); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getByte(), state.getByte()); } @Test public void testNonPrimitiveSerialization() { AccumulatorStateFactory<SliceState> factory = StateCompiler.generateStateFactory(SliceState.class); AccumulatorStateSerializer<SliceState> serializer = StateCompiler.generateStateSerializer(SliceState.class); SliceState state = factory.createSingleState(); SliceState deserializedState = factory.createSingleState(); state.setSlice(null); BlockBuilder nullBlockBuilder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(state, nullBlockBuilder); Block nullBlock = nullBlockBuilder.build(); serializer.deserialize(nullBlock, 0, deserializedState); assertEquals(deserializedState.getSlice(), state.getSlice()); state.setSlice(utf8Slice("test")); BlockBuilder builder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(state, builder); Block block = builder.build(); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getSlice(), state.getSlice()); } @Test public void testVarianceStateSerialization() { AccumulatorStateFactory<VarianceState> factory = StateCompiler.generateStateFactory(VarianceState.class); AccumulatorStateSerializer<VarianceState> serializer = StateCompiler.generateStateSerializer(VarianceState.class); VarianceState singleState = factory.createSingleState(); VarianceState deserializedState = factory.createSingleState(); singleState.setMean(1); singleState.setCount(2); singleState.setM2(3); BlockBuilder builder = new RowType(ImmutableList.of(BIGINT, DOUBLE, DOUBLE), Optional.empty()).createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(singleState, builder); Block block = builder.build(); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getCount(), singleState.getCount()); assertEquals(deserializedState.getMean(), singleState.getMean()); assertEquals(deserializedState.getM2(), singleState.getM2()); } @Test public void testComplexSerialization() { Type arrayType = new ArrayType(BIGINT); Type mapType = mapType(BIGINT, VARCHAR); Map<String, Type> fieldMap = ImmutableMap.of("Block", arrayType, "AnotherBlock", mapType); AccumulatorStateFactory<TestComplexState> factory = StateCompiler.generateStateFactory(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader())); AccumulatorStateSerializer<TestComplexState> serializer = StateCompiler.generateStateSerializer(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader())); TestComplexState singleState = factory.createSingleState(); TestComplexState deserializedState = factory.createSingleState(); singleState.setBoolean(true); singleState.setLong(1); singleState.setDouble(2.0); singleState.setByte((byte) 3); singleState.setSlice(utf8Slice("test")); singleState.setAnotherSlice(wrappedDoubleArray(1.0, 2.0, 3.0)); singleState.setYetAnotherSlice(null); Block array = createLongsBlock(45); singleState.setBlock(array); BlockBuilder mapBlockBuilder = new InterleavedBlockBuilder(ImmutableList.of(BIGINT, VARCHAR), new BlockBuilderStatus(), 1); BIGINT.writeLong(mapBlockBuilder, 123L); VARCHAR.writeSlice(mapBlockBuilder, utf8Slice("testBlock")); Block map = mapBlockBuilder.build(); singleState.setAnotherBlock(map); BlockBuilder builder = new RowType(ImmutableList.of(BOOLEAN, TINYINT, DOUBLE, BIGINT, mapType, VARBINARY, arrayType, VARBINARY, VARBINARY), Optional.empty()) .createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(singleState, builder); Block block = builder.build(); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getBoolean(), singleState.getBoolean()); assertEquals(deserializedState.getLong(), singleState.getLong()); assertEquals(deserializedState.getDouble(), singleState.getDouble()); assertEquals(deserializedState.getByte(), singleState.getByte()); assertEquals(deserializedState.getSlice(), singleState.getSlice()); assertEquals(deserializedState.getAnotherSlice(), singleState.getAnotherSlice()); assertEquals(deserializedState.getYetAnotherSlice(), singleState.getYetAnotherSlice()); assertEquals(deserializedState.getBlock().getLong(0, 0), singleState.getBlock().getLong(0, 0)); assertEquals(deserializedState.getAnotherBlock().getLong(0, 0), singleState.getAnotherBlock().getLong(0, 0)); assertEquals(deserializedState.getAnotherBlock().getSlice(1, 0, 9), singleState.getAnotherBlock().getSlice(1, 0, 9)); } //see SliceBigArray::getSize private long getSize(Slice slice) { return slice.length() + SLICE_INSTANCE_SIZE; } @Test public void testComplexStateEstimatedSize() { Map<String, Type> fieldMap = ImmutableMap.of("Block", new ArrayType(BIGINT), "AnotherBlock", mapType(BIGINT, VARCHAR)); AccumulatorStateFactory<TestComplexState> factory = StateCompiler.generateStateFactory(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader())); TestComplexState groupedState = factory.createGroupedState(); assertEquals(groupedState.getEstimatedSize(), 76064); for (int i = 0; i < 1000; i++) { long retainedSize = 0; ((GroupedAccumulatorState) groupedState).setGroupId(i); groupedState.setBoolean(true); groupedState.setLong(1); groupedState.setDouble(2.0); groupedState.setByte((byte) 3); Slice slice = utf8Slice("test"); retainedSize += getSize(slice); groupedState.setSlice(slice); slice = wrappedDoubleArray(1.0, 2.0, 3.0); retainedSize += getSize(slice); groupedState.setAnotherSlice(slice); groupedState.setYetAnotherSlice(null); Block array = createLongsBlock(45); retainedSize += array.getRetainedSizeInBytes(); groupedState.setBlock(array); BlockBuilder mapBlockBuilder = new InterleavedBlockBuilder(ImmutableList.of(BIGINT, VARCHAR), new BlockBuilderStatus(), 1); BIGINT.writeLong(mapBlockBuilder, 123L); VARCHAR.writeSlice(mapBlockBuilder, utf8Slice("testBlock")); Block map = mapBlockBuilder.build(); retainedSize += map.getRetainedSizeInBytes(); groupedState.setAnotherBlock(map); assertEquals(groupedState.getEstimatedSize(), 76064 + retainedSize * (i + 1)); } for (int i = 0; i < 1000; i++) { long retainedSize = 0; ((GroupedAccumulatorState) groupedState).setGroupId(i); groupedState.setBoolean(true); groupedState.setLong(1); groupedState.setDouble(2.0); groupedState.setByte((byte) 3); Slice slice = utf8Slice("test"); retainedSize += getSize(slice); groupedState.setSlice(slice); slice = wrappedDoubleArray(1.0, 2.0, 3.0); retainedSize += getSize(slice); groupedState.setAnotherSlice(slice); groupedState.setYetAnotherSlice(null); Block array = createLongsBlock(45); retainedSize += array.getRetainedSizeInBytes(); groupedState.setBlock(array); BlockBuilder mapBlockBuilder = new InterleavedBlockBuilder(ImmutableList.of(BIGINT, VARCHAR), new BlockBuilderStatus(), 1); BIGINT.writeLong(mapBlockBuilder, 123L); VARCHAR.writeSlice(mapBlockBuilder, utf8Slice("testBlock")); Block map = mapBlockBuilder.build(); retainedSize += map.getRetainedSizeInBytes(); groupedState.setAnotherBlock(map); assertEquals(groupedState.getEstimatedSize(), 76064 + retainedSize * 1000); } } public interface TestComplexState extends AccumulatorState { double getDouble(); void setDouble(double value); boolean getBoolean(); void setBoolean(boolean value); long getLong(); void setLong(long value); byte getByte(); void setByte(byte value); Slice getSlice(); void setSlice(Slice slice); Slice getAnotherSlice(); void setAnotherSlice(Slice slice); Slice getYetAnotherSlice(); void setYetAnotherSlice(Slice slice); Block getBlock(); void setBlock(Block block); Block getAnotherBlock(); void setAnotherBlock(Block block); } public interface BooleanState extends AccumulatorState { boolean isBoolean(); void setBoolean(boolean value); } public interface ByteState extends AccumulatorState { byte getByte(); void setByte(byte value); } public interface SliceState extends AccumulatorState { Slice getSlice(); void setSlice(Slice slice); } }