/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.raptor.util;
import com.google.common.collect.ComparisonChain;
import com.google.common.util.concurrent.ExecutionList;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import javax.annotation.concurrent.ThreadSafe;
import java.util.Comparator;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.FutureTask;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
/**
* This class is based on io.airlift.concurrent.BoundedExecutor
*/
@ThreadSafe
public class PrioritizedFifoExecutor<T extends Runnable>
{
private static final Logger log = Logger.get(PrioritizedFifoExecutor.class);
private final Queue<FifoRunnableTask<T>> queue;
private final AtomicInteger queueSize = new AtomicInteger(0);
private final AtomicLong sequenceNumber = new AtomicLong(0);
private final Runnable triggerTask = this::executeOrMerge;
private final ExecutorService executorService;
private final int maxThreads;
private final Comparator<T> taskComparator;
public PrioritizedFifoExecutor(ExecutorService coreExecutor, int maxThreads, Comparator<T> taskComparator)
{
checkArgument(maxThreads > 0, "maxThreads must be greater than zero");
this.taskComparator = requireNonNull(taskComparator, "taskComparator is null");
this.executorService = requireNonNull(coreExecutor, "coreExecutor is null");
this.maxThreads = maxThreads;
this.queue = new PriorityBlockingQueue<>(maxThreads);
}
public ListenableFuture<?> submit(T task)
{
FifoRunnableTask<T> fifoTask = new FifoRunnableTask<>(task, sequenceNumber.incrementAndGet(), taskComparator);
queue.add(fifoTask);
executorService.submit(triggerTask);
return fifoTask;
}
private void executeOrMerge()
{
int size = queueSize.incrementAndGet();
if (size > maxThreads) {
return;
}
do {
try {
queue.poll().run();
}
catch (Throwable e) {
log.error(e, "Task failed");
}
}
while (queueSize.getAndDecrement() > maxThreads);
}
private static class FifoRunnableTask<T extends Runnable>
extends FutureTask<Void>
implements ListenableFuture<Void>, Comparable<FifoRunnableTask<T>>
{
private final ExecutionList executionList = new ExecutionList();
private final T task;
private final long sequenceNumber;
private final Comparator<T> taskComparator;
public FifoRunnableTask(T task, long sequenceNumber, Comparator<T> taskComparator)
{
super(requireNonNull(task, "task is null"), null);
this.task = task;
this.sequenceNumber = sequenceNumber;
this.taskComparator = requireNonNull(taskComparator, "taskComparator is null");
}
@Override
public void addListener(Runnable listener, Executor executor)
{
executionList.add(listener, executor);
}
@Override
protected void done()
{
executionList.execute();
}
@Override
public int compareTo(FifoRunnableTask<T> other)
{
return ComparisonChain.start()
.compare(this.task, other.task, taskComparator)
.compare(this.sequenceNumber, other.sequenceNumber)
.result();
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FifoRunnableTask<?> other = (FifoRunnableTask<?>) o;
return Objects.equals(this.task, other.task) &&
Objects.equals(this.sequenceNumber, other.sequenceNumber);
}
@Override
public int hashCode()
{
return Objects.hash(task, sequenceNumber);
}
}
}