// ============================================================================
//
// Copyright (C) 2006-2016 Talend Inc. - www.talend.com
//
// This source code is available under agreement available at
// %InstallDIR%\features\org.talend.rcp.branding.%PRODUCTNAME%\%PRODUCTNAME%license.txt
//
// You should have received a copy of the agreement
// along with this program; if not, write to Talend SA
// 9 rue Pages 92150 Suresnes, France
//
// ============================================================================
package org.talend.dataquality.sampling.parallel;
import java.io.Serializable;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
/**
* Sampling API for Spark components.
*/
public class SparkSamplingUtil<T> implements Serializable {
private Long seed = null;
public SparkSamplingUtil() {
this(null);
}
/**
* constructor with random seed as parameter
*
* @param seed
*/
public SparkSamplingUtil(Long seed) {
this.seed = seed;
}
/**
* do sampling on RDD
*
* @param rdd
* @param nbSamples
* @return list of sample pairs, with generated score as left value and original data as right value.
*/
public List<ImmutablePair<Double, T>> getSamplePairList(JavaRDD<T> rdd, int nbSamples) {
JavaRDD<ImmutablePair<Double, T>> mappedRdd = rdd.mapPartitions(new SamplingMapFunction(nbSamples));
List<ImmutablePair<Double, T>> topPairs = mappedRdd.top(nbSamples, new PairComparator());
return topPairs;
}
/**
* do sampling on DF
*
* @param df
* @param nbSamples
* @return list of sample pairs, with generated score as left value and original data as right value.
*/
public List<ImmutablePair<Double, Row>> getSamplePairList(DataFrame df, int nbSamples) {
JavaRDD<ImmutablePair<Double, Row>> mappedRdd = df.javaRDD().mapPartitions(new SamplingMapFunction<Row>(nbSamples));
List<ImmutablePair<Double, Row>> topPairs = mappedRdd.top(nbSamples, new PairComparator());
return topPairs;
}
/**
* do sampling on DateFrame
*
* @param rdd
* @param nbSamples
* @return list of sample values
*/
public List<T> getSampleList(JavaRDD<T> rdd, int nbSamples) {
List<ImmutablePair<Double, T>> topPairs = getSamplePairList(rdd, nbSamples);
List<T> result = new ArrayList<T>();
for (ImmutablePair<Double, T> pair : topPairs) {
result.add(pair.getRight());
}
return result;
}
private class SamplingMapFunction<T> implements FlatMapFunction<Iterator<T>, ImmutablePair<Double, T>> {
private final int nbSamples;
public SamplingMapFunction(int nbSamples) {
this.nbSamples = nbSamples;
}
@Override
public Iterable<ImmutablePair<Double, T>> call(Iterator<T> tIterator) throws Exception {
if (seed == null) {
seed = new Random().nextLong();
}
ReservoirSamplerWithBinaryHeap<T> sampler = new ReservoirSamplerWithBinaryHeap<T>(nbSamples, seed);
sampler.clear();
while (tIterator.hasNext()) {
sampler.onNext(tIterator.next());
}
sampler.onCompleted(true);
Iterable<ImmutablePair<Double, T>> samplePairs = sampler.samplePairs();
return samplePairs;
}
}
private class PairComparator<T> implements Serializable, Comparator<ImmutablePair<Double, T>> {
@Override
public int compare(ImmutablePair<Double, T> o1, ImmutablePair<Double, T> o2) {
if (o1.left > o2.left) {
return 1;
} else if (o1.left < o2.left) {
return -1;
}
return 0;
}
}
}