/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.raptor.storage; import com.facebook.presto.raptor.NodeSupplier; import com.facebook.presto.raptor.RaptorConnectorId; import com.facebook.presto.raptor.backup.BackupStore; import com.facebook.presto.raptor.metadata.ShardManager; import com.facebook.presto.raptor.metadata.ShardMetadata; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.NodeManager; import com.google.common.annotations.VisibleForTesting; import io.airlift.log.Logger; import io.airlift.stats.CounterStat; import io.airlift.units.Duration; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import javax.inject.Inject; import java.io.File; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Queue; import java.util.Set; import java.util.UUID; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Maps.filterKeys; import static com.google.common.collect.Maps.filterValues; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static java.lang.Math.round; import static java.util.Comparator.comparingLong; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toSet; public class ShardEjector { private static final Logger log = Logger.get(ShardEjector.class); private final String currentNode; private final NodeSupplier nodeSupplier; private final ShardManager shardManager; private final StorageService storageService; private final Duration interval; private final Optional<BackupStore> backupStore; private final ScheduledExecutorService executor; private final AtomicBoolean started = new AtomicBoolean(); private final CounterStat shardsEjected = new CounterStat(); private final CounterStat jobErrors = new CounterStat(); @Inject public ShardEjector( NodeManager nodeManager, NodeSupplier nodeSupplier, ShardManager shardManager, StorageService storageService, StorageManagerConfig config, Optional<BackupStore> backupStore, RaptorConnectorId connectorId) { this(nodeManager.getCurrentNode().getNodeIdentifier(), nodeSupplier, shardManager, storageService, config.getShardEjectorInterval(), backupStore, connectorId.toString()); } public ShardEjector( String currentNode, NodeSupplier nodeSupplier, ShardManager shardManager, StorageService storageService, Duration interval, Optional<BackupStore> backupStore, String connectorId) { this.currentNode = requireNonNull(currentNode, "currentNode is null"); this.nodeSupplier = requireNonNull(nodeSupplier, "nodeSupplier is null"); this.shardManager = requireNonNull(shardManager, "shardManager is null"); this.storageService = requireNonNull(storageService, "storageService is null"); this.interval = requireNonNull(interval, "interval is null"); this.backupStore = requireNonNull(backupStore, "backupStore is null"); this.executor = newScheduledThreadPool(1, daemonThreadsNamed("shard-ejector-" + connectorId)); } @PostConstruct public void start() { if (!backupStore.isPresent()) { return; } if (!started.getAndSet(true)) { startJob(); } } @PreDestroy public void shutdown() { executor.shutdownNow(); } @Managed @Nested public CounterStat getShardsEjected() { return shardsEjected; } @Managed @Nested public CounterStat getJobErrors() { return jobErrors; } private void startJob() { executor.scheduleWithFixedDelay(() -> { try { // jitter to avoid overloading database long interval = this.interval.roundTo(SECONDS); SECONDS.sleep(ThreadLocalRandom.current().nextLong(1, interval)); process(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (Throwable t) { log.error(t, "Error ejecting shards"); jobErrors.update(1); } }, 0, interval.toMillis(), MILLISECONDS); } @VisibleForTesting void process() { checkState(backupStore.isPresent(), "backup store must be present"); // get the size of assigned shards for each node Map<String, Long> nodes = shardManager.getNodeBytes(); Set<String> activeNodes = nodeSupplier.getWorkerNodes().stream() .map(Node::getNodeIdentifier) .collect(toSet()); // only include active nodes nodes = new HashMap<>(filterKeys(nodes, activeNodes::contains)); if (nodes.isEmpty()) { return; } // get current node size if (!nodes.containsKey(currentNode)) { return; } long nodeSize = nodes.get(currentNode); // get average node size long averageSize = round(nodes.values().stream() .mapToLong(Long::longValue) .average() .getAsDouble()); long maxSize = round(averageSize * 1.01); // skip if not above max if (nodeSize <= maxSize) { return; } // only include nodes that are below threshold nodes = new HashMap<>(filterValues(nodes, size -> size <= averageSize)); // get non-bucketed node shards by size, largest to smallest List<ShardMetadata> shards = shardManager.getNodeShards(currentNode).stream() .filter(shard -> !shard.getBucketNumber().isPresent()) .sorted(comparingLong(ShardMetadata::getCompressedSize).reversed()) .collect(toList()); // eject shards while current node is above max Queue<ShardMetadata> queue = new ArrayDeque<>(shards); while ((nodeSize > maxSize) && !queue.isEmpty()) { ShardMetadata shard = queue.remove(); long shardSize = shard.getCompressedSize(); UUID shardUuid = shard.getShardUuid(); // verify backup exists if (!backupStore.get().shardExists(shardUuid)) { log.warn("No backup for shard: %s", shardUuid); } // pick target node String target = pickTargetNode(nodes, shardSize, averageSize); if (target == null) { return; } long targetSize = nodes.get(target); // stats log.info("Moving shard %s to node %s (shard: %s, node: %s, average: %s, target: %s)", shardUuid, target, shardSize, nodeSize, averageSize, targetSize); shardsEjected.update(1); // update size nodes.put(target, targetSize + shardSize); nodeSize -= shardSize; // move assignment shardManager.assignShard(shard.getTableId(), shardUuid, target, false); shardManager.unassignShard(shard.getTableId(), shardUuid, currentNode); // delete local file File file = storageService.getStorageFile(shardUuid); if (file.exists() && !file.delete()) { log.warn("Failed to delete shard file: %s", file); } } } private static String pickTargetNode(Map<String, Long> nodes, long shardSize, long maxSize) { while (!nodes.isEmpty()) { String node = pickCandidateNode(nodes); if ((nodes.get(node) + shardSize) <= maxSize) { return node; } nodes.remove(node); } return null; } private static String pickCandidateNode(Map<String, Long> nodes) { checkArgument(!nodes.isEmpty()); if (nodes.size() == 1) { return nodes.keySet().iterator().next(); } // pick two random candidates, then choose the smaller one List<String> candidates = new ArrayList<>(nodes.keySet()); Collections.shuffle(candidates); String first = candidates.get(0); String second = candidates.get(1); return (nodes.get(first) <= nodes.get(second)) ? first : second; } }