/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.task;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URL;
import java.nio.ByteBuffer;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.apache.tez.common.ContainerContext;
import org.apache.tez.common.ContainerTask;
import org.apache.tez.common.TezLocalResource;
import org.apache.tez.common.TezTaskUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.counters.Limits;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.dag.utils.RelocalizationUtils;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.runtime.common.objectregistry.ObjectLifeCycle;
import org.apache.tez.runtime.common.objectregistry.ObjectRegistryImpl;
import org.apache.tez.runtime.common.objectregistry.ObjectRegistryModule;
import org.apache.tez.runtime.library.shuffle.common.ShuffleUtils;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class TezChild {
private static final Logger LOG = Logger.getLogger(TezChild.class);
private final Configuration defaultConf;
private final String containerIdString;
private final int appAttemptNumber;
private final InetSocketAddress address;
private final String[] localDirs;
private final AtomicLong heartbeatCounter = new AtomicLong(0);
private final int getTaskMaxSleepTime;
private final int amHeartbeatInterval;
private final long sendCounterInterval;
private final int maxEventsToGet;
private final ListeningExecutorService executor;
private final ObjectRegistryImpl objectRegistry;
private final Map<String, ByteBuffer> serviceConsumerMetadata = new HashMap<String, ByteBuffer>();
private Multimap<String, String> startedInputsMap = HashMultimap.create();
private TaskReporter taskReporter;
private TezTaskUmbilicalProtocol umbilical;
private int taskCount = 0;
private TezVertexID lastVertexID;
public TezChild(Configuration conf, String host, int port, String containerIdentifier,
String tokenIdentifier, int appAttemptNumber, String[] localDirs,
ObjectRegistryImpl objectRegistry) throws IOException, InterruptedException {
this.defaultConf = conf;
this.containerIdString = containerIdentifier;
this.appAttemptNumber = appAttemptNumber;
this.localDirs = localDirs;
getTaskMaxSleepTime = defaultConf.getInt(
TezConfiguration.TEZ_TASK_GET_TASK_SLEEP_INTERVAL_MS_MAX,
TezConfiguration.TEZ_TASK_GET_TASK_SLEEP_INTERVAL_MS_MAX_DEFAULT);
amHeartbeatInterval = defaultConf.getInt(TezConfiguration.TEZ_TASK_AM_HEARTBEAT_INTERVAL_MS,
TezConfiguration.TEZ_TASK_AM_HEARTBEAT_INTERVAL_MS_DEFAULT);
sendCounterInterval = defaultConf.getLong(
TezConfiguration.TEZ_TASK_AM_HEARTBEAT_COUNTER_INTERVAL_MS,
TezConfiguration.TEZ_TASK_AM_HEARTBEAT_COUNTER_INTERVAL_MS_DEFAULT);
maxEventsToGet = defaultConf.getInt(TezConfiguration.TEZ_TASK_MAX_EVENTS_PER_HEARTBEAT,
TezConfiguration.TEZ_TASK_MAX_EVENTS_PER_HEARTBEAT_DEFAULT);
address = NetUtils.createSocketAddrForHost(host, port);
ExecutorService executor = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder()
.setDaemon(true).setNameFormat("TezChild").build());
this.executor = MoreExecutors.listeningDecorator(executor);
this.objectRegistry = objectRegistry;
// Security framework already loaded the tokens into current ugi
Credentials credentials = UserGroupInformation.getCurrentUser().getCredentials();
if (LOG.isDebugEnabled()) {
LOG.debug("Executing with tokens:");
for (Token<?> token : credentials.getAllTokens()) {
LOG.debug(token);
}
}
UserGroupInformation taskOwner = UserGroupInformation.createRemoteUser(tokenIdentifier);
Token<JobTokenIdentifier> jobToken = TokenCache.getSessionToken(credentials);
SecurityUtil.setTokenService(jobToken, address);
taskOwner.addToken(jobToken);
serviceConsumerMetadata.put(ShuffleUtils.SHUFFLE_HANDLER_SERVICE_ID,
ShuffleUtils.convertJobTokenToBytes(jobToken));
umbilical = taskOwner.doAs(new PrivilegedExceptionAction<TezTaskUmbilicalProtocol>() {
@Override
public TezTaskUmbilicalProtocol run() throws Exception {
return (TezTaskUmbilicalProtocol) RPC.getProxy(TezTaskUmbilicalProtocol.class,
TezTaskUmbilicalProtocol.versionID, address, defaultConf);
}
});
}
void run() throws IOException, InterruptedException, TezException {
ContainerContext containerContext = new ContainerContext(containerIdString);
ContainerReporter containerReporter = new ContainerReporter(umbilical, containerContext,
getTaskMaxSleepTime);
taskReporter = new TaskReporter(umbilical, amHeartbeatInterval,
sendCounterInterval, maxEventsToGet, heartbeatCounter, containerIdString);
UserGroupInformation childUGI = null;
while (!executor.isTerminated()) {
if (taskCount > 0) {
TezUtils.updateLoggers("");
}
ListenableFuture<ContainerTask> getTaskFuture = executor.submit(containerReporter);
ContainerTask containerTask = null;
try {
containerTask = getTaskFuture.get();
} catch (ExecutionException e) {
Throwable cause = e.getCause();
handleError(cause);
return;
} catch (InterruptedException e) {
LOG.info("Interrupted while waiting for task to complete:"
+ containerTask.getTaskSpec().getTaskAttemptID());
handleError(e);
return;
}
if (containerTask.shouldDie()) {
LOG.info("ContainerTask returned shouldDie=true, Exiting");
shutdown();
return;
} else {
String loggerAddend = containerTask.getTaskSpec().getTaskAttemptID().toString();
taskCount++;
TezUtils.updateLoggers(loggerAddend);
FileSystem.clearStatistics();
childUGI = handleNewTaskCredentials(containerTask, childUGI);
handleNewTaskLocalResources(containerTask);
cleanupOnTaskChanged(containerTask);
// Execute the Actual Task
TezTaskRunner taskRunner = new TezTaskRunner(new TezConfiguration(defaultConf), childUGI,
localDirs, containerTask.getTaskSpec(), umbilical, appAttemptNumber,
serviceConsumerMetadata, startedInputsMap, taskReporter, executor);
boolean shouldDie = false;
try {
shouldDie = !taskRunner.run();
if (shouldDie) {
LOG.info("Got a shouldDie notification via hearbeats. Shutting down");
shutdown();
return;
}
} catch (IOException e) {
handleError(e);
return;
} catch (TezException e) {
handleError(e);
return;
} finally {
FileSystem.closeAllForUGI(childUGI);
}
}
}
}
/**
* Setup
*
* @param containerTask
* the new task specification. Must be a valid task
* @param childUGI
* the old UGI instance being used
* @return
*/
UserGroupInformation handleNewTaskCredentials(ContainerTask containerTask,
UserGroupInformation childUGI) {
// Re-use the UGI only if the Credentials have not changed.
Preconditions.checkState(containerTask.shouldDie() != true);
Preconditions.checkState(containerTask.getTaskSpec() != null);
if (containerTask.haveCredentialsChanged()) {
LOG.info("Refreshing UGI since Credentials have changed");
Credentials taskCreds = containerTask.getCredentials();
if (taskCreds != null) {
LOG.info("Credentials : #Tokens=" + taskCreds.numberOfTokens() + ", #SecretKeys="
+ taskCreds.numberOfSecretKeys());
childUGI = UserGroupInformation.createRemoteUser(System
.getenv(ApplicationConstants.Environment.USER.toString()));
childUGI.addCredentials(containerTask.getCredentials());
} else {
LOG.info("Not loading any credentials, since no credentials provided");
}
}
return childUGI;
}
/**
* Handles any additional resources to be localized for the new task
*
* @param containerTask
* @throws IOException
* @throws TezException
*/
private void handleNewTaskLocalResources(ContainerTask containerTask) throws IOException,
TezException {
Map<String, TezLocalResource> additionalResources = containerTask.getAdditionalResources();
if (LOG.isDebugEnabled()) {
LOG.debug("Additional Resources added to container: " + additionalResources);
}
LOG.info("Localizing additional local resources for Task : " + additionalResources);
List<URL> downloadedUrls = RelocalizationUtils.processAdditionalResources(
Maps.transformValues(additionalResources, new Function<TezLocalResource, URI>() {
@Override
public URI apply(TezLocalResource input) {
return input.getUri();
}
}), defaultConf);
RelocalizationUtils.addUrlsToClassPath(downloadedUrls);
LOG.info("Done localizing additional resources");
final TaskSpec taskSpec = containerTask.getTaskSpec();
if (LOG.isDebugEnabled()) {
LOG.debug("New container task context:" + taskSpec.toString());
}
}
/**
* Cleans entries from the object registry, and resets the startedInputsMap if required
*
* @param containerTask
* the new task specification. Must be a valid task
*/
private void cleanupOnTaskChanged(ContainerTask containerTask) {
Preconditions.checkState(containerTask.shouldDie() != true);
Preconditions.checkState(containerTask.getTaskSpec() != null);
TezVertexID newVertexID = containerTask.getTaskSpec().getTaskAttemptID().getTaskID()
.getVertexID();
if (lastVertexID != null) {
if (!lastVertexID.equals(newVertexID)) {
objectRegistry.clearCache(ObjectLifeCycle.VERTEX);
}
if (!lastVertexID.getDAGId().equals(newVertexID.getDAGId())) {
objectRegistry.clearCache(ObjectLifeCycle.DAG);
startedInputsMap = HashMultimap.create();
}
}
lastVertexID = newVertexID;
}
private void shutdown() {
executor.shutdownNow();
if (taskReporter != null) {
taskReporter.shutdown();
}
RPC.stopProxy(umbilical);
DefaultMetricsSystem.shutdown();
LogManager.shutdown();
}
public static void main(String[] args) throws IOException, InterruptedException, TezException {
Thread.setDefaultUncaughtExceptionHandler(new YarnUncaughtExceptionHandler());
LOG.info("TezChild starting");
final Configuration defaultConf = new Configuration();
// Pull in configuration specified for the session.
// TODO TEZ-1233. This needs to be moved over the wire rather than localizing the file
// for each and every task, and reading it back from disk. Also needs to be per vertex.
TezUtils.addUserSpecifiedTezConfiguration(defaultConf);
UserGroupInformation.setConfiguration(defaultConf);
Limits.setConfiguration(defaultConf);
assert args.length == 5;
String host = args[0];
int port = Integer.parseInt(args[1]);
final String containerIdentifier = args[2];
final String tokenIdentifier = args[3];
final int attemptNumber = Integer.parseInt(args[4]);
final String pid = System.getenv().get("JVM_PID");
final String[] localDirs = StringUtils.getTrimmedStrings(System.getenv(Environment.LOCAL_DIRS
.name()));
LOG.info("PID, containerIdentifier: " + pid + ", " + containerIdentifier);
if (LOG.isDebugEnabled()) {
LOG.debug("Info from cmd line: AM-host: " + host + " AM-port: " + port
+ " containerIdentifier: " + containerIdentifier + " appAttemptNumber: " + attemptNumber
+ " tokenIdentifier: " + tokenIdentifier);
}
// Should this be part of main - Metrics and ObjectRegistry. TezTask setup should be independent
// of this class. Leaving it here, till there's some entity representing a running JVM.
DefaultMetricsSystem.initialize("TezTask");
ObjectRegistryImpl objectRegistry = new ObjectRegistryImpl();
@SuppressWarnings("unused")
Injector injector = Guice.createInjector(new ObjectRegistryModule(objectRegistry));
TezChild tezChild = new TezChild(defaultConf, host, port, containerIdentifier, tokenIdentifier,
attemptNumber, localDirs, objectRegistry);
tezChild.run();
}
private void handleError(Throwable t) {
shutdown();
}
}