/**
* Copyright (C) 2014-2016 LinkedIn Corp. (pinot-core@linkedin.com)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.linkedin.pinot.core.startree;
import com.linkedin.pinot.common.data.Schema;
import com.linkedin.pinot.common.request.BrokerRequest;
import com.linkedin.pinot.common.segment.SegmentMetadata;
import com.linkedin.pinot.common.utils.request.FilterQueryTree;
import com.linkedin.pinot.common.utils.request.RequestUtils;
import com.linkedin.pinot.core.common.BlockDocIdIterator;
import com.linkedin.pinot.core.common.BlockSingleValIterator;
import com.linkedin.pinot.core.common.Constants;
import com.linkedin.pinot.core.common.DataSource;
import com.linkedin.pinot.core.common.Operator;
import com.linkedin.pinot.core.indexsegment.IndexSegment;
import com.linkedin.pinot.core.operator.filter.StarTreeIndexOperator;
import com.linkedin.pinot.core.plan.FilterPlanNode;
import com.linkedin.pinot.core.segment.index.readers.Dictionary;
import com.linkedin.pinot.pql.parsers.Pql2Compiler;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
/**
* Base class containing common functionality for all star-tree integration tests.
*/
public class BaseSumStarTreeIndexTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseSumStarTreeIndexTest.class);
protected final long _randomSeed = System.nanoTime();
protected String[] _hardCodedQueries =
new String[]{
"select sum(m1) from T",
"select sum(m1) from T where d1 = 'd1-v1'",
"select sum(m1) from T where d1 <> 'd1-v1'",
"select sum(m1) from T where d1 between 'd1-v1' and 'd1-v3'",
"select sum(m1) from T where d1 in ('d1-v1', 'd1-v2')",
"select sum(m1) from T where d1 in ('d1-v1', 'd1-v2') and d2 not in ('d2-v1')",
"select sum(m1) from T group by d1", "select sum(m1) from T group by d1, d2",
"select sum(m1) from T where d1 = 'd1-v2' group by d1",
"select sum(m1) from T where d1 between 'd1-v1' and 'd1-v3' group by d2",
"select sum(m1) from T where d1 = 'd1-v2' group by d2, d3",
"select sum(m1) from T where d1 <> 'd1-v1' group by d2",
"select sum(m1) from T where d1 in ('d1-v1', 'd1-v2') group by d2",
"select sum(m1) from T where d1 in ('d1-v1', 'd1-v2') and d2 not in ('d2-v1') group by d3",
"select sum(m1) from T where d1 in ('d1-v1', 'd1-v2') and d2 not in ('d2-v1') group by d3, d4"};
protected void testHardCodedQueries(IndexSegment segment, Schema schema) {
// Test against all metric columns, instead of just the aggregation column in the query.
List<String> metricNames = schema.getMetricNames();
SegmentMetadata segmentMetadata = segment.getSegmentMetadata();
for (int i = 0; i < _hardCodedQueries.length; i++) {
Pql2Compiler compiler = new Pql2Compiler();
BrokerRequest brokerRequest = compiler.compileToBrokerRequest(_hardCodedQueries[i]);
FilterQueryTree filterQueryTree = RequestUtils.generateFilterQueryTree(brokerRequest);
Assert.assertTrue(RequestUtils.isFitForStarTreeIndex(segmentMetadata, filterQueryTree, brokerRequest));
Map<String, double[]> expectedResult = computeSumUsingRawDocs(segment, metricNames, brokerRequest);
Map<String, double[]> actualResult = computeSumUsingAggregatedDocs(segment, metricNames, brokerRequest);
Assert.assertEquals(expectedResult.size(), actualResult.size(), "Mis-match in number of groups");
for (Map.Entry<String, double[]> entry : expectedResult.entrySet()) {
String expectedKey = entry.getKey();
Assert.assertTrue(actualResult.containsKey(expectedKey));
double[] expectedSums = entry.getValue();
double[] actualSums = actualResult.get(expectedKey);
for (int j = 0; j < expectedSums.length; j++) {
Assert.assertEquals(actualSums[j], expectedSums[j],
"Mis-match sum for key '" + expectedKey + "', Metric: " + metricNames.get(j) + ", Random Seed: "
+ _randomSeed);
}
}
}
}
/**
* Helper method to compute the sums using raw index.
* @param metricNames
* @param brokerRequest
*/
private Map<String, double[]> computeSumUsingRawDocs(IndexSegment segment, List<String> metricNames,
BrokerRequest brokerRequest) {
FilterPlanNode planNode = new FilterPlanNode(segment, brokerRequest);
Operator rawOperator = planNode.run();
BlockDocIdIterator rawDocIdIterator = rawOperator.nextBlock().getBlockDocIdSet().iterator();
List<String> groupByColumns = Collections.EMPTY_LIST;
if (brokerRequest.isSetAggregationsInfo() && brokerRequest.isSetGroupBy()) {
groupByColumns = brokerRequest.getGroupBy().getColumns();
}
return computeSum(segment, rawDocIdIterator, metricNames, groupByColumns);
}
/**
* Helper method to compute the sum using aggregated docs.
* @param metricNames
* @param brokerRequest
* @return
*/
private Map<String, double[]> computeSumUsingAggregatedDocs(IndexSegment segment, List<String> metricNames,
BrokerRequest brokerRequest) {
StarTreeIndexOperator starTreeOperator = new StarTreeIndexOperator(segment, brokerRequest);
starTreeOperator.open();
BlockDocIdIterator starTreeDocIdIterator = starTreeOperator.nextBlock().getBlockDocIdSet().iterator();
List<String> groupByColumns = Collections.EMPTY_LIST;
if (brokerRequest.isSetAggregationsInfo() && brokerRequest.isSetGroupBy()) {
groupByColumns = brokerRequest.getGroupBy().getColumns();
}
return computeSum(segment, starTreeDocIdIterator, metricNames, groupByColumns);
}
/**
* Compute 'sum' for a given list of metrics, by scanning the given set of doc-ids.
*
* @param segment
* @param docIdIterator
* @param metricNames
* @return
*/
private Map<String, double[]> computeSum(IndexSegment segment, BlockDocIdIterator docIdIterator,
List<String> metricNames, List<String> groupByColumns) {
int docId;
int numMetrics = metricNames.size();
Dictionary[] metricDictionaries = new Dictionary[numMetrics];
BlockSingleValIterator[] metricValIterators = new BlockSingleValIterator[numMetrics];
int numGroupByColumns = groupByColumns.size();
Dictionary[] groupByDictionaries = new Dictionary[numGroupByColumns];
BlockSingleValIterator[] groupByValIterators = new BlockSingleValIterator[numGroupByColumns];
for (int i = 0; i < numMetrics; i++) {
String metricName = metricNames.get(i);
DataSource dataSource = segment.getDataSource(metricName);
metricDictionaries[i] = dataSource.getDictionary();
metricValIterators[i] = (BlockSingleValIterator) dataSource.getNextBlock().getBlockValueSet().iterator();
}
for (int i = 0; i < numGroupByColumns; i++) {
String groupByColumn = groupByColumns.get(i);
DataSource dataSource = segment.getDataSource(groupByColumn);
groupByDictionaries[i] = dataSource.getDictionary();
groupByValIterators[i] = (BlockSingleValIterator) dataSource.getNextBlock().getBlockValueSet().iterator();
}
Map<String, double[]> result = new HashMap<String, double[]>();
while ((docId = docIdIterator.next()) != Constants.EOF) {
StringBuilder stringBuilder = new StringBuilder();
for (int i = 0; i < numGroupByColumns; i++) {
groupByValIterators[i].skipTo(docId);
int dictId = groupByValIterators[i].nextIntVal();
stringBuilder.append(groupByDictionaries[i].getStringValue(dictId));
stringBuilder.append("_");
}
String key = stringBuilder.toString();
if (!result.containsKey(key)) {
result.put(key, new double[numMetrics]);
}
double[] sumsSoFar = result.get(key);
for (int i = 0; i < numMetrics; i++) {
metricValIterators[i].skipTo(docId);
int dictId = metricValIterators[i].nextIntVal();
sumsSoFar[i] += metricDictionaries[i].getDoubleValue(dictId);
}
}
return result;
}
}