/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution.scheduler;
import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.TableHandle;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TestingColumnHandle;
import com.facebook.presto.sql.planner.TestingTableHandle;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED;
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.testng.Assert.assertEquals;
public class TestPhasedExecutionSchedule
{
@Test
public void testExchange()
throws Exception
{
PlanFragment aFragment = createTableScanPlanFragment("a");
PlanFragment bFragment = createTableScanPlanFragment("b");
PlanFragment cFragment = createTableScanPlanFragment("c");
PlanFragment exchangeFragment = createExchangePlanFragment("exchange", aFragment, bFragment, cFragment);
List<Set<PlanFragmentId>> phases = PhasedExecutionSchedule.extractPhases(ImmutableList.of(aFragment, bFragment, cFragment, exchangeFragment));
assertEquals(phases, ImmutableList.of(
ImmutableSet.of(exchangeFragment.getId()),
ImmutableSet.of(aFragment.getId()),
ImmutableSet.of(bFragment.getId()),
ImmutableSet.of(cFragment.getId())));
}
@Test
public void testUnion()
throws Exception
{
PlanFragment aFragment = createTableScanPlanFragment("a");
PlanFragment bFragment = createTableScanPlanFragment("b");
PlanFragment cFragment = createTableScanPlanFragment("c");
PlanFragment unionFragment = createUnionPlanFragment("union", aFragment, bFragment, cFragment);
List<Set<PlanFragmentId>> phases = PhasedExecutionSchedule.extractPhases(ImmutableList.of(aFragment, bFragment, cFragment, unionFragment));
assertEquals(phases, ImmutableList.of(
ImmutableSet.of(unionFragment.getId()),
ImmutableSet.of(aFragment.getId()),
ImmutableSet.of(bFragment.getId()),
ImmutableSet.of(cFragment.getId())));
}
@Test
public void testJoin()
throws Exception
{
PlanFragment buildFragment = createTableScanPlanFragment("build");
PlanFragment probeFragment = createTableScanPlanFragment("probe");
PlanFragment joinFragment = createJoinPlanFragment(INNER, "join", buildFragment, probeFragment);
List<Set<PlanFragmentId>> phases = PhasedExecutionSchedule.extractPhases(ImmutableList.of(joinFragment, buildFragment, probeFragment));
assertEquals(phases, ImmutableList.of(ImmutableSet.of(joinFragment.getId()), ImmutableSet.of(buildFragment.getId()), ImmutableSet.of(probeFragment.getId())));
}
@Test
public void testRightJoin()
throws Exception
{
PlanFragment buildFragment = createTableScanPlanFragment("build");
PlanFragment probeFragment = createTableScanPlanFragment("probe");
PlanFragment joinFragment = createJoinPlanFragment(RIGHT, "join", buildFragment, probeFragment);
List<Set<PlanFragmentId>> phases = PhasedExecutionSchedule.extractPhases(ImmutableList.of(joinFragment, buildFragment, probeFragment));
assertEquals(phases, ImmutableList.of(ImmutableSet.of(joinFragment.getId()), ImmutableSet.of(buildFragment.getId()), ImmutableSet.of(probeFragment.getId())));
}
@Test
public void testBroadcastJoin()
throws Exception
{
PlanFragment buildFragment = createTableScanPlanFragment("build");
PlanFragment joinFragment = createBroadcastJoinPlanFragment("join", buildFragment);
List<Set<PlanFragmentId>> phases = PhasedExecutionSchedule.extractPhases(ImmutableList.of(joinFragment, buildFragment));
assertEquals(phases, ImmutableList.of(ImmutableSet.of(joinFragment.getId(), buildFragment.getId())));
}
@Test
public void testJoinWithDeepSources()
throws Exception
{
PlanFragment buildSourceFragment = createTableScanPlanFragment("buildSource");
PlanFragment buildMiddleFragment = createExchangePlanFragment("buildMiddle", buildSourceFragment);
PlanFragment buildTopFragment = createExchangePlanFragment("buildTop", buildMiddleFragment);
PlanFragment probeSourceFragment = createTableScanPlanFragment("probeSource");
PlanFragment probeMiddleFragment = createExchangePlanFragment("probeMiddle", probeSourceFragment);
PlanFragment probeTopFragment = createExchangePlanFragment("probeTop", probeMiddleFragment);
PlanFragment joinFragment = createJoinPlanFragment(INNER, "join", buildTopFragment, probeTopFragment);
List<Set<PlanFragmentId>> phases = PhasedExecutionSchedule.extractPhases(ImmutableList.of(
joinFragment,
buildTopFragment,
buildMiddleFragment,
buildSourceFragment,
probeTopFragment,
probeMiddleFragment,
probeSourceFragment));
assertEquals(phases, ImmutableList.of(
ImmutableSet.of(joinFragment.getId()),
ImmutableSet.of(buildTopFragment.getId()),
ImmutableSet.of(buildMiddleFragment.getId()),
ImmutableSet.of(buildSourceFragment.getId()),
ImmutableSet.of(probeTopFragment.getId()),
ImmutableSet.of(probeMiddleFragment.getId()),
ImmutableSet.of(probeSourceFragment.getId())));
}
private static PlanFragment createExchangePlanFragment(String name, PlanFragment... fragments)
{
PlanNode planNode = new RemoteSourceNode(
new PlanNodeId(name + "_id"),
Stream.of(fragments)
.map(PlanFragment::getId)
.collect(toImmutableList()),
fragments[0].getPartitioningScheme().getOutputLayout());
return createFragment(planNode);
}
private static PlanFragment createUnionPlanFragment(String name, PlanFragment... fragments)
{
PlanNode planNode = new UnionNode(
new PlanNodeId(name + "_id"),
Stream.of(fragments)
.map(fragment -> new RemoteSourceNode(new PlanNodeId(fragment.getId().toString()), fragment.getId(), fragment.getPartitioningScheme().getOutputLayout()))
.collect(toImmutableList()),
ImmutableListMultimap.of(),
ImmutableList.of());
return createFragment(planNode);
}
private static PlanFragment createBroadcastJoinPlanFragment(String name, PlanFragment buildFragment)
{
Symbol symbol = new Symbol("column");
PlanNode tableScan = new TableScanNode(
new PlanNodeId(name),
new TableHandle(new ConnectorId("test"), new TestingTableHandle()),
ImmutableList.of(symbol),
ImmutableMap.of(symbol, new TestingColumnHandle("column")),
Optional.empty(),
TupleDomain.all(),
null);
RemoteSourceNode remote = new RemoteSourceNode(new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of());
PlanNode join = new JoinNode(
new PlanNodeId(name + "_id"),
INNER,
tableScan,
remote,
ImmutableList.of(),
ImmutableList.<Symbol>builder()
.addAll(tableScan.getOutputSymbols())
.addAll(remote.getOutputSymbols())
.build(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.of(REPLICATED));
return createFragment(join);
}
private static PlanFragment createJoinPlanFragment(JoinNode.Type joinType, String name, PlanFragment buildFragment, PlanFragment probeFragment)
{
RemoteSourceNode probe = new RemoteSourceNode(new PlanNodeId("probe_id"), probeFragment.getId(), ImmutableList.of());
RemoteSourceNode build = new RemoteSourceNode(new PlanNodeId("build_id"), buildFragment.getId(), ImmutableList.of());
PlanNode planNode = new JoinNode(
new PlanNodeId(name + "_id"),
joinType,
probe,
build,
ImmutableList.of(),
ImmutableList.<Symbol>builder()
.addAll(probe.getOutputSymbols())
.addAll(build.getOutputSymbols())
.build(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.of(PARTITIONED));
return createFragment(planNode);
}
private static PlanFragment createTableScanPlanFragment(String name)
{
Symbol symbol = new Symbol("column");
PlanNode planNode = new TableScanNode(
new PlanNodeId(name),
new TableHandle(new ConnectorId("test"), new TestingTableHandle()),
ImmutableList.of(symbol),
ImmutableMap.of(symbol, new TestingColumnHandle("column")),
Optional.empty(),
TupleDomain.all(),
null);
return createFragment(planNode);
}
private static PlanFragment createFragment(PlanNode planNode)
{
ImmutableMap.Builder<Symbol, Type> types = ImmutableMap.builder();
for (Symbol symbol : planNode.getOutputSymbols()) {
types.put(symbol, VARCHAR);
}
return new PlanFragment(
new PlanFragmentId(planNode.getId() + "_fragment_id"),
planNode,
types.build(),
SOURCE_DISTRIBUTION,
ImmutableList.of(planNode.getId()),
new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()));
}
}