/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;
import com.facebook.presto.client.NodeVersion;
import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.execution.scheduler.LegacyNetworkTopology;
import com.facebook.presto.execution.scheduler.NetworkLocation;
import com.facebook.presto.execution.scheduler.NetworkLocationCache;
import com.facebook.presto.execution.scheduler.NetworkTopology;
import com.facebook.presto.execution.scheduler.NodeScheduler;
import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
import com.facebook.presto.execution.scheduler.NodeSelector;
import com.facebook.presto.metadata.InMemoryNodeManager;
import com.facebook.presto.metadata.PrestoNode;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.ConnectorSplit;
import com.facebook.presto.spi.HostAddress;
import com.facebook.presto.spi.Node;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.facebook.presto.util.FinalizerService;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import static com.facebook.presto.execution.scheduler.NetworkLocation.ROOT_LOCATION;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
@Test(singleThreaded = true)
public class TestNodeScheduler
{
private static final ConnectorId CONNECTOR_ID = new ConnectorId("connector_id");
private FinalizerService finalizerService;
private NodeTaskMap nodeTaskMap;
private InMemoryNodeManager nodeManager;
private NodeSelector nodeSelector;
private Map<Node, RemoteTask> taskMap;
private ExecutorService remoteTaskExecutor;
@BeforeMethod
public void setUp()
throws Exception
{
finalizerService = new FinalizerService();
nodeTaskMap = new NodeTaskMap(finalizerService);
nodeManager = new InMemoryNodeManager();
ImmutableList.Builder<Node> nodeBuilder = ImmutableList.builder();
nodeBuilder.add(new PrestoNode("other1", URI.create("http://127.0.0.1:11"), NodeVersion.UNKNOWN, false));
nodeBuilder.add(new PrestoNode("other2", URI.create("http://127.0.0.1:12"), NodeVersion.UNKNOWN, false));
nodeBuilder.add(new PrestoNode("other3", URI.create("http://127.0.0.1:13"), NodeVersion.UNKNOWN, false));
ImmutableList<Node> nodes = nodeBuilder.build();
nodeManager.addNode(CONNECTOR_ID, nodes);
NodeSchedulerConfig nodeSchedulerConfig = new NodeSchedulerConfig()
.setMaxSplitsPerNode(20)
.setIncludeCoordinator(false)
.setMaxPendingSplitsPerTask(10);
NodeScheduler nodeScheduler = new NodeScheduler(new LegacyNetworkTopology(), nodeManager, nodeSchedulerConfig, nodeTaskMap);
// contents of taskMap indicate the node-task map for the current stage
taskMap = new HashMap<>();
nodeSelector = nodeScheduler.createNodeSelector(CONNECTOR_ID);
remoteTaskExecutor = Executors.newCachedThreadPool(daemonThreadsNamed("remoteTaskExecutor-%s"));
finalizerService.start();
}
@AfterMethod
public void tearDown()
throws Exception
{
remoteTaskExecutor.shutdown();
finalizerService.destroy();
}
@Test
public void testScheduleLocal()
throws Exception
{
Split split = new Split(CONNECTOR_ID, TestingTransactionHandle.create(), new TestSplitLocal());
Set<Split> splits = ImmutableSet.of(split);
Map.Entry<Node, Split> assignment = Iterables.getOnlyElement(nodeSelector.computeAssignments(splits, ImmutableList.copyOf(taskMap.values())).getAssignments().entries());
assertEquals(assignment.getKey().getHostAndPort(), split.getAddresses().get(0));
assertEquals(assignment.getValue(), split);
}
@Test(timeOut = 60 * 1000)
public void testTopologyAwareScheduling()
throws Exception
{
TestingTransactionHandle transactionHandle = TestingTransactionHandle.create();
NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService);
InMemoryNodeManager nodeManager = new InMemoryNodeManager();
ImmutableList.Builder<Node> nodeBuilder = ImmutableList.builder();
nodeBuilder.add(new PrestoNode("node1", URI.create("http://host1.rack1:11"), NodeVersion.UNKNOWN, false));
nodeBuilder.add(new PrestoNode("node2", URI.create("http://host2.rack1:12"), NodeVersion.UNKNOWN, false));
nodeBuilder.add(new PrestoNode("node3", URI.create("http://host3.rack2:13"), NodeVersion.UNKNOWN, false));
ImmutableList<Node> nodes = nodeBuilder.build();
nodeManager.addNode(CONNECTOR_ID, nodes);
// contents of taskMap indicate the node-task map for the current stage
Map<Node, RemoteTask> taskMap = new HashMap<>();
NodeSchedulerConfig nodeSchedulerConfig = new NodeSchedulerConfig()
.setMaxSplitsPerNode(25)
.setIncludeCoordinator(false)
.setNetworkTopology("test")
.setMaxPendingSplitsPerTask(20);
TestNetworkTopology topology = new TestNetworkTopology();
NetworkLocationCache locationCache = new NetworkLocationCache(topology)
{
@Override
public NetworkLocation get(HostAddress host)
{
// Bypass the cache for workers, since we only look them up once and they would all be unresolved otherwise
if (host.getHostText().startsWith("host")) {
return topology.locate(host);
}
else {
return super.get(host);
}
}
};
NodeScheduler nodeScheduler = new NodeScheduler(locationCache, topology, nodeManager, nodeSchedulerConfig, nodeTaskMap);
NodeSelector nodeSelector = nodeScheduler.createNodeSelector(CONNECTOR_ID);
// Fill up the nodes with non-local data
ImmutableSet.Builder<Split> nonRackLocalBuilder = ImmutableSet.builder();
for (int i = 0; i < (25 + 11) * 3; i++) {
nonRackLocalBuilder.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote(HostAddress.fromParts("data.other_rack", 1))));
}
Set<Split> nonRackLocalSplits = nonRackLocalBuilder.build();
Multimap<Node, Split> assignments = nodeSelector.computeAssignments(nonRackLocalSplits, ImmutableList.copyOf(taskMap.values())).getAssignments();
MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor);
int task = 0;
for (Node node : assignments.keySet()) {
TaskId taskId = new TaskId("test", 1, task);
task++;
MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, ImmutableList.copyOf(assignments.get(node)), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId));
remoteTask.startSplits(25);
nodeTaskMap.addTask(node, remoteTask);
taskMap.put(node, remoteTask);
}
// Continue assigning to fill up part of the queue
nonRackLocalSplits = Sets.difference(nonRackLocalSplits, new HashSet<>(assignments.values()));
assignments = nodeSelector.computeAssignments(nonRackLocalSplits, ImmutableList.copyOf(taskMap.values())).getAssignments();
for (Node node : assignments.keySet()) {
RemoteTask remoteTask = taskMap.get(node);
remoteTask.addSplits(ImmutableMultimap.<PlanNodeId, Split>builder()
.putAll(new PlanNodeId("sourceId"), assignments.get(node))
.build());
}
nonRackLocalSplits = Sets.difference(nonRackLocalSplits, new HashSet<>(assignments.values()));
// Check that 3 of the splits were rejected, since they're non-local
assertEquals(nonRackLocalSplits.size(), 3);
// Assign rack-local splits
ImmutableSet.Builder<Split> rackLocalSplits = ImmutableSet.builder();
HostAddress dataHost1 = HostAddress.fromParts("data.rack1", 1);
HostAddress dataHost2 = HostAddress.fromParts("data.rack2", 1);
for (int i = 0; i < 6 * 2; i++) {
rackLocalSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote(dataHost1)));
}
for (int i = 0; i < 6; i++) {
rackLocalSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote(dataHost2)));
}
assignments = nodeSelector.computeAssignments(rackLocalSplits.build(), ImmutableList.copyOf(taskMap.values())).getAssignments();
for (Node node : assignments.keySet()) {
RemoteTask remoteTask = taskMap.get(node);
remoteTask.addSplits(ImmutableMultimap.<PlanNodeId, Split>builder()
.putAll(new PlanNodeId("sourceId"), assignments.get(node))
.build());
}
Set<Split> unassigned = Sets.difference(rackLocalSplits.build(), new HashSet<>(assignments.values()));
// Compute the assignments a second time to account for the fact that some splits may not have been assigned due to asynchronous
// loading of the NetworkLocationCache
boolean cacheRefreshed = false;
while (!cacheRefreshed) {
cacheRefreshed = true;
if (locationCache.get(dataHost1).equals(ROOT_LOCATION)) {
cacheRefreshed = false;
}
if (locationCache.get(dataHost2).equals(ROOT_LOCATION)) {
cacheRefreshed = false;
}
MILLISECONDS.sleep(10);
}
assignments = nodeSelector.computeAssignments(unassigned, ImmutableList.copyOf(taskMap.values())).getAssignments();
for (Node node : assignments.keySet()) {
RemoteTask remoteTask = taskMap.get(node);
remoteTask.addSplits(ImmutableMultimap.<PlanNodeId, Split>builder()
.putAll(new PlanNodeId("sourceId"), assignments.get(node))
.build());
}
unassigned = Sets.difference(unassigned, new HashSet<>(assignments.values()));
assertEquals(unassigned.size(), 3);
int rack1 = 0;
int rack2 = 0;
for (Split split : unassigned) {
String rack = topology.locate(split.getAddresses().get(0)).getSegments().get(0);
switch (rack) {
case "rack1":
rack1++;
break;
case "rack2":
rack2++;
break;
default:
fail();
}
}
assertEquals(rack1, 2);
assertEquals(rack2, 1);
// Assign local splits
ImmutableSet.Builder<Split> localSplits = ImmutableSet.builder();
localSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote(HostAddress.fromParts("host1.rack1", 1))));
localSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote(HostAddress.fromParts("host2.rack1", 1))));
localSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote(HostAddress.fromParts("host3.rack2", 1))));
assignments = nodeSelector.computeAssignments(localSplits.build(), ImmutableList.copyOf(taskMap.values())).getAssignments();
assertEquals(assignments.size(), 3);
assertEquals(assignments.keySet().size(), 3);
}
@Test
public void testScheduleRemote()
throws Exception
{
Set<Split> splits = new HashSet<>();
splits.add(new Split(CONNECTOR_ID, TestingTransactionHandle.create(), new TestSplitRemote()));
Multimap<Node, Split> assignments = nodeSelector.computeAssignments(splits, ImmutableList.copyOf(taskMap.values())).getAssignments();
assertEquals(assignments.size(), 1);
}
@Test
public void testBasicAssignment()
throws Exception
{
// One split for each node
Set<Split> splits = new HashSet<>();
for (int i = 0; i < 3; i++) {
splits.add(new Split(CONNECTOR_ID, TestingTransactionHandle.create(), new TestSplitRemote()));
}
Multimap<Node, Split> assignments = nodeSelector.computeAssignments(splits, ImmutableList.copyOf(taskMap.values())).getAssignments();
assertEquals(assignments.entries().size(), 3);
for (Node node : nodeManager.getActiveConnectorNodes(CONNECTOR_ID)) {
assertTrue(assignments.keySet().contains(node));
}
}
@Test
public void testMaxSplitsPerNode()
throws Exception
{
TestingTransactionHandle transactionHandle = TestingTransactionHandle.create();
Node newNode = new PrestoNode("other4", URI.create("http://127.0.0.1:14"), NodeVersion.UNKNOWN, false);
nodeManager.addNode(CONNECTOR_ID, newNode);
ImmutableList.Builder<Split> initialSplits = ImmutableList.builder();
for (int i = 0; i < 10; i++) {
initialSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote()));
}
MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor);
// Max out number of splits on node
TaskId taskId1 = new TaskId("test", 1, 1);
RemoteTask remoteTask1 = remoteTaskFactory.createTableScanTask(taskId1, newNode, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(newNode, taskId1));
nodeTaskMap.addTask(newNode, remoteTask1);
TaskId taskId2 = new TaskId("test", 1, 2);
RemoteTask remoteTask2 = remoteTaskFactory.createTableScanTask(taskId2, newNode, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(newNode, taskId2));
nodeTaskMap.addTask(newNode, remoteTask2);
Set<Split> splits = new HashSet<>();
for (int i = 0; i < 5; i++) {
splits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote()));
}
Multimap<Node, Split> assignments = nodeSelector.computeAssignments(splits, ImmutableList.copyOf(taskMap.values())).getAssignments();
// no split should be assigned to the newNode, as it already has maxNodeSplits assigned to it
assertFalse(assignments.keySet().contains(newNode));
remoteTask1.abort();
remoteTask2.abort();
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(newNode), 0);
}
@Test
public void testMaxSplitsPerNodePerTask()
throws Exception
{
TestingTransactionHandle transactionHandle = TestingTransactionHandle.create();
Node newNode = new PrestoNode("other4", URI.create("http://127.0.0.1:14"), NodeVersion.UNKNOWN, false);
nodeManager.addNode(CONNECTOR_ID, newNode);
ImmutableList.Builder<Split> initialSplits = ImmutableList.builder();
for (int i = 0; i < 20; i++) {
initialSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote()));
}
List<RemoteTask> tasks = new ArrayList<>();
MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor);
for (Node node : nodeManager.getActiveConnectorNodes(CONNECTOR_ID)) {
// Max out number of splits on node
TaskId taskId = new TaskId("test", 1, 1);
RemoteTask remoteTask = remoteTaskFactory.createTableScanTask(taskId, node, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(node, taskId));
nodeTaskMap.addTask(node, remoteTask);
tasks.add(remoteTask);
}
TaskId taskId = new TaskId("test", 1, 2);
RemoteTask newRemoteTask = remoteTaskFactory.createTableScanTask(taskId, newNode, initialSplits.build(), nodeTaskMap.createPartitionedSplitCountTracker(newNode, taskId));
// Max out pending splits on new node
taskMap.put(newNode, newRemoteTask);
nodeTaskMap.addTask(newNode, newRemoteTask);
tasks.add(newRemoteTask);
Set<Split> splits = new HashSet<>();
for (int i = 0; i < 5; i++) {
splits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote()));
}
Multimap<Node, Split> assignments = nodeSelector.computeAssignments(splits, ImmutableList.copyOf(taskMap.values())).getAssignments();
// no split should be assigned to the newNode, as it already has
// maxSplitsPerNode + maxSplitsPerNodePerTask assigned to it
assertEquals(assignments.keySet().size(), 3); // Splits should be scheduled on the other three nodes
assertFalse(assignments.keySet().contains(newNode)); // No splits scheduled on the maxed out node
for (RemoteTask task : tasks) {
task.abort();
}
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(newNode), 0);
}
@Test
public void testTaskCompletion()
throws Exception
{
MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor);
Node chosenNode = Iterables.get(nodeManager.getActiveConnectorNodes(CONNECTOR_ID), 0);
TaskId taskId = new TaskId("test", 1, 1);
RemoteTask remoteTask = remoteTaskFactory.createTableScanTask(
taskId,
chosenNode,
ImmutableList.of(new Split(CONNECTOR_ID, TestingTransactionHandle.create(), new TestSplitRemote())),
nodeTaskMap.createPartitionedSplitCountTracker(chosenNode, taskId));
nodeTaskMap.addTask(chosenNode, remoteTask);
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 1);
remoteTask.abort();
MILLISECONDS.sleep(100); // Sleep until cache expires
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 0);
remoteTask.abort();
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 0);
}
@Test
public void testSplitCount()
throws Exception
{
MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor);
Node chosenNode = Iterables.get(nodeManager.getActiveConnectorNodes(CONNECTOR_ID), 0);
TaskId taskId1 = new TaskId("test", 1, 1);
RemoteTask remoteTask1 = remoteTaskFactory.createTableScanTask(taskId1,
chosenNode,
ImmutableList.of(
new Split(CONNECTOR_ID, TestingTransactionHandle.create(), new TestSplitRemote()),
new Split(CONNECTOR_ID, TestingTransactionHandle.create(), new TestSplitRemote())),
nodeTaskMap.createPartitionedSplitCountTracker(chosenNode, taskId1));
TaskId taskId2 = new TaskId("test", 1, 2);
RemoteTask remoteTask2 = remoteTaskFactory.createTableScanTask(
taskId2,
chosenNode,
ImmutableList.of(new Split(CONNECTOR_ID, TestingTransactionHandle.create(), new TestSplitRemote())),
nodeTaskMap.createPartitionedSplitCountTracker(chosenNode, taskId2));
nodeTaskMap.addTask(chosenNode, remoteTask1);
nodeTaskMap.addTask(chosenNode, remoteTask2);
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 3);
remoteTask1.abort();
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 1);
remoteTask2.abort();
assertEquals(nodeTaskMap.getPartitionedSplitsOnNode(chosenNode), 0);
}
private static class TestSplitLocal
implements ConnectorSplit
{
@Override
public boolean isRemotelyAccessible()
{
return false;
}
@Override
public List<HostAddress> getAddresses()
{
return ImmutableList.of(HostAddress.fromString("127.0.0.1:11"));
}
@Override
public Object getInfo()
{
return this;
}
}
private static class TestSplitRemote
implements ConnectorSplit
{
private final List<HostAddress> hosts;
public TestSplitRemote()
{
this(HostAddress.fromString("127.0.0.1:" + ThreadLocalRandom.current().nextInt(5000)));
}
public TestSplitRemote(HostAddress host)
{
this.hosts = ImmutableList.of(requireNonNull(host, "host is null"));
}
@Override
public boolean isRemotelyAccessible()
{
return true;
}
@Override
public List<HostAddress> getAddresses()
{
return hosts;
}
@Override
public Object getInfo()
{
return this;
}
}
private static class TestNetworkTopology
implements NetworkTopology
{
@Override
public NetworkLocation locate(HostAddress address)
{
List<String> parts = new ArrayList<>(ImmutableList.copyOf(Splitter.on(".").split(address.getHostText())));
Collections.reverse(parts);
return NetworkLocation.create(parts);
}
@Override
public List<String> getLocationSegmentNames()
{
return ImmutableList.of("rack", "machine");
}
}
}