/**
* Copyright (C) 2014-2016 LinkedIn Corp. (pinot-core@linkedin.com)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.linkedin.pinot.query.transform;
import com.linkedin.pinot.common.data.DimensionFieldSpec;
import com.linkedin.pinot.common.data.FieldSpec;
import com.linkedin.pinot.common.data.MetricFieldSpec;
import com.linkedin.pinot.common.data.Schema;
import com.linkedin.pinot.common.data.TimeFieldSpec;
import com.linkedin.pinot.common.request.AggregationInfo;
import com.linkedin.pinot.common.request.BrokerRequest;
import com.linkedin.pinot.common.request.GroupBy;
import com.linkedin.pinot.common.segment.ReadMode;
import com.linkedin.pinot.core.common.BlockValSet;
import com.linkedin.pinot.core.common.Operator;
import com.linkedin.pinot.core.data.GenericRow;
import com.linkedin.pinot.core.data.readers.FileFormat;
import com.linkedin.pinot.core.data.readers.RecordReader;
import com.linkedin.pinot.core.indexsegment.IndexSegment;
import com.linkedin.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
import com.linkedin.pinot.core.operator.BReusableFilteredDocIdSetOperator;
import com.linkedin.pinot.core.operator.BaseOperator;
import com.linkedin.pinot.core.operator.MProjectionOperator;
import com.linkedin.pinot.core.operator.blocks.IntermediateResultsBlock;
import com.linkedin.pinot.core.operator.filter.MatchEntireSegmentOperator;
import com.linkedin.pinot.core.operator.query.AggregationGroupByOperator;
import com.linkedin.pinot.core.operator.transform.TransformExpressionOperator;
import com.linkedin.pinot.core.operator.transform.function.TimeConversionTransform;
import com.linkedin.pinot.core.operator.transform.function.TransformFunction;
import com.linkedin.pinot.core.operator.transform.function.TransformFunctionFactory;
import com.linkedin.pinot.core.plan.AggregationFunctionInitializer;
import com.linkedin.pinot.core.plan.DocIdSetPlanNode;
import com.linkedin.pinot.core.plan.TransformPlanNode;
import com.linkedin.pinot.core.query.aggregation.AggregationFunctionContext;
import com.linkedin.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import com.linkedin.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import com.linkedin.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
import com.linkedin.pinot.core.segment.index.loader.Loaders;
import com.linkedin.pinot.pql.parsers.Pql2Compiler;
import com.linkedin.pinot.util.TestUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
/**
* Unit test for transforms on group by columns.
*/
public class TransformGroupByTest {
private static final Logger LOGGER = LoggerFactory.getLogger(TransformExpressionOperatorTest.class);
private static final String SEGMENT_DIR_NAME = System.getProperty("java.io.tmpdir") + File.separator + "xformGroupBy";
private static final String SEGMENT_NAME = "xformGroupBySeg";
private static final String TABLE_NAME = "xformGroupByTable";
private static final long RANDOM_SEED = System.nanoTime();
private static final int NUM_ROWS = DocIdSetPlanNode.MAX_DOC_PER_CALL;
private static final double EPSILON = 1e-5;
private static final String DIMENSION_NAME = "dimension";
private static final String TIME_COLUMN_NAME = "millisSinceEpoch";
private static final String METRIC_NAME = "metric";
private static final String[] _dimensionValues = new String[]{"abcd", "ABCD", "bcde", "BCDE", "cdef", "CDEF"};
private IndexSegment _indexSegment;
private RecordReader _recordReader;
@BeforeClass
public void setup()
throws Exception {
TransformFunctionFactory.init(new String[]{ToUpper.class.getName(), TimeConversionTransform.class.getName()});
Schema schema = buildSchema();
_recordReader = buildSegment(SEGMENT_DIR_NAME, SEGMENT_NAME, schema);
_indexSegment = Loaders.IndexSegment.load(new File(SEGMENT_DIR_NAME, SEGMENT_NAME), ReadMode.heap);
}
@AfterClass
public void tearDown()
throws IOException {
FileUtils.deleteDirectory(new File(SEGMENT_DIR_NAME));
}
/**
* Test for group-by with transformed string dimension column.
*/
@Test
public void testGroupByString()
throws Exception {
String query = String.format("select sum(%s) from xformSegTable group by ToUpper(%s)", METRIC_NAME, DIMENSION_NAME);
AggregationGroupByResult groupByResult = executeGroupByQuery(_indexSegment, query);
Assert.assertNotNull(groupByResult);
// Compute the expected answer for the query.
Map<String, Double> expectedValuesMap = new HashMap<>();
_recordReader.rewind();
for (int row = 0; row < NUM_ROWS; row++) {
GenericRow genericRow = _recordReader.next();
String key = ((String) genericRow.getValue(DIMENSION_NAME)).toUpperCase();
Double value = (Double) genericRow.getValue(METRIC_NAME);
Double prevValue = expectedValuesMap.get(key);
if (prevValue == null) {
expectedValuesMap.put(key, value);
} else {
expectedValuesMap.put(key, prevValue + value);
}
}
compareGroupByResults(groupByResult, expectedValuesMap);
}
/**
* Test for group-by with transformed time column from millis to days.
*
* @throws Exception
*/
@Test
public void testTimeRollUp()
throws Exception {
String query =
String.format("select sum(%s) from xformSegTable group by timeConvert(%s, 'MILLISECONDS', 'DAYS')", METRIC_NAME,
TIME_COLUMN_NAME);
AggregationGroupByResult groupByResult = executeGroupByQuery(_indexSegment, query);
Assert.assertNotNull(groupByResult);
Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = groupByResult.getGroupKeyIterator();
Assert.assertNotNull(groupKeyIterator);
// Compute the expected answer for the query.
Map<String, Double> expectedValuesMap = new HashMap<>();
_recordReader.rewind();
for (int row = 0; row < NUM_ROWS; row++) {
GenericRow genericRow = _recordReader.next();
long daysSinceEpoch =
TimeUnit.DAYS.convert(((Long) genericRow.getValue(TIME_COLUMN_NAME)), TimeUnit.MILLISECONDS);
Double value = (Double) genericRow.getValue(METRIC_NAME);
String key = String.valueOf(daysSinceEpoch);
Double prevValue = expectedValuesMap.get(key);
if (prevValue == null) {
expectedValuesMap.put(key, value);
} else {
expectedValuesMap.put(key, prevValue + value);
}
}
compareGroupByResults(groupByResult, expectedValuesMap);
}
/**
* Helper method that executes the group by query on the index and returns the group by result.
*
* @param query Query to execute
* @return Group by result
*/
private AggregationGroupByResult executeGroupByQuery(IndexSegment indexSegment, String query) {
Operator filterOperator = new MatchEntireSegmentOperator(indexSegment.getSegmentMetadata().getTotalDocs());
final BReusableFilteredDocIdSetOperator docIdSetOperator =
new BReusableFilteredDocIdSetOperator(filterOperator, indexSegment.getSegmentMetadata().getTotalDocs(),
NUM_ROWS);
final Map<String, BaseOperator> dataSourceMap = buildDataSourceMap(indexSegment.getSegmentMetadata().getSchema());
final MProjectionOperator projectionOperator = new MProjectionOperator(dataSourceMap, docIdSetOperator);
Pql2Compiler compiler = new Pql2Compiler();
BrokerRequest brokerRequest = compiler.compileToBrokerRequest(query);
List<AggregationInfo> aggregationsInfo = brokerRequest.getAggregationsInfo();
int numAggFunctions = aggregationsInfo.size();
AggregationFunctionContext[] aggrFuncContextArray = new AggregationFunctionContext[numAggFunctions];
AggregationFunctionInitializer aggFuncInitializer =
new AggregationFunctionInitializer(indexSegment.getSegmentMetadata());
for (int i = 0; i < numAggFunctions; i++) {
AggregationInfo aggregationInfo = aggregationsInfo.get(i);
aggrFuncContextArray[i] = AggregationFunctionContext.instantiate(aggregationInfo);
aggrFuncContextArray[i].getAggregationFunction().accept(aggFuncInitializer);
}
GroupBy groupBy = brokerRequest.getGroupBy();
Set<String> expressions = new HashSet<>(groupBy.getExpressions());
TransformExpressionOperator transformOperator = new TransformExpressionOperator(projectionOperator,
TransformPlanNode.buildTransformExpressionTrees(expressions));
AggregationGroupByOperator groupByOperator =
new AggregationGroupByOperator(aggrFuncContextArray, groupBy, Integer.MAX_VALUE, transformOperator, NUM_ROWS);
IntermediateResultsBlock block = (IntermediateResultsBlock) groupByOperator.nextBlock();
return block.getAggregationGroupByResult();
}
/**
* Helper method to build a segment with one dimension column containing values
* from {@link #_dimensionValues}, and one metric column.
*
* Also builds the expected group by result as it builds the segments.
*
* @param segmentDirName Name of segment directory
* @param segmentName Name of segment
* @param schema Schema for segment
* @return Schema built for the segment
* @throws Exception
*/
private RecordReader buildSegment(String segmentDirName, String segmentName, Schema schema)
throws Exception {
SegmentGeneratorConfig config = new SegmentGeneratorConfig(schema);
config.setOutDir(segmentDirName);
config.setFormat(FileFormat.AVRO);
config.setTableName(TABLE_NAME);
config.setSegmentName(segmentName);
Random random = new Random(RANDOM_SEED);
long currentTimeMillis = System.currentTimeMillis();
// Divide the day into fixed parts, and decrement time column value by this delta, so as to get
// continuous days in the input. This gives about 10 days per 10k rows.
long timeDelta = TimeUnit.MILLISECONDS.convert(1, TimeUnit.DAYS) / 1000;
final List<GenericRow> data = new ArrayList<>();
int numDimValues = _dimensionValues.length;
for (int row = 0; row < NUM_ROWS; row++) {
HashMap<String, Object> map = new HashMap<>();
map.put(DIMENSION_NAME, _dimensionValues[random.nextInt(numDimValues)]);
map.put(METRIC_NAME, random.nextDouble());
map.put(TIME_COLUMN_NAME, currentTimeMillis);
currentTimeMillis -= timeDelta;
GenericRow genericRow = new GenericRow();
genericRow.init(map);
data.add(genericRow);
}
SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl();
RecordReader reader = new TestUtils.GenericRowRecordReader(schema, data);
driver.init(config, reader);
driver.build();
LOGGER.info("Built segment {} at {}", segmentName, segmentDirName);
return reader;
}
/**
* Helper method to build a schema with one string dimension, and one double metric columns.
*/
private static Schema buildSchema() {
Schema schema = new Schema();
DimensionFieldSpec dimensionFieldSpec = new DimensionFieldSpec(DIMENSION_NAME, FieldSpec.DataType.STRING, true);
schema.addField(dimensionFieldSpec);
MetricFieldSpec metricFieldSpec = new MetricFieldSpec(METRIC_NAME, FieldSpec.DataType.DOUBLE);
schema.addField(metricFieldSpec);
TimeFieldSpec timeFieldSpec = new TimeFieldSpec(TIME_COLUMN_NAME, FieldSpec.DataType.LONG, TimeUnit.MILLISECONDS);
schema.setTimeFieldSpec(timeFieldSpec);
return schema;
}
/**
* Helper method to build data source map for all the metric columns.
*
* @param schema Schema for the index segment
* @return Map of metric name to its data source.
*/
private Map<String, BaseOperator> buildDataSourceMap(Schema schema) {
final Map<String, BaseOperator> dataSourceMap = new HashMap<>();
for (String metricName : schema.getColumnNames()) {
dataSourceMap.put(metricName, _indexSegment.getDataSource(metricName));
}
return dataSourceMap;
}
/**
* Helper method to compare group by result from query execution against a map of group keys and values.
*
* @param groupByResult Group by result from query
* @param expectedValuesMap Map of expected keys and values
*/
private void compareGroupByResults(AggregationGroupByResult groupByResult, Map<String, Double> expectedValuesMap) {
Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = groupByResult.getGroupKeyIterator();
Assert.assertNotNull(groupKeyIterator);
int numGroupKeys = 0;
while (groupKeyIterator.hasNext()) {
GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next();
Double actual = (Double) groupByResult.getResultForKey(groupKey, 0 /* aggregation function index */);
String stringKey = groupKey.getStringKey();
Double expected = expectedValuesMap.get(stringKey);
Assert.assertNotNull(expected, "Unexpected key in actual result: " + stringKey);
Assert.assertEquals(actual, expected, EPSILON);
numGroupKeys++;
}
Assert.assertEquals(numGroupKeys, expectedValuesMap.size(), "Mis-match in number of group keys");
}
/**
* Implementation of TransformFunction that converts strings to upper case.
*/
public static class ToUpper implements TransformFunction {
@Override
public String[] transform(int length, BlockValSet... input) {
String[] inputStrings = input[0].getStringValuesSV();
String[] outputStrings = new String[length];
for (int i = 0; i < length; i++) {
outputStrings[i] = inputStrings[i].toUpperCase();
}
return outputStrings;
}
@Override
public FieldSpec.DataType getOutputType() {
return FieldSpec.DataType.STRING;
}
@Override
public String getName() {
return "ToUpper";
}
}
}