/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.Session; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.InterleavedBlockBuilder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.testing.MaterializedResult; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; import static com.facebook.presto.operator.PageAssertions.assertPageEquals; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.facebook.presto.type.TypeJsonUtils.appendToBlockBuilder; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public final class OperatorAssertion { private OperatorAssertion() { } public static List<Page> toPages(Operator operator, Iterator<Page> input) { ImmutableList.Builder<Page> outputPages = ImmutableList.builder(); boolean finishing = false; while (operator.needsInput() && input.hasNext()) { operator.addInput(input.next()); } for (int loops = 0; !operator.isFinished() && loops < 10_000; loops++) { if (operator.needsInput()) { if (input.hasNext()) { Page inputPage = input.next(); operator.addInput(inputPage); } else if (!finishing) { operator.finish(); finishing = true; } } Page outputPage = operator.getOutput(); if (outputPage != null) { outputPages.add(outputPage); } } assertFalse(operator.needsInput()); assertTrue(operator.isBlocked().isDone()); assertTrue(operator.isFinished()); return outputPages.build(); } public static List<Page> toPages(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input) { try (Operator operator = operatorFactory.createOperator(driverContext)) { return toPages(operator, input); } catch (Exception e) { throw Throwables.propagate(e); } } private static List<Page> toPages(Operator operator, List<Page> input) { // verify initial state assertEquals(operator.isFinished(), false); assertEquals(operator.needsInput(), true); assertEquals(operator.getOutput(), null); return toPages(operator, input.iterator()); } public static List<Page> toPages(OperatorFactory operatorFactory, DriverContext driverContext) { try (Operator operator = operatorFactory.createOperator(driverContext)) { return toPages(operator); } catch (Exception e) { throw Throwables.propagate(e); } } private static List<Page> toPages(Operator operator) { // operator does not have input so should never require input assertEquals(operator.needsInput(), false); ImmutableList.Builder<Page> outputPages = ImmutableList.builder(); addRemainingOutputPages(operator, outputPages); return outputPages.build(); } private static void addRemainingOutputPages(Operator operator, ImmutableList.Builder<Page> outputPages) { // pull remaining output pages while (!operator.isFinished()) { // at this point the operator should not need more input assertEquals(operator.needsInput(), false); Page outputPage = operator.getOutput(); if (outputPage != null) { outputPages.add(outputPage); } } // verify final state assertEquals(operator.isFinished(), true); assertEquals(operator.needsInput(), false); assertEquals(operator.getOutput(), null); } public static MaterializedResult toMaterializedResult(Session session, List<Type> types, List<Page> pages) { // materialize pages MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(session, types); for (Page outputPage : pages) { resultBuilder.page(outputPage); } return resultBuilder.build(); } public static Block toRow(List<Type> parameterTypes, Object... values) { checkArgument(parameterTypes.size() == values.length, "parameterTypes.size(" + parameterTypes.size() + ") does not equal to values.length(" + values.length + ")"); BlockBuilder blockBuilder = new InterleavedBlockBuilder(parameterTypes, new BlockBuilderStatus(), parameterTypes.size()); for (int i = 0; i < values.length; i++) { appendToBlockBuilder(parameterTypes.get(i), values[i], blockBuilder); } return blockBuilder.build(); } public static void assertOperatorEquals(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, List<Page> expected) throws Exception { List<Page> actual = toPages(operatorFactory, driverContext, input); assertEquals(actual.size(), expected.size()); for (int i = 0; i < actual.size(); i++) { assertPageEquals(operatorFactory.getTypes(), actual.get(i), expected.get(i)); } } public static void assertOperatorEquals(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, MaterializedResult expected) throws Exception { assertOperatorEquals(operatorFactory, driverContext, input, expected, false, ImmutableList.of()); } public static void assertOperatorEquals(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, MaterializedResult expected, boolean hashEnabled, List<Integer> hashChannels) throws Exception { List<Page> pages = toPages(operatorFactory, driverContext, input); MaterializedResult actual; if (hashEnabled && !hashChannels.isEmpty()) { // Drop the hashChannel for all pages List<Page> actualPages = dropChannel(pages, hashChannels); List<Type> expectedTypes = without(operatorFactory.getTypes(), hashChannels); actual = toMaterializedResult(driverContext.getSession(), expectedTypes, actualPages); } else { actual = toMaterializedResult(driverContext.getSession(), operatorFactory.getTypes(), pages); } assertEquals(actual, expected); } public static void assertOperatorEqualsIgnoreOrder( OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, MaterializedResult expected) { assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, false, Optional.empty()); } public static void assertOperatorEqualsIgnoreOrder( OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, MaterializedResult expected, boolean hashEnabled, Optional<Integer> hashChannel) { List<Page> pages = toPages(operatorFactory, driverContext, input); MaterializedResult actual; if (hashEnabled && hashChannel.isPresent()) { // Drop the hashChannel for all pages List<Page> actualPages = dropChannel(pages, ImmutableList.of(hashChannel.get())); List<Type> expectedTypes = without(operatorFactory.getTypes(), ImmutableList.of(hashChannel.get())); actual = toMaterializedResult(driverContext.getSession(), expectedTypes, actualPages); } else { actual = toMaterializedResult(driverContext.getSession(), operatorFactory.getTypes(), pages); } assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); } static <T> List<T> without(List<T> types, List<Integer> channels) { types = new ArrayList<>(types); int removed = 0; for (int hashChannel : channels) { types.remove(hashChannel - removed); removed++; } return ImmutableList.copyOf(types); } static List<Page> dropChannel(List<Page> pages, List<Integer> channels) { List<Page> actualPages = new ArrayList<>(); for (Page page : pages) { int channel = 0; Block[] blocks = new Block[page.getChannelCount() - channels.size()]; for (int i = 0; i < page.getChannelCount(); i++) { if (channels.contains(i)) { continue; } blocks[channel++] = page.getBlock(i); } actualPages.add(new Page(blocks)); } return actualPages; } }