package org.radargun.service; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Map; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.radargun.traits.MapReducer; import org.radargun.utils.KeyValueProperty; import org.radargun.utils.Utils; /** * @author Matej Cimbora */ public class SparkMapReduce implements MapReducer, Serializable { private SparkDriverService sparkDriverService; public SparkMapReduce(SparkDriverService sparkDriverService) { this.sparkDriverService = sparkDriverService; } @Override public Builder builder() { return new Builder(this); } @Override public boolean supportsCombiner() { return false; } @Override public boolean supportsTimeout() { return false; } public static class Builder<KOut, VOut, R> implements MapReducer.Builder<KOut, VOut, R> { private SparkMapReduce sparkMapReduce; private Object source; private Object mapper; private Object reducer; public Builder(SparkMapReduce sparkMapReduce) { this.sparkMapReduce = sparkMapReduce; } @Override public MapReducer.Builder timeout(long timeout) { throw new UnsupportedOperationException("Timeout not supported"); } @Override public MapReducer.Builder source(String source) { try { this.source = Utils.instantiate(source); } catch (Exception e) { throw new IllegalArgumentException("Could not instantiate RDD source class: " + source, e); } Utils.invokeMethodWithProperties(this.source, sparkMapReduce.sparkDriverService.mapReduceSourceProperties); return this; } @Override public MapReducer.Task build() { final SparkDriverService sparkDriverService = sparkMapReduce.sparkDriverService; if (mapper instanceof SparkMapper) { return new MapReduceTask((SparkMapper) mapper, (SparkReducer) reducer, (SparkJavaRDDSource) source, sparkDriverService.sparkContext); } else if (mapper instanceof SparkPairMapper) { return new MapToPairReduceByKeyTask((SparkPairMapper) mapper, (SparkReducer) reducer, (SparkJavaRDDSource) source, sparkDriverService.sparkContext); } else { throw new IllegalStateException("Invalid Mapper implementation " + mapper + " has been provided. " + "Expecting one of (" + SparkMapper.class + ", " + SparkPairMapper.class + ")"); } } @Override public MapReducer.Builder collator(String collatorFqn, Collection<KeyValueProperty> collatorParameters) { throw new UnsupportedOperationException("Collator not supported"); } @Override public MapReducer.Builder combiner(String combinerFqn, Collection<KeyValueProperty> combinerParameters) { throw new UnsupportedOperationException("Combiner not supported"); } @Override public MapReducer.Builder reducer(String reducerFqn, Collection<KeyValueProperty> reducerParameters) { try { reducer = Utils.instantiate(reducerFqn); Utils.invokeMethodWithProperties(reducer, reducerParameters); } catch (Exception e) { throw new IllegalArgumentException("Could not instantiate Reducer class: " + reducerFqn, e); } return this; } @Override public MapReducer.Builder mapper(String mapperFqn, Collection<KeyValueProperty> mapperParameters) { try { mapper = Utils.instantiate(mapperFqn); Utils.invokeMethodWithProperties(mapper, mapperParameters); } catch (Exception e) { throw new IllegalArgumentException("Could not instantiate Mapper class: " + mapperFqn, e); } return this; } } public abstract static class AbstractTask implements MapReducer.Task, Serializable { protected SparkReducer reducer; protected SparkJavaRDDSource source; protected JavaSparkContext sparkContext; public AbstractTask(SparkReducer reducer, SparkJavaRDDSource source, JavaSparkContext javaSparkContext) { this.reducer = reducer; this.source = source; this.sparkContext = javaSparkContext; source.setSparkContext(sparkContext); // Run dummy task to make sure jars are added to workers before performance test starts sparkContext.parallelize(new ArrayList<>(0)).count(); } } public static class MapReduceTask extends AbstractTask { private SparkMapper mapper; public MapReduceTask(SparkMapper mapper, SparkReducer reducer, SparkJavaRDDSource source, JavaSparkContext javaSparkContext) { super(reducer, source, javaSparkContext); this.mapper = mapper; this.reducer = reducer; } @Override public Map execute() { Object resultObject = source.getSource().map(mapper.getMapFunction()).reduce(reducer.getReduceFunction()); Map resultMap = new HashMap(1); resultMap.put("result_key", resultObject); return resultMap; } @Override public Object executeWithCollator() { throw new UnsupportedOperationException("Collator not supported"); } } public static class MapToPairReduceByKeyTask extends AbstractTask { private SparkPairMapper mapper; public MapToPairReduceByKeyTask(SparkPairMapper mapper, SparkReducer reducer, SparkJavaRDDSource source, JavaSparkContext javaSparkContext) { super(reducer, source, javaSparkContext); this.mapper = mapper; } @Override public Map execute() { JavaPairRDD javaPairRDD = source.getSource().mapToPair(mapper.getMapFunction()).reduceByKey(reducer.getReduceFunction()); return javaPairRDD.collectAsMap(); } @Override public Object executeWithCollator() { throw new UnsupportedOperationException("Collator not supported"); } } /** * Provides JavaRDD from various sources (e.g. file, ISPN cluster, parallelized collection) */ public interface SparkJavaRDDSource<T> extends Serializable { JavaRDD<T> getSource(); /** * Set spark context to obtain JavaRDD with */ void setSparkContext(JavaSparkContext context); } /** * Basic mapper implementation allowing to invoke JavaRDD.map() function */ public interface SparkMapper<T, R> extends Serializable { Function<T, R> getMapFunction(); } /** * Mapper implementation allowing to invoke JavaRDD.mapToPair() function */ public interface SparkPairMapper<T, K, V> extends Serializable { PairFunction<T, K, V> getMapFunction(); } /* Reducer implementation */ public interface SparkReducer<T1, T2, R> extends Serializable { Function2<T1, T2, R> getReduceFunction(); } }