package org.shanbo.feluca.distribute.model.vertical; import gnu.trove.list.array.TFloatArrayList; import java.util.HashMap; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.locks.ReentrantLock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class FloatReducerImpl implements FloatReducer{ static Logger log = LoggerFactory.getLogger(FloatReducerImpl.class); static class ReducePreparation{ int shardId; float[] values; public ReducePreparation(int shardId, float[] values){ this.shardId = shardId; this.values = values; } } abstract static class ReduceProcessor{ ReducePreparation[] toReduce ; TFloatArrayList accValues; volatile float[] results ; CyclicBarrier enterBarrier ; CyclicBarrier leaveBarrier ; volatile boolean reduceDone; ReentrantLock lock = new ReentrantLock(); public ReduceProcessor(int total){ accValues = new TFloatArrayList(); enterBarrier = new CyclicBarrier(total); leaveBarrier = new CyclicBarrier(total); toReduce = new ReducePreparation[total]; } public void prepare(int clientId, float[] values) throws InterruptedException{ toReduce[clientId] = new ReducePreparation(clientId, values); } public void doReduce(){ lock.lock(); try{ if (reduceDone == false){ accValues.resetQuick(); processValues(); results = accValues.toArray(); reduceDone = true; enterBarrier.reset(); leaveBarrier.reset(); } }finally{ lock.unlock(); } } public abstract void processValues() ; public float[] getResult() throws InterruptedException, BrokenBarrierException{ leaveBarrier.await(); reduceDone = false; //it's ok; because fastest thread still have to wait for others to enter the doReduce() return results; } public void waitForOther() throws InterruptedException, BrokenBarrierException{ enterBarrier.await(); } } HashMap<String, ReduceProcessor> reducers; public FloatReducerImpl(int totalClients){ reducers = new HashMap<String, FloatReducerImpl.ReduceProcessor>(6); reducers.put("sum", new SumReducer(totalClients)); reducers.put("avg", new AvgReducer(totalClients)); reducers.put("max", new MaxReducer(totalClients)); reducers.put("min", new MinReducer(totalClients)); } public float[] reduce(String name, int clientId, float[] values){ ReduceProcessor processor = reducers.get(name); if (processor == null) return null; try { processor.prepare(clientId, values); processor.waitForOther(); processor.doReduce(); return processor.getResult(); } catch (Exception e) { e.printStackTrace(); return null; } } public String getName() { return "floatReducer"; } public static class SumReducer extends ReduceProcessor{ public SumReducer(int total) { super(total); } @Override public void processValues() { for(int l = 0 ; l < toReduce[0].values.length ; l++){ //each input float accValue = toReduce[0].values[l]; for(int i = 1 ; i < toReduce.length; i++){ //merge from other accValue += toReduce[i].values[l]; } accValues.add(accValue); } } } public static class AvgReducer extends ReduceProcessor{ public AvgReducer(int total) { super(total); } @Override public void processValues() { for(int l = 0 ; l < toReduce[0].values.length ; l++){ //each input float accValue = toReduce[0].values[l]; for(int i = 1 ; i < toReduce.length; i++){ //merge from other accValue += toReduce[i].values[l]; } accValues.add( accValue / toReduce.length); } } } public static class MinReducer extends ReduceProcessor{ public MinReducer(int total) { super(total); } @Override public void processValues() { for(int l = 0 ; l < toReduce[0].values.length ; l++){ //each input float accValue = toReduce[0].values[l]; for(int i = 1 ; i < toReduce.length; i++){ //merge from other accValue = Math.min(accValue, toReduce[i].values[l]); } accValues.add(accValue); } } } public static class MaxReducer extends ReduceProcessor{ public MaxReducer(int total) { super(total); } @Override public void processValues() { for(int l = 0 ; l < toReduce[0].values.length ; l++){ //each input float accValue = toReduce[0].values[l]; for(int i = 1 ; i < toReduce.length; i++){ //merge from other accValue = Math.max(accValue, toReduce[i].values[l]); } accValues.add(accValue); } } } }