/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * <p/> * http://www.apache.org/licenses/LICENSE-2.0 * <p/> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.drill.exec.physical.unit; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import mockit.Delegate; import mockit.Injectable; import mockit.NonStrictExpectations; import org.antlr.runtime.ANTLRStringStream; import org.antlr.runtime.CommonTokenStream; import org.antlr.runtime.RecognitionException; import org.apache.calcite.rel.RelFieldCollation; import org.apache.drill.DrillTestWrapper; import org.apache.drill.common.config.DrillConfig; import org.apache.drill.common.exceptions.ExecutionSetupException; import org.apache.drill.common.expression.FieldReference; import org.apache.drill.common.expression.LogicalExpression; import org.apache.drill.common.expression.PathSegment; import org.apache.drill.common.expression.SchemaPath; import org.apache.drill.common.expression.parser.ExprLexer; import org.apache.drill.common.expression.parser.ExprParser; import org.apache.drill.common.logical.data.JoinCondition; import org.apache.drill.common.logical.data.NamedExpression; import org.apache.drill.common.logical.data.Order; import org.apache.drill.common.scanner.ClassPathScanner; import org.apache.drill.common.scanner.persistence.ScanResult; import org.apache.drill.exec.ExecTest; import org.apache.drill.exec.compile.CodeCompiler; import org.apache.drill.exec.compile.TemplateClassDefinition; import org.apache.drill.exec.exception.ClassTransformationException; import org.apache.drill.exec.exception.SchemaChangeException; import org.apache.drill.exec.expr.ClassGenerator; import org.apache.drill.exec.expr.CodeGenerator; import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry; import org.apache.drill.exec.memory.BufferAllocator; import org.apache.drill.exec.memory.RootAllocatorFactory; import org.apache.drill.exec.ops.BufferManagerImpl; import org.apache.drill.exec.ops.FragmentContext; import org.apache.drill.exec.ops.OperatorContext; import org.apache.drill.exec.ops.OperatorStats; import org.apache.drill.exec.physical.base.AbstractBase; import org.apache.drill.exec.physical.base.PhysicalOperator; import org.apache.drill.exec.physical.impl.BatchCreator; import org.apache.drill.exec.physical.impl.OperatorCreatorRegistry; import org.apache.drill.exec.physical.impl.ScanBatch; import org.apache.drill.exec.physical.impl.project.Projector; import org.apache.drill.exec.physical.impl.project.ProjectorTemplate; import org.apache.drill.exec.proto.ExecProtos; import org.apache.drill.exec.record.RecordBatch; import org.apache.drill.exec.record.VectorAccessible; import org.apache.drill.exec.store.RecordReader; import org.apache.drill.exec.testing.ExecutionControls; import org.apache.drill.exec.util.TestUtilities; import org.junit.Before; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import static org.apache.drill.exec.physical.base.AbstractBase.INIT_ALLOCATION; /** * Look! Doesn't extend BaseTestQuery!! */ public class PhysicalOpUnitTestBase extends ExecTest { // public static long INIT_ALLOCATION = 1_000_000l; // public static long MAX_ALLOCATION = 10_000_000L; @Injectable FragmentContext fragContext; @Injectable OperatorContext opContext; @Injectable OperatorStats opStats; @Injectable PhysicalOperator popConf; @Injectable ExecutionControls executionControls; private final DrillConfig drillConf = DrillConfig.create(); private final BufferAllocator allocator = RootAllocatorFactory.newRoot(drillConf); private final BufferManagerImpl bufManager = new BufferManagerImpl(allocator); private final ScanResult classpathScan = ClassPathScanner.fromPrescan(drillConf); private final FunctionImplementationRegistry funcReg = new FunctionImplementationRegistry(drillConf, classpathScan); private final TemplateClassDefinition<Projector> templateClassDefinition = new TemplateClassDefinition<Projector>(Projector.class, ProjectorTemplate.class); private final OperatorCreatorRegistry opCreatorReg = new OperatorCreatorRegistry(classpathScan); @Before public void setup() throws Exception { mockFragmentContext(); } @Override protected LogicalExpression parseExpr(String expr) { ExprLexer lexer = new ExprLexer(new ANTLRStringStream(expr)); CommonTokenStream tokens = new CommonTokenStream(lexer); ExprParser parser = new ExprParser(tokens); try { return parser.parse().e; } catch (RecognitionException e) { throw new RuntimeException("Error parsing expression: " + expr); } } protected Order.Ordering ordering(String expression, RelFieldCollation.Direction direction, RelFieldCollation.NullDirection nullDirection) { return new Order.Ordering(direction, parseExpr(expression), nullDirection); } protected JoinCondition joinCond(String leftExpr, String relationship, String rightExpr) { return new JoinCondition(relationship, parseExpr(leftExpr), parseExpr(rightExpr)); } protected List<NamedExpression> parseExprs(String... expressionsAndOutputNames) { Preconditions.checkArgument(expressionsAndOutputNames.length %2 ==0, "List of expressions and output field names" + " is not complete, each expression must explicitly give and output name,"); List<NamedExpression> ret = new ArrayList<>(); for (int i = 0; i < expressionsAndOutputNames.length; i += 2) { ret.add(new NamedExpression(parseExpr(expressionsAndOutputNames[i]), new FieldReference(new SchemaPath(new PathSegment.NameSegment(expressionsAndOutputNames[i+1]))))); } return ret; } protected static class BatchIterator implements Iterable<VectorAccessible> { private RecordBatch operator; public BatchIterator(RecordBatch operator) { this.operator = operator; } @Override public Iterator<VectorAccessible> iterator() { return new Iterator<VectorAccessible>() { boolean needToGrabNext = true; RecordBatch.IterOutcome lastResultOutcome; @Override public boolean hasNext() { if (needToGrabNext) { lastResultOutcome = operator.next(); needToGrabNext = false; } if (lastResultOutcome == RecordBatch.IterOutcome.NONE || lastResultOutcome == RecordBatch.IterOutcome.STOP) { return false; } else if (lastResultOutcome == RecordBatch.IterOutcome.OUT_OF_MEMORY) { throw new RuntimeException("Operator ran out of memory"); } else { return true; } } @Override public VectorAccessible next() { if (needToGrabNext) { lastResultOutcome = operator.next(); } needToGrabNext = true; return operator; } @Override public void remove() { throw new UnsupportedOperationException("Remove is not supported."); } }; } } protected OperatorTestBuilder opTestBuilder() { return new OperatorTestBuilder(); } protected class OperatorTestBuilder { private PhysicalOperator popConfig; private String[] baselineColumns; private List<Map<String, Object>> baselineRecords; private List<List<String>> inputStreamsJSON; private long initReservation = AbstractBase.INIT_ALLOCATION; private long maxAllocation = AbstractBase.MAX_ALLOCATION; public void go() { BatchCreator<PhysicalOperator> opCreator; RecordBatch testOperator; try { mockOpContext(initReservation, maxAllocation); opCreator = (BatchCreator<PhysicalOperator>) opCreatorReg.getOperatorCreator(popConfig.getClass()); List<RecordBatch> incomingStreams = Lists.newArrayList(); if (inputStreamsJSON != null) { for (List<String> batchesJson : inputStreamsJSON) { incomingStreams.add(new ScanBatch(null, fragContext, getRecordReadersForJsonBatches(batchesJson, fragContext))); } } testOperator = opCreator.getBatch(fragContext, popConfig, incomingStreams); Map<String, List<Object>> actualSuperVectors = DrillTestWrapper.addToCombinedVectorResults(new BatchIterator(testOperator)); Map<String, List<Object>> expectedSuperVectors = DrillTestWrapper.translateRecordListToHeapVectors(baselineRecords); DrillTestWrapper.compareMergedVectors(expectedSuperVectors, actualSuperVectors); } catch (ExecutionSetupException e) { throw new RuntimeException(e); } catch (UnsupportedEncodingException e) { throw new RuntimeException(e); } catch (SchemaChangeException e) { throw new RuntimeException(e); } catch (Exception e) { throw new RuntimeException(e); } } public OperatorTestBuilder physicalOperator(PhysicalOperator batch) { this.popConfig = batch; return this; } public OperatorTestBuilder initReservation(long initReservation) { this.initReservation = initReservation; return this; } public OperatorTestBuilder maxAllocation(long maxAllocation) { this.maxAllocation = maxAllocation; return this; } public OperatorTestBuilder inputDataStreamJson(List<String> jsonBatches) { this.inputStreamsJSON = new ArrayList<>(); this.inputStreamsJSON.add(jsonBatches); return this; } public OperatorTestBuilder inputDataStreamsJson(List<List<String>> childStreams) { this.inputStreamsJSON = childStreams; return this; } public OperatorTestBuilder baselineColumns(String... columns) { for (int i = 0; i < columns.length; i++) { LogicalExpression ex = parseExpr(columns[i]); if (ex instanceof SchemaPath) { columns[i] = ((SchemaPath)ex).toExpr(); } else { throw new IllegalStateException("Schema path is not a valid format."); } } this.baselineColumns = columns; return this; } public OperatorTestBuilder baselineValues(Object ... baselineValues) { if (baselineRecords == null) { baselineRecords = new ArrayList<>(); } Map<String, Object> ret = new HashMap<>(); int i = 0; Preconditions.checkArgument(baselineValues.length == baselineColumns.length, "Must supply the same number of baseline values as columns."); for (String s : baselineColumns) { ret.put(s, baselineValues[i]); i++; } this.baselineRecords.add(ret); return this; } } protected void mockFragmentContext() throws Exception{ final CodeCompiler compiler = new CodeCompiler(drillConf, optionManager); // final BufferAllocator allocator = this.allocator.newChildAllocator("allocator_for_operator_test", initReservation, maxAllocation); new NonStrictExpectations() { { // optManager.getOption(withAny(new TypeValidators.BooleanValidator("", false))); result = false; // // TODO(DRILL-4450) - Probably want to just create a default option manager, this is a hack to prevent // // the code compilation from failing when trying to decide of scalar replacement is turned on // // this will cause other code paths to fail because this return value won't be valid for most // // string options // optManager.getOption(withAny(new TypeValidators.StringValidator("", "try"))); result = "try"; // optManager.getOption(withAny(new TypeValidators.PositiveLongValidator("", 1l, 1l))); result = 10; fragContext.getOptions(); result = optionManager; fragContext.getManagedBuffer(); result = bufManager.getManagedBuffer(); fragContext.shouldContinue(); result = true; fragContext.getExecutionControls(); result = executionControls; fragContext.getFunctionRegistry(); result = funcReg; fragContext.getConfig(); result = drillConf; fragContext.getHandle(); result = ExecProtos.FragmentHandle.getDefaultInstance(); try { CodeGenerator<?> cg = CodeGenerator.get(templateClassDefinition, funcReg); cg.plainJavaCapable(true); // cg.saveCodeForDebugging(true); fragContext.getImplementationClass(withAny(cg)); result = new Delegate<Object>() { @SuppressWarnings("unused") Object getImplementationClass(CodeGenerator<Object> gen) throws IOException, ClassTransformationException { return compiler.createInstance(gen); } }; fragContext.getImplementationClass(withAny(CodeGenerator.get(templateClassDefinition, funcReg).getRoot())); result = new Delegate<Object>() { @SuppressWarnings("unused") Object getImplementationClass(ClassGenerator<Object> gen) throws IOException, ClassTransformationException { return compiler.createInstance(gen.getCodeGenerator()); } }; } catch (ClassTransformationException e) { throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } } }; } protected void mockOpContext(long initReservation, long maxAllocation) throws Exception{ final BufferAllocator allocator = this.allocator.newChildAllocator("allocator_for_operator_test", initReservation, maxAllocation); new NonStrictExpectations() { { opContext.getStats();result = opStats; opContext.getAllocator(); result = allocator; fragContext.newOperatorContext(withAny(popConf));result = opContext; } }; } protected OperatorCreatorRegistry getOpCreatorReg() { return opCreatorReg; } private Iterator<RecordReader> getRecordReadersForJsonBatches(List<String> jsonBatches, FragmentContext fragContext) { return TestUtilities.getJsonReadersFromBatchString(jsonBatches, fragContext, Collections.singletonList(SchemaPath.getSimplePath("*"))); } }