/* * Copyright © 2016 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.app.runtime.spark.distributed; import co.cask.cdap.api.workflow.WorkflowToken; import co.cask.cdap.app.runtime.spark.SparkMainWrapper; import co.cask.cdap.common.BadRequestException; import co.cask.cdap.common.HttpExceptionHandler; import co.cask.cdap.proto.id.ProgramRunId; import co.cask.http.AbstractHttpHandler; import co.cask.http.HttpHandler; import co.cask.http.HttpResponder; import co.cask.http.NettyHttpService; import com.google.common.base.Charsets; import com.google.common.reflect.TypeToken; import com.google.common.util.concurrent.AbstractIdleService; import com.google.gson.Gson; import org.apache.twill.api.Command; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBufferInputStream; import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpResponseStatus; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStreamReader; import java.io.Reader; import java.lang.reflect.Type; import java.net.InetSocketAddress; import java.net.URI; import java.util.Collections; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import javax.annotation.Nullable; import javax.ws.rs.POST; import javax.ws.rs.PUT; import javax.ws.rs.Path; import javax.ws.rs.PathParam; /** * The HTTP service for communicating with the {@link SparkMainWrapper} running in the driver * for controlling lifecycle as well as the {@link WorkflowToken}. */ public final class SparkExecutionService extends AbstractIdleService { private static final Logger LOG = LoggerFactory.getLogger(SparkExecutionService.class); private static final Gson GSON = new Gson(); private static final Type TOKEN_TYPE = new TypeToken<Map<String, String>>() { }.getType(); private static final long SHUTDOWN_WAIT_SECONDS = 30L; private final NettyHttpService httpServer; private final ProgramRunId programRunId; @Nullable private final WorkflowToken workflowToken; private final AtomicBoolean stopping; private final CountDownLatch stopLatch; public SparkExecutionService(String host, ProgramRunId programRunId, @Nullable WorkflowToken workflowToken) { this.httpServer = NettyHttpService.builder() .addHttpHandlers(Collections.singletonList(new SparkControllerHandler())) .setHost(host) .setExceptionHandler(new HttpExceptionHandler()) .build(); this.stopping = new AtomicBoolean(); this.stopLatch = new CountDownLatch(1); this.programRunId = programRunId; this.workflowToken = workflowToken; } /** * Returns the base {@link URI} for talking to this service remotely through HTTP. */ public URI getBaseURI() { InetSocketAddress bindAddress = httpServer.getBindAddress(); if (bindAddress == null) { throw new IllegalStateException("SparkExecutionService hasn't been started"); } return URI.create(String.format("http://%s:%d", bindAddress.getHostName(), bindAddress.getPort())); } @Override protected void startUp() throws Exception { httpServer.startAndWait(); } @Override protected void shutDown() throws Exception { stopping.set(true); if (!stopLatch.await(SHUTDOWN_WAIT_SECONDS, TimeUnit.SECONDS)) { LOG.warn("Timeout in waiting for Spark program to stop: {}", programRunId); } httpServer.stopAndWait(); } public void shutdownNow() { stopLatch.countDown(); stop(); } /** * The {@link HttpHandler} for communicating with the Spark driver. */ public final class SparkControllerHandler extends AbstractHttpHandler { /** * Handles heartbeat request from the running Spark program. */ @POST @Path("/v1/spark/{programName}/runs/{runId}/heartbeat") public synchronized void heartbeat(HttpRequest request, HttpResponder responder, @PathParam("programName") String programName, @PathParam("runId") String runId) throws Exception { if (stopLatch.await(0, TimeUnit.SECONDS)) { throw new BadRequestException( String.format("Spark program '%s' is already stopped. Heartbeat is not accepted.", programRunId)); } validateRequest(programName, runId); updateWorkflowToken(request.getContent()); // If the stop was requested, send the "stop" command if (stopping.get()) { Command.Builder.of("stop"); responder.sendJson(HttpResponseStatus.OK, SparkCommand.STOP); } else { responder.sendStatus(HttpResponseStatus.OK); } } /** * Handles execution completion request from the running Spark program. */ @PUT @Path("/v1/spark/{programName}/runs/{runId}/completed") public synchronized void completed(HttpRequest request, HttpResponder responder, @PathParam("programName") String programName, @PathParam("runId") String runId) throws Exception { validateRequest(programName, runId); try { updateWorkflowToken(request.getContent()); responder.sendStatus(HttpResponseStatus.OK); } finally { stopLatch.countDown(); } } /** * Verifies the call is from the right client. */ private void validateRequest(String programName, String runId) throws Exception { if (!programRunId.getProgram().equals(programName)) { throw new BadRequestException( String.format("Request program name '%s' is not the same as the context program name '%s", programName, programRunId.getProgram()) ); } if (runId == null || !programRunId.getRun().equals(runId)) { throw new BadRequestException( String.format("Request runId '%s' is not the same as the context runId '%s'", runId, programRunId.getRun()) ); } } /** * Updates {@link WorkflowToken} of the program. It is a json Map<String, String> in the request body. */ private void updateWorkflowToken(ChannelBuffer requestBody) { if (!requestBody.readable()) { return; } if (workflowToken == null) { // This shouldn't happen. Just log and ignore the update LOG.warn("Spark program is not running inside Workflow. Ignore workflow token update: {}", programRunId); return; } try (Reader reader = new InputStreamReader(new ChannelBufferInputStream(requestBody), Charsets.UTF_8)) { Map<String, String> token = GSON.fromJson(reader, TOKEN_TYPE); for (Map.Entry<String, String> entry : token.entrySet()) { workflowToken.put(entry.getKey(), entry.getValue()); } } catch (IOException e) { // Shouldn't happen, since all reading is from in-memory buffer LOG.warn("Exception when deocoding workflow token update request for {}", programRunId, e); } } } }