/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.ScheduledSplit; import com.facebook.presto.Session; import com.facebook.presto.TaskSource; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.Split; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.FixedPageSource; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.Type; import com.facebook.presto.split.PageSourceProvider; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.PageConsumerOperator; import com.facebook.presto.testing.TestingTransactionHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.Duration; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import java.io.Closeable; import java.io.IOException; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Function; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static java.util.concurrent.Executors.newCachedThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertSame; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @Test(singleThreaded = true) public class TestDriver { private ExecutorService executor; private DriverContext driverContext; @BeforeMethod public void setUp() throws Exception { executor = newCachedThreadPool(daemonThreadsNamed("test-%s")); driverContext = createTaskContext(executor, TEST_SESSION) .addPipelineContext(0, true, true) .addDriverContext(); } @AfterMethod public void tearDown() { executor.shutdownNow(); } @Test public void testNormalFinish() { List<Type> types = ImmutableList.of(VARCHAR, BIGINT, BIGINT); ValuesOperator source = new ValuesOperator(driverContext.addOperatorContext(0, new PlanNodeId("test"), "values"), types, rowPagesBuilder(types) .addSequencePage(10, 20, 30, 40) .build()); Operator sink = createSinkOperator(source); Driver driver = new Driver(driverContext, source, sink); assertSame(driver.getDriverContext(), driverContext); assertFalse(driver.isFinished()); ListenableFuture<?> blocked = driver.processFor(new Duration(1, TimeUnit.SECONDS)); assertTrue(blocked.isDone()); assertTrue(driver.isFinished()); assertTrue(sink.isFinished()); assertTrue(source.isFinished()); } @Test public void testAbruptFinish() { List<Type> types = ImmutableList.of(VARCHAR, BIGINT, BIGINT); ValuesOperator source = new ValuesOperator(driverContext.addOperatorContext(0, new PlanNodeId("test"), "values"), types, rowPagesBuilder(types) .addSequencePage(10, 20, 30, 40) .build()); PageConsumerOperator sink = createSinkOperator(source); Driver driver = new Driver(driverContext, source, sink); assertSame(driver.getDriverContext(), driverContext); assertFalse(driver.isFinished()); driver.close(); assertTrue(driver.isFinished()); // finish is only called in normal operations assertFalse(source.isFinished()); assertFalse(sink.isFinished()); // close is always called (values operator doesn't have a closed state) assertTrue(sink.isClosed()); } @Test public void testAddSourceFinish() { PlanNodeId sourceId = new PlanNodeId("source"); final List<Type> types = ImmutableList.of(VARCHAR, BIGINT, BIGINT); TableScanOperator source = new TableScanOperator(driverContext.addOperatorContext(99, new PlanNodeId("test"), "values"), sourceId, new PageSourceProvider() { @Override public ConnectorPageSource createPageSource(Session session, Split split, List<ColumnHandle> columns) { return new FixedPageSource(rowPagesBuilder(types) .addSequencePage(10, 20, 30, 40) .build()); } }, types, ImmutableList.of()); PageConsumerOperator sink = createSinkOperator(source); Driver driver = new Driver(driverContext, source, sink); assertSame(driver.getDriverContext(), driverContext); assertFalse(driver.isFinished()); assertFalse(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone()); assertFalse(driver.isFinished()); driver.updateSource(new TaskSource(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true)); assertFalse(driver.isFinished()); assertTrue(driver.processFor(new Duration(1, TimeUnit.SECONDS)).isDone()); assertTrue(driver.isFinished()); assertTrue(sink.isFinished()); assertTrue(source.isFinished()); } @Test public void testBrokenOperatorCloseWhileProcessing() throws Exception { BrokenOperator brokenOperator = new BrokenOperator(driverContext.addOperatorContext(0, new PlanNodeId("test"), "source"), false); final Driver driver = new Driver(driverContext, brokenOperator, createSinkOperator(brokenOperator)); assertSame(driver.getDriverContext(), driverContext); // block thread in operator processing Future<Boolean> driverProcessFor = executor.submit(new Callable<Boolean>() { @Override public Boolean call() throws Exception { return driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone(); } }); brokenOperator.waitForLocked(); driver.close(); assertTrue(driver.isFinished()); try { driverProcessFor.get(1, TimeUnit.SECONDS); fail("Expected InterruptedException"); } catch (ExecutionException e) { assertDriverInterrupted(e.getCause()); } } @Test public void testBrokenOperatorProcessWhileClosing() throws Exception { BrokenOperator brokenOperator = new BrokenOperator(driverContext.addOperatorContext(0, new PlanNodeId("test"), "source"), true); final Driver driver = new Driver(driverContext, brokenOperator, createSinkOperator(brokenOperator)); assertSame(driver.getDriverContext(), driverContext); // block thread in operator close Future<Boolean> driverClose = executor.submit(new Callable<Boolean>() { @Override public Boolean call() throws Exception { driver.close(); return true; } }); brokenOperator.waitForLocked(); assertTrue(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone()); assertTrue(driver.isFinished()); brokenOperator.unlock(); assertTrue(driverClose.get()); } @Test public void testBrokenOperatorAddSource() throws Exception { PlanNodeId sourceId = new PlanNodeId("source"); final List<Type> types = ImmutableList.of(VARCHAR, BIGINT, BIGINT); // create a table scan operator that does not block, which will cause the driver loop to busy wait TableScanOperator source = new NotBlockedTableScanOperator(driverContext.addOperatorContext(99, new PlanNodeId("test"), "values"), sourceId, new PageSourceProvider() { @Override public ConnectorPageSource createPageSource(Session session, Split split, List<ColumnHandle> columns) { return new FixedPageSource(rowPagesBuilder(types) .addSequencePage(10, 20, 30, 40) .build()); } }, types, ImmutableList.of()); BrokenOperator brokenOperator = new BrokenOperator(driverContext.addOperatorContext(0, new PlanNodeId("test"), "source")); final Driver driver = new Driver(driverContext, source, brokenOperator); // block thread in operator processing Future<Boolean> driverProcessFor = executor.submit(new Callable<Boolean>() { @Override public Boolean call() throws Exception { return driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone(); } }); brokenOperator.waitForLocked(); assertSame(driver.getDriverContext(), driverContext); assertFalse(driver.isFinished()); // processFor always returns NOT_BLOCKED, because DriveLockResult was not acquired assertTrue(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone()); assertFalse(driver.isFinished()); driver.updateSource(new TaskSource(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true)); assertFalse(driver.isFinished()); // processFor always returns NOT_BLOCKED, because DriveLockResult was not acquired assertTrue(driver.processFor(new Duration(1, TimeUnit.SECONDS)).isDone()); assertFalse(driver.isFinished()); driver.close(); assertTrue(driver.isFinished()); try { driverProcessFor.get(1, TimeUnit.SECONDS); fail("Expected InterruptedException"); } catch (ExecutionException e) { assertDriverInterrupted(e.getCause()); } } private void assertDriverInterrupted(Throwable cause) { checkArgument(cause instanceof PrestoException, "Expected root cause exception to be an instance of PrestoException"); assertEquals(((PrestoException) cause).getErrorCode(), GENERIC_INTERNAL_ERROR.toErrorCode()); assertEquals(cause.getMessage(), "Driver was interrupted"); } private static Split newMockSplit() { return new Split(new ConnectorId("test"), TestingTransactionHandle.create(), new MockSplit()); } private PageConsumerOperator createSinkOperator(Operator source) { // materialize the output to catch some type errors MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(driverContext.getSession(), source.getTypes()); return new PageConsumerOperator(driverContext.addOperatorContext(1, new PlanNodeId("test"), "sink"), resultBuilder::page, Function.identity()); } private static class BrokenOperator implements Operator, Closeable { private final OperatorContext operatorContext; private final ReentrantLock lock = new ReentrantLock(); private final CountDownLatch lockedLatch = new CountDownLatch(1); private final CountDownLatch unlockLatch = new CountDownLatch(1); private final boolean lockForClose; private BrokenOperator(OperatorContext operatorContext) { this(operatorContext, false); } private BrokenOperator(OperatorContext operatorContext, boolean lockForClose) { this.operatorContext = operatorContext; this.lockForClose = lockForClose; } @Override public OperatorContext getOperatorContext() { return operatorContext; } public void unlock() { unlockLatch.countDown(); } private void waitForLocked() { try { assertTrue(lockedLatch.await(10, TimeUnit.SECONDS)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException("Interrupted", e); } } private void waitForUnlock() { try { assertTrue(lock.tryLock(1, TimeUnit.SECONDS)); try { lockedLatch.countDown(); assertTrue(unlockLatch.await(5, TimeUnit.SECONDS)); } finally { lock.unlock(); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException("Interrupted", e); } } @Override public List<Type> getTypes() { return ImmutableList.of(); } @Override public void finish() { waitForUnlock(); } @Override public boolean isFinished() { waitForUnlock(); return true; } @Override public ListenableFuture<?> isBlocked() { waitForUnlock(); return NOT_BLOCKED; } @Override public boolean needsInput() { waitForUnlock(); return false; } @Override public void addInput(Page page) { waitForUnlock(); } @Override public Page getOutput() { waitForUnlock(); return null; } @Override public void close() throws IOException { if (lockForClose) { waitForUnlock(); } } } private static class NotBlockedTableScanOperator extends TableScanOperator { public NotBlockedTableScanOperator( OperatorContext operatorContext, PlanNodeId planNodeId, PageSourceProvider pageSourceProvider, List<Type> types, Iterable<ColumnHandle> columns) { super(operatorContext, planNodeId, pageSourceProvider, types, columns); } @Override public ListenableFuture<?> isBlocked() { return NOT_BLOCKED; } } private static class MockSplit implements ConnectorSplit { @Override public boolean isRemotelyAccessible() { return false; } @Override public List<HostAddress> getAddresses() { return ImmutableList.of(); } @Override public Object getInfo() { return null; } } }