/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hive.spark.client; import com.google.common.base.Throwables; import com.google.common.io.Files; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.nio.NioEventLoopGroup; import java.io.File; import java.io.IOException; import java.io.Serializable; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.io.FileUtils; import org.apache.hadoop.hive.common.classification.InterfaceAudience; import org.apache.hive.spark.client.metrics.Metrics; import org.apache.hive.spark.client.rpc.Rpc; import org.apache.hive.spark.client.rpc.RpcConfiguration; import org.apache.hive.spark.counter.SparkCounters; import org.apache.spark.SparkConf; import org.apache.spark.SparkJobInfo; import org.apache.spark.api.java.JavaFutureAction; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.scheduler.SparkListener; import org.apache.spark.scheduler.SparkListenerJobEnd; import org.apache.spark.scheduler.SparkListenerJobStart; import org.apache.spark.scheduler.SparkListenerTaskEnd; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.util.concurrent.ThreadFactoryBuilder; /** * Driver code for the Spark client library. */ @InterfaceAudience.Private public class RemoteDriver { private static final Logger LOG = LoggerFactory.getLogger(RemoteDriver.class); private final Map<String, JobWrapper<?>> activeJobs; private final Object jcLock; private final Object shutdownLock; private final ExecutorService executor; private final NioEventLoopGroup egroup; private final Rpc clientRpc; private final DriverProtocol protocol; // a local temp dir specific to this driver private final File localTmpDir; // Used to queue up requests while the SparkContext is being created. private final List<JobWrapper<?>> jobQueue = Lists.newLinkedList(); // jc is effectively final, but it has to be volatile since it's accessed by different // threads while the constructor is running. private volatile JobContextImpl jc; private volatile boolean running; private RemoteDriver(String[] args) throws Exception { this.activeJobs = Maps.newConcurrentMap(); this.jcLock = new Object(); this.shutdownLock = new Object(); localTmpDir = Files.createTempDir(); SparkConf conf = new SparkConf(); String serverAddress = null; int serverPort = -1; for (int idx = 0; idx < args.length; idx += 2) { String key = args[idx]; if (key.equals("--remote-host")) { serverAddress = getArg(args, idx); } else if (key.equals("--remote-port")) { serverPort = Integer.parseInt(getArg(args, idx)); } else if (key.equals("--client-id")) { conf.set(SparkClientFactory.CONF_CLIENT_ID, getArg(args, idx)); } else if (key.equals("--secret")) { conf.set(SparkClientFactory.CONF_KEY_SECRET, getArg(args, idx)); } else if (key.equals("--conf")) { String[] val = getArg(args, idx).split("[=]", 2); conf.set(val[0], val[1]); } else { throw new IllegalArgumentException("Invalid command line: " + Joiner.on(" ").join(args)); } } executor = Executors.newCachedThreadPool(); LOG.info("Connecting to: {}:{}", serverAddress, serverPort); Map<String, String> mapConf = Maps.newHashMap(); for (Tuple2<String, String> e : conf.getAll()) { mapConf.put(e._1(), e._2()); LOG.debug("Remote Driver configured with: " + e._1() + "=" + e._2()); } String clientId = mapConf.get(SparkClientFactory.CONF_CLIENT_ID); Preconditions.checkArgument(clientId != null, "No client ID provided."); String secret = mapConf.get(SparkClientFactory.CONF_KEY_SECRET); Preconditions.checkArgument(secret != null, "No secret provided."); int threadCount = new RpcConfiguration(mapConf).getRpcThreadCount(); this.egroup = new NioEventLoopGroup( threadCount, new ThreadFactoryBuilder() .setNameFormat("Driver-RPC-Handler-%d") .setDaemon(true) .build()); this.protocol = new DriverProtocol(); // The RPC library takes care of timing out this. this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort, clientId, secret, protocol).get(); this.running = true; this.clientRpc.addListener(new Rpc.Listener() { @Override public void rpcClosed(Rpc rpc) { LOG.warn("Shutting down driver because RPC channel was closed."); shutdown(null); } }); try { JavaSparkContext sc = new JavaSparkContext(conf); sc.sc().addSparkListener(new ClientListener()); synchronized (jcLock) { jc = new JobContextImpl(sc, localTmpDir); jcLock.notifyAll(); } } catch (Exception e) { LOG.error("Failed to start SparkContext: " + e, e); shutdown(e); synchronized (jcLock) { jcLock.notifyAll(); } throw e; } synchronized (jcLock) { for (Iterator<JobWrapper<?>> it = jobQueue.iterator(); it.hasNext();) { it.next().submit(); } } } private void run() throws InterruptedException { synchronized (shutdownLock) { while (running) { shutdownLock.wait(); } } executor.shutdownNow(); try { FileUtils.deleteDirectory(localTmpDir); } catch (IOException e) { LOG.warn("Failed to delete local tmp dir: " + localTmpDir, e); } } private void submit(JobWrapper<?> job) { synchronized (jcLock) { if (jc != null) { job.submit(); } else { LOG.info("SparkContext not yet up, queueing job request."); jobQueue.add(job); } } } private synchronized void shutdown(Throwable error) { if (running) { if (error == null) { LOG.info("Shutting down remote driver."); } else { LOG.error("Shutting down remote driver due to error: " + error, error); } running = false; for (JobWrapper<?> job : activeJobs.values()) { cancelJob(job); } if (error != null) { protocol.sendError(error); } if (jc != null) { jc.stop(); } clientRpc.close(); egroup.shutdownGracefully(); synchronized (shutdownLock) { shutdownLock.notifyAll(); } } } private boolean cancelJob(JobWrapper<?> job) { boolean cancelled = false; for (JavaFutureAction<?> action : job.jobs) { cancelled |= action.cancel(true); } return cancelled | (job.future != null && job.future.cancel(true)); } private String getArg(String[] args, int keyIdx) { int valIdx = keyIdx + 1; if (args.length <= valIdx) { throw new IllegalArgumentException("Invalid command line: " + Joiner.on(" ").join(args)); } return args[valIdx]; } private class DriverProtocol extends BaseProtocol { void sendError(Throwable error) { LOG.debug("Send error to Client: {}", Throwables.getStackTraceAsString(error)); clientRpc.call(new Error(error)); } <T extends Serializable> void jobFinished(String jobId, T result, Throwable error, SparkCounters counters) { LOG.debug("Send job({}) result to Client.", jobId); clientRpc.call(new JobResult(jobId, result, error, counters)); } void jobStarted(String jobId) { clientRpc.call(new JobStarted(jobId)); } void jobSubmitted(String jobId, int sparkJobId) { LOG.debug("Send job({}/{}) submitted to Client.", jobId, sparkJobId); clientRpc.call(new JobSubmitted(jobId, sparkJobId)); } void sendMetrics(String jobId, int sparkJobId, int stageId, long taskId, Metrics metrics) { LOG.debug("Send task({}/{}/{}/{}) metric to Client.", jobId, sparkJobId, stageId, taskId); clientRpc.call(new JobMetrics(jobId, sparkJobId, stageId, taskId, metrics)); } private void handle(ChannelHandlerContext ctx, CancelJob msg) { JobWrapper<?> job = activeJobs.get(msg.id); if (job == null || !cancelJob(job)) { LOG.info("Requested to cancel an already finished job."); } } private void handle(ChannelHandlerContext ctx, EndSession msg) { LOG.debug("Shutting down due to EndSession request."); shutdown(null); } private void handle(ChannelHandlerContext ctx, JobRequest msg) { LOG.info("Received job request {}", msg.id); JobWrapper<?> wrapper = new JobWrapper<Serializable>(msg); activeJobs.put(msg.id, wrapper); submit(wrapper); } private Object handle(ChannelHandlerContext ctx, SyncJobRequest msg) throws Exception { // In case the job context is not up yet, let's wait, since this is supposed to be a // "synchronous" RPC. if (jc == null) { synchronized (jcLock) { while (jc == null) { jcLock.wait(); if (!running) { throw new IllegalStateException("Remote context is shutting down."); } } } } jc.setMonitorCb(new MonitorCallback() { @Override public void call(JavaFutureAction<?> future, SparkCounters sparkCounters, Set<Integer> cachedRDDIds) { throw new IllegalStateException( "JobContext.monitor() is not available for synchronous jobs."); } }); try { return msg.job.call(jc); } finally { jc.setMonitorCb(null); } } } private class JobWrapper<T extends Serializable> implements Callable<Void> { private final BaseProtocol.JobRequest<T> req; private final List<JavaFutureAction<?>> jobs; private final AtomicInteger jobEndReceived; private int completed; private SparkCounters sparkCounters; private Set<Integer> cachedRDDIds; private Integer sparkJobId; private Future<?> future; JobWrapper(BaseProtocol.JobRequest<T> req) { this.req = req; this.jobs = Lists.newArrayList(); completed = 0; jobEndReceived = new AtomicInteger(0); this.sparkCounters = null; this.cachedRDDIds = null; this.sparkJobId = null; } @Override public Void call() throws Exception { protocol.jobStarted(req.id); try { jc.setMonitorCb(new MonitorCallback() { @Override public void call(JavaFutureAction<?> future, SparkCounters sparkCounters, Set<Integer> cachedRDDIds) { monitorJob(future, sparkCounters, cachedRDDIds); } }); T result = req.job.call(jc); // In case the job is empty, there won't be JobStart/JobEnd events. The only way // to know if the job has finished is to check the futures here ourselves. for (JavaFutureAction<?> future : jobs) { future.get(); completed++; LOG.debug("Client job {}: {} of {} Spark jobs finished.", req.id, completed, jobs.size()); } // If the job is not empty (but runs fast), we have to wait until all the TaskEnd/JobEnd // events are processed. Otherwise, task metrics may get lost. See HIVE-13525. if (sparkJobId != null) { SparkJobInfo sparkJobInfo = jc.sc().statusTracker().getJobInfo(sparkJobId); if (sparkJobInfo != null && sparkJobInfo.stageIds() != null && sparkJobInfo.stageIds().length > 0) { synchronized (jobEndReceived) { while (jobEndReceived.get() < jobs.size()) { jobEndReceived.wait(); } } } } SparkCounters counters = null; if (sparkCounters != null) { counters = sparkCounters.snapshot(); } protocol.jobFinished(req.id, result, null, counters); } catch (Throwable t) { // Catch throwables in a best-effort to report job status back to the client. It's // re-thrown so that the executor can destroy the affected thread (or the JVM can // die or whatever would happen if the throwable bubbled up). LOG.info("Failed to run job " + req.id, t); protocol.jobFinished(req.id, null, t, sparkCounters != null ? sparkCounters.snapshot() : null); throw new ExecutionException(t); } finally { jc.setMonitorCb(null); activeJobs.remove(req.id); releaseCache(); } return null; } void submit() { this.future = executor.submit(this); } void jobDone() { synchronized (jobEndReceived) { jobEndReceived.incrementAndGet(); jobEndReceived.notifyAll(); } } /** * Release cached RDDs as soon as the job is done. * This is different from local Spark client so as * to save a RPC call/trip, avoid passing cached RDD * id information around. Otherwise, we can follow * the local Spark client way to be consistent. */ void releaseCache() { if (cachedRDDIds != null) { for (Integer cachedRDDId: cachedRDDIds) { jc.sc().sc().unpersistRDD(cachedRDDId, false); } } } private void monitorJob(JavaFutureAction<?> job, SparkCounters sparkCounters, Set<Integer> cachedRDDIds) { jobs.add(job); if (!jc.getMonitoredJobs().containsKey(req.id)) { jc.getMonitoredJobs().put(req.id, new CopyOnWriteArrayList<JavaFutureAction<?>>()); } jc.getMonitoredJobs().get(req.id).add(job); this.sparkCounters = sparkCounters; this.cachedRDDIds = cachedRDDIds; sparkJobId = job.jobIds().get(0); protocol.jobSubmitted(req.id, sparkJobId); } } private class ClientListener extends SparkListener { private final Map<Integer, Integer> stageToJobId = Maps.newHashMap(); @Override public void onJobStart(SparkListenerJobStart jobStart) { synchronized (stageToJobId) { for (int i = 0; i < jobStart.stageIds().length(); i++) { stageToJobId.put((Integer) jobStart.stageIds().apply(i), jobStart.jobId()); } } } @Override public void onJobEnd(SparkListenerJobEnd jobEnd) { synchronized (stageToJobId) { for (Iterator<Map.Entry<Integer, Integer>> it = stageToJobId.entrySet().iterator(); it.hasNext();) { Map.Entry<Integer, Integer> e = it.next(); if (e.getValue() == jobEnd.jobId()) { it.remove(); } } } String clientId = getClientId(jobEnd.jobId()); if (clientId != null) { activeJobs.get(clientId).jobDone(); } } @Override public void onTaskEnd(SparkListenerTaskEnd taskEnd) { if (taskEnd.reason() instanceof org.apache.spark.Success$ && !taskEnd.taskInfo().speculative()) { Metrics metrics = new Metrics(taskEnd.taskMetrics()); Integer jobId; synchronized (stageToJobId) { jobId = stageToJobId.get(taskEnd.stageId()); } // TODO: implement implicit AsyncRDDActions conversion instead of jc.monitor()? // TODO: how to handle stage failures? String clientId = getClientId(jobId); if (clientId != null) { protocol.sendMetrics(clientId, jobId, taskEnd.stageId(), taskEnd.taskInfo().taskId(), metrics); } } } /** * Returns the client job ID for the given Spark job ID. * * This will only work for jobs monitored via JobContext#monitor(). Other jobs won't be * matched, and this method will return `None`. */ private String getClientId(Integer jobId) { for (Map.Entry<String, JobWrapper<?>> e : activeJobs.entrySet()) { for (JavaFutureAction<?> future : e.getValue().jobs) { if (future.jobIds().contains(jobId)) { return e.getKey(); } } } return null; } } public static void main(String[] args) throws Exception { new RemoteDriver(args).run(); } }