package org.apache.pig.backend.hadoop.executionengine.spark.converter; import java.io.IOException; import java.io.Serializable; import java.util.Iterator; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.pig.backend.executionengine.ExecException; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POGlobalRearrange; import org.apache.pig.backend.hadoop.executionengine.spark.ScalaUtil; import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil; import org.apache.pig.data.Tuple; import org.apache.pig.data.TupleFactory; import org.apache.spark.HashPartitioner; import org.apache.spark.rdd.CoGroupedRDD; import org.apache.spark.rdd.RDD; import scala.Tuple2; import scala.collection.JavaConversions; import scala.collection.Seq; import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; import com.google.common.collect.Lists; @SuppressWarnings({ "serial" }) public class GlobalRearrangeConverter implements POConverter<Tuple, Tuple, POGlobalRearrange> { private static final Log LOG = LogFactory.getLog(GlobalRearrangeConverter.class); private static final TupleFactory tf = TupleFactory.getInstance(); // GROUP FUNCTIONS private static final ToKeyValueFunction TO_KEY_VALUE_FUNCTION = new ToKeyValueFunction(); private static final GetKeyFunction GET_KEY_FUNCTION = new GetKeyFunction(); // COGROUP FUNCTIONS private static final GroupTupleFunction GROUP_TUPLE_FUNCTION = new GroupTupleFunction(); private static final ToGroupKeyValueFunction TO_GROUP_KEY_VALUE_FUNCTION = new ToGroupKeyValueFunction(); @SuppressWarnings({ "unchecked", "rawtypes" }) @Override public RDD<Tuple> convert(List<RDD<Tuple>> predecessors, POGlobalRearrange physicalOperator) throws IOException { SparkUtil.assertPredecessorSizeGreaterThan(predecessors, physicalOperator, 0); int parallelism = SparkUtil.getParallelism(predecessors, physicalOperator); if (LOG.isDebugEnabled()) { LOG.info("Parallelism for Spark groupBy: " + parallelism); } if (predecessors.size() == 1) { //GROUP return predecessors.get(0) // group by key .groupBy(GET_KEY_FUNCTION, parallelism, ScalaUtil.getClassTag(Object.class)) // convert result to a tuple (key, { values }) .map(GROUP_TUPLE_FUNCTION, ScalaUtil.getClassTag(Tuple.class)); } else { // COGROUP // each pred returns (index, key, value) ClassTag<Tuple2<Object, Tuple>> tuple2ClassTag = ScalaUtil.<Object, Tuple>getTuple2ClassTag(); List<RDD<Tuple2<Object, Tuple>>> rddPairs = Lists.newArrayList(); for (RDD<Tuple> rdd : predecessors) { // (key, value) RDD<Tuple2<Object, Tuple>> rddPair = rdd.map(TO_KEY_VALUE_FUNCTION, tuple2ClassTag); rddPairs.add(rddPair); } CoGroupedRDD<Object> coGroupedRDD = new CoGroupedRDD<Object>( (Seq) JavaConversions.asScalaBuffer(rddPairs), new HashPartitioner(parallelism)); RDD<Tuple2<Object, Seq<Seq<Tuple>>>> rdd = (RDD<Tuple2<Object, Seq<Seq<Tuple>>>>)(Object) coGroupedRDD; return rdd.map(TO_GROUP_KEY_VALUE_FUNCTION, ScalaUtil.getClassTag(Tuple.class)); } } private static class GetKeyFunction extends AbstractFunction1<Tuple, Object> implements Serializable { @Override public Object apply(Tuple t) { try { LOG.debug("GetKeyFunction in " + t); // see PigGenericMapReduce For the key Object key = t.get(1); LOG.debug("GetKeyFunction out " + key); return key; } catch (ExecException e) { throw new RuntimeException(e); } } } private static class GroupTupleFunction extends AbstractFunction1<Tuple2<Object, Seq<Tuple>>, Tuple> implements Serializable { @Override public Tuple apply(Tuple2<Object, Seq<Tuple>> v1) { try { LOG.debug("GroupTupleFunction in " + v1); Tuple tuple = tf.newTuple(2); tuple.set(0, v1._1()); // the (index, key) tuple tuple.set(1, JavaConversions.asJavaCollection(v1._2()).iterator()); // the Seq<Tuple> aka bag of values LOG.debug("GroupTupleFunction out " + tuple); return tuple; } catch (ExecException e) { throw new RuntimeException(e); } } } private static class ToKeyValueFunction extends AbstractFunction1<Tuple,Tuple2<Object, Tuple>> implements Serializable { @Override public Tuple2<Object, Tuple> apply(Tuple t) { try { // (index, key, value) LOG.debug("ToKeyValueFunction in " + t); Object key = t.get(1); Tuple value = (Tuple) t.get(2); // (key, value) Tuple2<Object, Tuple> out = new Tuple2<Object, Tuple>(key, value); LOG.debug("ToKeyValueFunction out " + out); return out; } catch (ExecException e) { throw new RuntimeException(e); } } } private static class ToGroupKeyValueFunction extends AbstractFunction1<Tuple2<Object,Seq<Seq<Tuple>>>,Tuple> implements Serializable { @Override public Tuple apply(Tuple2<Object, Seq<Seq<Tuple>>> input) { try { LOG.debug("ToGroupKeyValueFunction2 in " + input); final Object key = input._1(); Seq<Seq<Tuple>> bags = input._2(); Iterable<Seq<Tuple>> bagsList = JavaConversions.asJavaIterable(bags); int i = 0; List<Iterator<Tuple>> tupleIterators = Lists.newArrayList(); for (Seq<Tuple> bag : bagsList) { Iterator<Tuple> iterator = JavaConversions.asJavaCollection(bag).iterator(); final int index = i; tupleIterators.add(new IteratorTransform<Tuple, Tuple>(iterator) { @Override protected Tuple transform(Tuple next) { try { Tuple tuple = tf.newTuple(3); tuple.set(0, index); tuple.set(1, key); tuple.set(2, next); return tuple; } catch (ExecException e) { throw new RuntimeException(e); } } }); ++i; } Tuple out = tf.newTuple(2); out.set(0, key); out.set(1, new IteratorUnion<Tuple>(tupleIterators.iterator())); LOG.debug("ToGroupKeyValueFunction2 out " + out); return out; } catch(Exception e) { throw new RuntimeException(e); } } } private static class IteratorUnion<T> implements Iterator<T> { private final Iterator<Iterator<T>> iterators; private Iterator<T> current; public IteratorUnion(Iterator<Iterator<T>> iterators) { super(); this.iterators = iterators; } @Override public boolean hasNext() { if (current != null && current.hasNext()) { return true; } else if (iterators.hasNext()) { current = iterators.next(); return hasNext(); } else { return false; } } @Override public T next() { return current.next(); } @Override public void remove() { current.remove(); } } }