package org.apache.pig.backend.hadoop.executionengine.spark.converter;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POGlobalRearrange;
import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.rdd.CoGroupedRDD;
import org.apache.spark.rdd.RDD;
import scala.Product2;
import scala.Tuple2;
import scala.collection.JavaConversions;
import scala.collection.Seq;
//import scala.reflect.ClassManifest;
@SuppressWarnings({ "serial" })
public class GlobalRearrangeConverter implements
POConverter<Tuple, Tuple, POGlobalRearrange> {
private static final Log LOG = LogFactory
.getLog(GlobalRearrangeConverter.class);
private static final TupleFactory tf = TupleFactory.getInstance();
// GROUP FUNCTIONS
private static final ToKeyValueFunction TO_KEY_VALUE_FUNCTION = new ToKeyValueFunction();
private static final GetKeyFunction GET_KEY_FUNCTION = new GetKeyFunction();
// COGROUP FUNCTIONS
private static final GroupTupleFunction GROUP_TUPLE_FUNCTION = new GroupTupleFunction();
private static final ToGroupKeyValueFunction TO_GROUP_KEY_VALUE_FUNCTION = new ToGroupKeyValueFunction();
@Override
public RDD<Tuple> convert(List<RDD<Tuple>> predecessors,
POGlobalRearrange physicalOperator) throws IOException {
SparkUtil.assertPredecessorSizeGreaterThan(predecessors,
physicalOperator, 0);
int parallelism = SparkUtil.getParallelism(predecessors,
physicalOperator);
String reducers = System.getenv("SPARK_REDUCERS");
if (reducers != null) {
parallelism = Integer.parseInt(reducers);
}
LOG.info("Parallelism for Spark groupBy: " + parallelism);
if (predecessors.size() == 1) {
// GROUP
JavaRDD<Tuple> jrdd = predecessors.get(0).toJavaRDD();
JavaPairRDD<Object, Iterable<Tuple>> prdd = jrdd.groupBy(GET_KEY_FUNCTION, parallelism);
JavaRDD<Tuple> jrdd2 = prdd.map(GROUP_TUPLE_FUNCTION);
return jrdd2.rdd();
} else {
List<RDD<Tuple2<Object, Tuple>>> rddPairs = new ArrayList<RDD<Tuple2<Object, Tuple>>>();
for (RDD<Tuple> rdd : predecessors) {
JavaRDD<Tuple> jrdd = JavaRDD.fromRDD(rdd, SparkUtil.getManifest(Tuple.class));
JavaRDD<Tuple2<Object, Tuple>> rddPair = jrdd.map(TO_KEY_VALUE_FUNCTION);
rddPairs.add(rddPair.rdd());
}
// Something's wrong with the type parameters of CoGroupedRDD
// key and value are the same type ???
CoGroupedRDD<Object> coGroupedRDD = new CoGroupedRDD<Object>(
(Seq<RDD<? extends Product2<Object, ?>>>) (Object) (JavaConversions
.asScalaBuffer(rddPairs).toSeq()),
new HashPartitioner(parallelism));
RDD<Tuple2<Object, Seq<Seq<Tuple>>>> rdd =
(RDD<Tuple2<Object, Seq<Seq<Tuple>>>>) (Object) coGroupedRDD;
return rdd.toJavaRDD().map(TO_GROUP_KEY_VALUE_FUNCTION).rdd();
}
}
private static class GetKeyFunction implements Function<Tuple, Object>, Serializable {
public Object call(Tuple t) {
try {
LOG.debug("GetKeyFunction in " + t);
// see PigGenericMapReduce For the key
Object key = t.get(1);
LOG.debug("GetKeyFunction out " + key);
return key;
} catch (ExecException e) {
throw new RuntimeException(e);
}
}
}
private static class GroupTupleFunction implements Function<Tuple2<Object, Iterable<Tuple>>, Tuple>,
Serializable {
public Tuple call(Tuple2<Object, Iterable<Tuple>> v1) {
try {
LOG.debug("GroupTupleFunction in " + v1);
Tuple tuple = tf.newTuple(2);
tuple.set(0, v1._1()); // the (index, key) tuple
tuple.set(1, v1._2().iterator()); // the Seq<Tuple> aka bag of values
LOG.debug("GroupTupleFunction out " + tuple);
return tuple;
} catch (ExecException e) {
throw new RuntimeException(e);
}
}
}
private static class ToKeyValueFunction implements
Function<Tuple, Tuple2<Object, Tuple>>, Serializable {
@Override
public Tuple2<Object, Tuple> call(Tuple t) {
try {
// (index, key, value)
LOG.debug("ToKeyValueFunction in " + t);
Object key = t.get(1);
Tuple value = (Tuple) t.get(2); // value
// (key, value)
Tuple2<Object, Tuple> out = new Tuple2<Object, Tuple>(key,
value);
LOG.debug("ToKeyValueFunction out " + out);
return out;
} catch (ExecException e) {
throw new RuntimeException(e);
}
}
}
private static class ToGroupKeyValueFunction implements
Function<Tuple2<Object, Seq<Seq<Tuple>>>, Tuple>, Serializable {
@Override
public Tuple call(Tuple2<Object, Seq<Seq<Tuple>>> input) {
try {
LOG.debug("ToGroupKeyValueFunction2 in " + input);
final Object key = input._1();
Object obj = input._2();
// XXX this is a hack for Spark 1.1.0: the type is an Array, not Seq
Seq<Tuple>[] bags = (Seq<Tuple>[])obj;
int i = 0;
List<Iterator<Tuple>> tupleIterators = new ArrayList<Iterator<Tuple>>();
for (int j=0; j<bags.length; j++) {
Seq<Tuple> bag = bags[j];
Iterator<Tuple> iterator = JavaConversions
.asJavaCollection(bag).iterator();
final int index = i;
tupleIterators.add(new IteratorTransform<Tuple, Tuple>(
iterator) {
@Override
protected Tuple transform(Tuple next) {
try {
Tuple tuple = tf.newTuple(3);
tuple.set(0, index);
tuple.set(1, key);
tuple.set(2, next);
return tuple;
} catch (ExecException e) {
throw new RuntimeException(e);
}
}
});
++i;
}
Tuple out = tf.newTuple(2);
out.set(0, key);
out.set(1, new IteratorUnion<Tuple>(tupleIterators.iterator()));
LOG.debug("ToGroupKeyValueFunction2 out " + out);
return out;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
private static class IteratorUnion<T> implements Iterator<T> {
private final Iterator<Iterator<T>> iterators;
private Iterator<T> current;
public IteratorUnion(Iterator<Iterator<T>> iterators) {
super();
this.iterators = iterators;
}
@Override
public boolean hasNext() {
if (current != null && current.hasNext()) {
return true;
} else if (iterators.hasNext()) {
current = iterators.next();
return hasNext();
} else {
return false;
}
}
@Override
public T next() {
return current.next();
}
@Override
public void remove() {
current.remove();
}
}
}