/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.operator.HashPartitionMaskOperator.HashPartitionMaskOperatorFactory; import com.facebook.presto.spi.Page; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.type.BigintOperators; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.airlift.slice.XxHash64; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.stream.IntStream; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static java.util.concurrent.Executors.newCachedThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @Test public class TestHashPartitionMaskOperator { private static final int PARTITION_COUNT = 5; private static final int ROW_COUNT = 100; private ExecutorService executor; @BeforeClass public void setUp() { executor = newCachedThreadPool(daemonThreadsNamed("test-%s")); } @AfterClass public void tearDown() { executor.shutdownNow(); } @DataProvider(name = "hashEnabledValues") public static Object[][] hashEnabledValuesProvider() { return new Object[][] { { true }, { false } }; } @Test(dataProvider = "hashEnabledValues") public void testHashPartitionMask(boolean hashEnabled) throws Exception { RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, Ints.asList(0), BIGINT); List<Page> input = rowPagesBuilder .addSequencePage(ROW_COUNT, 0) .build(); OperatorFactory operatorFactory = new HashPartitionMaskOperatorFactory( 0, new PlanNodeId("test"), PARTITION_COUNT, rowPagesBuilder.getTypes(), ImmutableList.of(), ImmutableList.of(0), rowPagesBuilder.getHashChannel()); int[] rowPartition = new int[ROW_COUNT]; Arrays.fill(rowPartition, -1); for (int partition = 0; partition < PARTITION_COUNT; partition++) { MaterializedResult.Builder expected = resultBuilder(TEST_SESSION, BIGINT, BOOLEAN); for (int i = 0; i < ROW_COUNT; i++) { long rawHash = BigintOperators.hashCode(i); // mix the bits so we don't use the same hash used to distribute between stages rawHash = XxHash64.hash(Long.reverse(rawHash)); rawHash &= Long.MAX_VALUE; boolean active = (rawHash % PARTITION_COUNT == partition); expected.row((long) i, active); if (active) { assertEquals(rowPartition[i], -1); rowPartition[i] = partition; } } OperatorAssertion.assertOperatorEqualsIgnoreOrder(operatorFactory, createDriverContext(), input, expected.build(), hashEnabled, Optional.of(1)); } assertTrue(IntStream.of(rowPartition).noneMatch(partition -> partition == -1)); } @Test(dataProvider = "hashEnabledValues") public void testHashPartitionMaskWithMask(boolean hashEnabled) throws Exception { RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, Ints.asList(0), BIGINT, BOOLEAN, BOOLEAN); List<Page> input = rowPagesBuilder .addSequencePage(ROW_COUNT, 0, 0, 1) .build(); OperatorFactory operatorFactory = new HashPartitionMaskOperatorFactory( 0, new PlanNodeId("test"), PARTITION_COUNT, rowPagesBuilder.getTypes(), ImmutableList.of(1, 2), ImmutableList.of(0), rowPagesBuilder.getHashChannel()); int[] rowPartition = new int[ROW_COUNT]; Arrays.fill(rowPartition, -1); for (int partition = 0; partition < PARTITION_COUNT; partition++) { MaterializedResult.Builder expected = resultBuilder(TEST_SESSION, BIGINT, BOOLEAN, BOOLEAN, BOOLEAN); for (int i = 0; i < ROW_COUNT; i++) { long rawHash = BigintOperators.hashCode(i); // mix the bits so we don't use the same hash used to distribute between stages rawHash = XxHash64.hash(Long.reverse(rawHash)); rawHash &= Long.MAX_VALUE; boolean active = (rawHash % PARTITION_COUNT == partition); boolean maskValue = i % 2 == 0; expected.row((long) i, active && maskValue, active && !maskValue, active); if (active) { assertEquals(rowPartition[i], -1); rowPartition[i] = partition; } } OperatorAssertion.assertOperatorEqualsIgnoreOrder(operatorFactory, createDriverContext(), input, expected.build(), hashEnabled, Optional.of(3)); } assertTrue(IntStream.of(rowPartition).noneMatch(partition -> partition == -1)); } public DriverContext createDriverContext() { return createTaskContext(executor, TEST_SESSION) .addPipelineContext(0, true, true) .addDriverContext(); } }