/** * Copyright (C) 2014-2016 LinkedIn Corp. (pinot-core@linkedin.com) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.linkedin.pinot.query.aggregation.groupby; import com.linkedin.pinot.common.response.broker.GroupByResult; import com.linkedin.pinot.core.query.aggregation.AggregationFunctionContext; import com.linkedin.pinot.core.query.aggregation.function.AggregationFunction; import com.linkedin.pinot.core.query.aggregation.function.AggregationFunctionFactory; import com.linkedin.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import org.apache.commons.lang.RandomStringUtils; import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; public class AggregationGroupByTrimmingServiceTest { private static final long RANDOM_SEED = System.currentTimeMillis(); private static final Random RANDOM = new Random(RANDOM_SEED); private static final String ERROR_MESSAGE = "Random seed: " + RANDOM_SEED; private static final AggregationFunction SUM = AggregationFunctionFactory.getAggregationFunction("SUM"); private static final AggregationFunctionContext[] AGGREGATION_FUNCTION_CONTEXTS = {new AggregationFunctionContext(new String[]{"column"}, SUM)}; private static final AggregationFunction[] AGGREGATION_FUNCTIONS = {SUM}; private static final int NUM_GROUP_KEYS = 3; private static final int GROUP_BY_TOP_N = 100; private static final int NUM_GROUPS = 50000; private List<String> _groups; private AggregationGroupByTrimmingService _serverTrimmingService; private AggregationGroupByTrimmingService _brokerTrimmingService; @BeforeClass public void setUp() { // Generate a list of random groups. Set<String> groupSet = new HashSet<>(NUM_GROUPS); while (groupSet.size() < NUM_GROUPS) { String group = ""; for (int i = 0; i < NUM_GROUP_KEYS; i++) { if (i != 0) { group += '\t'; } // Random generate group key without '\t'. String groupKey = RandomStringUtils.random(RANDOM.nextInt(10)); while (groupKey.contains("\t")) { groupKey = RandomStringUtils.random(RANDOM.nextInt(10)); } group += groupKey; } groupSet.add(group); } _groups = new ArrayList<>(groupSet); // Explicitly set an empty group. String emptyGroup = ""; for (int i = 1; i < NUM_GROUP_KEYS; i++) { emptyGroup += '\t'; } _groups.set(NUM_GROUPS - 1, emptyGroup); _serverTrimmingService = new AggregationGroupByTrimmingService(AGGREGATION_FUNCTION_CONTEXTS, GROUP_BY_TOP_N); _brokerTrimmingService = new AggregationGroupByTrimmingService(AGGREGATION_FUNCTIONS, GROUP_BY_TOP_N); } @SuppressWarnings("unchecked") @Test public void testTrimming() { // Test server side trimming. Map<String, Object[]> intermediateResultsMap = new HashMap<>(NUM_GROUPS); for (int i = 0; i < NUM_GROUPS; i++) { intermediateResultsMap.put(_groups.get(i), new Double[]{(double) i}); } Map<String, Object> trimmedIntermediateResultsMap = _serverTrimmingService.trimIntermediateResultsMap(intermediateResultsMap).get(0); int trimSize = trimmedIntermediateResultsMap.size(); for (int i = NUM_GROUPS - trimSize; i < NUM_GROUPS; i++) { Assert.assertEquals(trimmedIntermediateResultsMap.get(_groups.get(i)), (double) i, ERROR_MESSAGE); } // Test broker side trimming. List<GroupByResult> groupByResults = _brokerTrimmingService.trimFinalResults(new Map[]{trimmedIntermediateResultsMap})[0]; for (int i = 0; i < GROUP_BY_TOP_N; i++) { int expectedGroupIndex = NUM_GROUPS - 1 - i; GroupByResult groupByResult = groupByResults.get(i); List<String> group = groupByResult.getGroup(); Assert.assertEquals(group.size(), NUM_GROUP_KEYS, ERROR_MESSAGE); String groupString = ""; for (int j = 0; j < NUM_GROUP_KEYS; j++) { if (j != 0) { groupString += '\t'; } groupString += group.get(j); } Assert.assertEquals(groupString, _groups.get(expectedGroupIndex), ERROR_MESSAGE); Assert.assertEquals(Double.parseDouble((String) groupByResult.getValue()), (double) expectedGroupIndex, ERROR_MESSAGE); } } }