/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package gobblin.runtime.task; import java.util.List; import java.util.concurrent.CountDownLatch; import com.google.common.base.Optional; import com.google.common.collect.Lists; import gobblin.configuration.State; import gobblin.configuration.WorkUnitState; import gobblin.runtime.Task; import gobblin.runtime.TaskContext; import gobblin.runtime.TaskState; import gobblin.runtime.TaskStateTracker; import gobblin.runtime.fork.Fork; /** * An extension of {@link Task} that wraps a generic {@link TaskIFace} for backwards compatibility. */ public class TaskIFaceWrapper extends Task { private final TaskIFace underlyingTask; private final TaskContext taskContext; private final String jobId; private final String taskId; private final CountDownLatch countDownLatch; private final TaskStateTracker taskStateTracker; private int retryCount = 0; public TaskIFaceWrapper(TaskIFace underlyingTask, TaskContext taskContext, CountDownLatch countDownLatch, TaskStateTracker taskStateTracker) { super(); this.underlyingTask = underlyingTask; this.taskContext = taskContext; this.jobId = taskContext.getTaskState().getJobId(); this.taskId = taskContext.getTaskState().getTaskId(); this.countDownLatch = countDownLatch; this.taskStateTracker = taskStateTracker; } @Override public boolean awaitShutdown(long timeoutInMillis) throws InterruptedException { return this.underlyingTask.awaitShutdown(timeoutInMillis); } @Override public void shutdown() { this.underlyingTask.shutdown(); } @Override public String getProgress() { return this.underlyingTask.getProgress(); } @Override public void run() { this.underlyingTask.run(); this.taskStateTracker.onTaskRunCompletion(this); } @Override public String getJobId() { return this.jobId; } @Override public String getTaskId() { return this.taskId; } @Override public TaskContext getTaskContext() { return this.taskContext; } @Override public TaskState getTaskState() { TaskState taskState = this.taskContext.getTaskState(); taskState.setTaskId(getTaskId()); taskState.setJobId(getJobId()); taskState.setWorkingState(getWorkingState()); taskState.addAll(getExecutionMetadata()); taskState.addAll(getPersistentState()); return taskState; } @Override public State getPersistentState() { return this.underlyingTask.getPersistentState(); } @Override public State getExecutionMetadata() { return this.underlyingTask.getExecutionMetadata(); } @Override public WorkUnitState.WorkingState getWorkingState() { return this.underlyingTask.getWorkingState(); } @Override public List<Optional<Fork>> getForks() { return Lists.newArrayList(); } @Override public void updateRecordMetrics() { } @Override public void updateByteMetrics() { } @Override public void incrementRetryCount() { this.retryCount++; } @Override public int getRetryCount() { return this.retryCount; } @Override public void markTaskCompletion() { if (this.countDownLatch != null) { this.countDownLatch.countDown(); } } @Override public String toString() { return this.underlyingTask.toString(); } @Override public void commit() { this.underlyingTask.commit(); this.taskStateTracker.onTaskCommitCompletion(this); } @Override protected void submitTaskCommittedEvent() { } @Override public boolean isSpeculativeExecutionSafe() { return this.underlyingTask.isSpeculativeExecutionSafe(); } }