/*
* Copyright 2013 Goldman Sachs.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.gs.collections.impl.forkjoin;
import java.io.Serializable;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import com.gs.collections.api.block.function.Function;
import com.gs.collections.api.block.procedure.primitive.ObjectIntProcedure;
import com.gs.collections.impl.list.mutable.FastList;
import com.gs.collections.impl.parallel.Combiner;
import com.gs.collections.impl.parallel.ObjectIntProcedureFactory;
public class FJListObjectIntProcedureRunner<T, PT extends ObjectIntProcedure<? super T>> implements Serializable
{
private static final long serialVersionUID = 1L;
private Throwable error;
private final Combiner<PT> combiner;
private final int taskCount;
private final BlockingQueue<PT> outputQueue;
public FJListObjectIntProcedureRunner(Combiner<PT> newCombiner, int taskCount)
{
this.combiner = newCombiner;
this.taskCount = taskCount;
this.outputQueue = this.combiner.useCombineOne() ? new ArrayBlockingQueue<PT>(taskCount) : null;
}
private FastList<ForkJoinTask<PT>> createAndExecuteTasks(ForkJoinPool executor, ObjectIntProcedureFactory<PT> procedureFactory, List<T> list)
{
FastList<ForkJoinTask<PT>> tasks = FastList.newList(this.taskCount);
int sectionSize = list.size() / this.taskCount;
int taskCountMinusOne = this.taskCount - 1;
for (int index = 0; index < this.taskCount; index++)
{
ForkJoinTask<PT> task = this.createTask(procedureFactory, list, sectionSize, taskCountMinusOne, index);
tasks.add(task);
executor.execute(task);
}
return tasks;
}
protected FJListObjectIntProcedureTask<T, PT> createTask(ObjectIntProcedureFactory<PT> procedureFactory, List<T> list, int sectionSize, int taskCountMinusOne, int index)
{
return new FJListObjectIntProcedureTask<>(this, procedureFactory, list, index, sectionSize, index == taskCountMinusOne);
}
public void setFailed(Throwable newError)
{
this.error = newError;
}
public void taskCompleted(ForkJoinTask<PT> task)
{
if (this.combiner.useCombineOne())
{
this.outputQueue.add(task.getRawResult());
}
}
public void executeAndCombine(ForkJoinPool executor, ObjectIntProcedureFactory<PT> procedureFactory, List<T> list)
{
FastList<ForkJoinTask<PT>> tasks = this.createAndExecuteTasks(executor, procedureFactory, list);
if (this.combiner.useCombineOne())
{
this.join();
}
if (this.error != null)
{
throw new RuntimeException("One or more parallel tasks failed", this.error);
}
if (!this.combiner.useCombineOne())
{
this.combiner.combineAll(tasks.asLazy().collect(new ProcedureExtractor()));
}
}
private void join()
{
try
{
int remainingTaskCount = this.taskCount;
while (remainingTaskCount > 0)
{
this.combiner.combineOne(this.outputQueue.take());
remainingTaskCount--;
}
}
catch (InterruptedException e)
{
throw new RuntimeException("Combine failed", e);
}
}
private final class ProcedureExtractor implements Function<ForkJoinTask<PT>, PT>
{
private static final long serialVersionUID = 1L;
@Override
public PT valueOf(ForkJoinTask<PT> object)
{
try
{
return object.get();
}
catch (InterruptedException | ExecutionException e)
{
throw new RuntimeException(e);
}
}
}
}