/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.exec; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.mr.MapRedTask; import org.apache.hadoop.hive.ql.exec.mr.MapredLocalTask; import org.apache.hadoop.hive.ql.exec.spark.SparkTask; import org.apache.hadoop.hive.ql.exec.tez.TezTask; import org.apache.hadoop.hive.ql.index.IndexMetadataChangeTask; import org.apache.hadoop.hive.ql.index.IndexMetadataChangeWork; import org.apache.hadoop.hive.ql.io.merge.MergeFileTask; import org.apache.hadoop.hive.ql.io.merge.MergeFileWork; import org.apache.hadoop.hive.ql.io.rcfile.stats.PartialScanTask; import org.apache.hadoop.hive.ql.io.rcfile.stats.PartialScanWork; import org.apache.hadoop.hive.ql.plan.ColumnStatsUpdateWork; import org.apache.hadoop.hive.ql.plan.ColumnStatsWork; import org.apache.hadoop.hive.ql.plan.ConditionalWork; import org.apache.hadoop.hive.ql.plan.CopyWork; import org.apache.hadoop.hive.ql.plan.DDLWork; import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork; import org.apache.hadoop.hive.ql.plan.ExplainSQRewriteWork; import org.apache.hadoop.hive.ql.plan.ExplainWork; import org.apache.hadoop.hive.ql.plan.FetchWork; import org.apache.hadoop.hive.ql.plan.FunctionWork; import org.apache.hadoop.hive.ql.plan.MapredLocalWork; import org.apache.hadoop.hive.ql.plan.MapredWork; import org.apache.hadoop.hive.ql.plan.MoveWork; import org.apache.hadoop.hive.ql.plan.ReplCopyWork; import org.apache.hadoop.hive.ql.plan.SparkWork; import org.apache.hadoop.hive.ql.plan.StatsNoJobWork; import org.apache.hadoop.hive.ql.plan.StatsWork; import org.apache.hadoop.hive.ql.plan.TezWork; /** * TaskFactory implementation. **/ public final class TaskFactory { /** * taskTuple. * * @param <T> */ public static final class TaskTuple<T extends Serializable> { public Class<T> workClass; public Class<? extends Task<T>> taskClass; public TaskTuple(Class<T> workClass, Class<? extends Task<T>> taskClass) { this.workClass = workClass; this.taskClass = taskClass; } } public static ArrayList<TaskTuple<? extends Serializable>> taskvec; static { taskvec = new ArrayList<TaskTuple<? extends Serializable>>(); taskvec.add(new TaskTuple<MoveWork>(MoveWork.class, MoveTask.class)); taskvec.add(new TaskTuple<FetchWork>(FetchWork.class, FetchTask.class)); taskvec.add(new TaskTuple<CopyWork>(CopyWork.class, CopyTask.class)); taskvec.add(new TaskTuple<ReplCopyWork>(ReplCopyWork.class, ReplCopyTask.class)); taskvec.add(new TaskTuple<DDLWork>(DDLWork.class, DDLTask.class)); taskvec.add(new TaskTuple<FunctionWork>(FunctionWork.class, FunctionTask.class)); taskvec .add(new TaskTuple<ExplainWork>(ExplainWork.class, ExplainTask.class)); taskvec .add(new TaskTuple<ExplainSQRewriteWork>(ExplainSQRewriteWork.class, ExplainSQRewriteTask.class)); taskvec.add(new TaskTuple<ConditionalWork>(ConditionalWork.class, ConditionalTask.class)); taskvec.add(new TaskTuple<MapredWork>(MapredWork.class, MapRedTask.class)); taskvec.add(new TaskTuple<MapredLocalWork>(MapredLocalWork.class, MapredLocalTask.class)); taskvec.add(new TaskTuple<StatsWork>(StatsWork.class, StatsTask.class)); taskvec.add(new TaskTuple<StatsNoJobWork>(StatsNoJobWork.class, StatsNoJobTask.class)); taskvec.add(new TaskTuple<ColumnStatsWork>(ColumnStatsWork.class, ColumnStatsTask.class)); taskvec.add(new TaskTuple<ColumnStatsUpdateWork>(ColumnStatsUpdateWork.class, ColumnStatsUpdateTask.class)); taskvec.add(new TaskTuple<MergeFileWork>(MergeFileWork.class, MergeFileTask.class)); taskvec.add(new TaskTuple<DependencyCollectionWork>(DependencyCollectionWork.class, DependencyCollectionTask.class)); taskvec.add(new TaskTuple<PartialScanWork>(PartialScanWork.class, PartialScanTask.class)); taskvec.add(new TaskTuple<IndexMetadataChangeWork>(IndexMetadataChangeWork.class, IndexMetadataChangeTask.class)); taskvec.add(new TaskTuple<TezWork>(TezWork.class, TezTask.class)); taskvec.add(new TaskTuple<SparkWork>(SparkWork.class, SparkTask.class)); } private static ThreadLocal<Integer> tid = new ThreadLocal<Integer>() { @Override protected Integer initialValue() { return Integer.valueOf(0); } }; public static int getAndIncrementId() { int curValue = tid.get().intValue(); tid.set(new Integer(curValue + 1)); return curValue; } public static void resetId() { tid.set(Integer.valueOf(0)); } @SuppressWarnings("unchecked") public static <T extends Serializable> Task<T> get(Class<T> workClass, HiveConf conf) { for (TaskTuple<? extends Serializable> t : taskvec) { if (t.workClass == workClass) { try { Task<T> ret = (Task<T>) t.taskClass.newInstance(); ret.setId("Stage-" + Integer.toString(getAndIncrementId())); return ret; } catch (Exception e) { throw new RuntimeException(e); } } } throw new RuntimeException("No task for work class " + workClass.getName()); } public static <T extends Serializable> Task<T> get(T work, HiveConf conf, Task<? extends Serializable>... tasklist) { Task<T> ret = get((Class<T>) work.getClass(), conf); ret.setWork(work); if (tasklist.length == 0) { return (ret); } ArrayList<Task<? extends Serializable>> clist = new ArrayList<Task<? extends Serializable>>(); for (Task<? extends Serializable> tsk : tasklist) { clist.add(tsk); } ret.setChildTasks(clist); return (ret); } public static <T extends Serializable> Task<T> getAndMakeChild(T work, HiveConf conf, Task<? extends Serializable>... tasklist) { Task<T> ret = get((Class<T>) work.getClass(), conf); ret.setWork(work); if (tasklist.length == 0) { return (ret); } makeChild(ret, tasklist); return (ret); } public static void makeChild(Task<?> ret, Task<? extends Serializable>... tasklist) { // Add the new task as child of each of the passed in tasks for (Task<? extends Serializable> tsk : tasklist) { List<Task<? extends Serializable>> children = tsk.getChildTasks(); if (children == null) { children = new ArrayList<Task<? extends Serializable>>(); } children.add(ret); tsk.setChildTasks(children); } } private TaskFactory() { // prevent instantiation } }