package org.apache.pig.backend.hadoop.executionengine.spark.converter; import java.io.IOException; import java.util.Iterator; import java.util.List; import org.apache.pig.backend.executionengine.ExecException; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POCollectedGroup; import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil; import org.apache.pig.data.Tuple; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.rdd.RDD; @SuppressWarnings({ "serial"}) public class CollectedGroupConverter implements POConverter<Tuple, Tuple, POCollectedGroup> { @Override public RDD<Tuple> convert(List<RDD<Tuple>> predecessors, POCollectedGroup physicalOperator) throws IOException { SparkUtil.assertPredecessorSize(predecessors, physicalOperator, 1); RDD<Tuple> rdd = predecessors.get(0); // return predecessors.get(0); RDD<Tuple> rdd2 = rdd.coalesce(1, false, null); long count = 0; try { count = rdd2.count(); } catch (Exception e) { } CollectedGroupFunction collectedGroupFunction = new CollectedGroupFunction(physicalOperator, count); return rdd.toJavaRDD().mapPartitions(collectedGroupFunction, true).rdd(); } private static class CollectedGroupFunction implements FlatMapFunction<Iterator<Tuple>, Tuple> { /** * */ private POCollectedGroup poCollectedGroup; public long total_limit; public long current_val; public boolean proceed; private CollectedGroupFunction(POCollectedGroup poCollectedGroup, long count) { this.poCollectedGroup = poCollectedGroup; this.total_limit = count; this.current_val = 0; } public Iterable<Tuple> call(final Iterator<Tuple> input) { return new Iterable<Tuple>() { @Override public Iterator<Tuple> iterator() { return new POOutputConsumerIterator(input) { protected void attach(Tuple tuple) { poCollectedGroup.setInputs(null); poCollectedGroup.attachInput(tuple); poCollectedGroup.setParentPlan(poCollectedGroup.getPlans().get(0)); try{ current_val = current_val + 1; //System.out.println("Row: =>" + current_val); if (current_val == total_limit) { proceed = true; } else { proceed = false; } } catch(Exception e){ System.out.println("Crashhh in CollectedGroupConverter :" + e); e.printStackTrace(); } } protected Result getNextResult() throws ExecException { return poCollectedGroup.getNextTuple(proceed); } }; } }; } } }