/*
* Copyright 2013 Goldman Sachs.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.gs.collections.impl.forkjoin;
import java.io.Serializable;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import com.gs.collections.api.block.function.Function;
import com.gs.collections.api.block.procedure.Procedure;
import com.gs.collections.impl.list.mutable.FastList;
import com.gs.collections.impl.parallel.BatchIterable;
import com.gs.collections.impl.parallel.Combiner;
import com.gs.collections.impl.parallel.ProcedureFactory;
public class FJBatchIterableProcedureRunner<T, PT extends Procedure<? super T>> implements Serializable
{
private static final long serialVersionUID = 1L;
private Throwable error;
private final Combiner<PT> combiner;
private final int taskCount;
private final BlockingQueue<PT> outputQueue;
public FJBatchIterableProcedureRunner(Combiner<PT> newCombiner, int taskCount)
{
this.combiner = newCombiner;
this.taskCount = taskCount;
this.outputQueue = this.combiner.useCombineOne() ? new ArrayBlockingQueue<PT>(taskCount) : null;
}
private FastList<ForkJoinTask<PT>> createAndExecuteTasks(ForkJoinPool executor, ProcedureFactory<PT> procedureFactory, BatchIterable<T> iterable)
{
FastList<ForkJoinTask<PT>> tasks = FastList.newList(this.taskCount);
for (int index = 0; index < this.taskCount; index++)
{
ForkJoinTask<PT> voidBlockFJTask = new FJBatchIterableProcedureTask<>(this, procedureFactory, iterable, index, this.taskCount);
tasks.add(voidBlockFJTask);
executor.execute(voidBlockFJTask);
}
return tasks;
}
public void setFailed(Throwable newError)
{
this.error = newError;
}
public void taskCompleted(ForkJoinTask<PT> task)
{
if (this.combiner.useCombineOne())
{
this.outputQueue.add(task.getRawResult());
}
}
public void executeAndCombine(ForkJoinPool executor, ProcedureFactory<PT> procedureFactory, BatchIterable<T> list)
{
FastList<ForkJoinTask<PT>> tasks = this.createAndExecuteTasks(executor, procedureFactory, list);
if (this.combiner.useCombineOne())
{
this.join();
}
if (this.error != null)
{
throw new RuntimeException("One or more parallel tasks failed", this.error);
}
if (!this.combiner.useCombineOne())
{
this.combiner.combineAll(tasks.asLazy().collect(new ProcedureExtractor()));
}
}
private void join()
{
try
{
int remaingTaskCount = this.taskCount;
while (remaingTaskCount > 0)
{
this.combiner.combineOne(this.outputQueue.take());
remaingTaskCount--;
}
}
catch (InterruptedException e)
{
throw new RuntimeException("Combine failed", e);
}
}
private final class ProcedureExtractor implements Function<ForkJoinTask<PT>, PT>
{
private static final long serialVersionUID = 1L;
@Override
public PT valueOf(ForkJoinTask<PT> object)
{
try
{
return object.get();
}
catch (InterruptedException | ExecutionException e)
{
throw new RuntimeException(e);
}
}
}
}