// ============================================================================
//
// Copyright (C) 2006-2016 Talend Inc. - www.talend.com
//
// This source code is available under agreement available at
// %InstallDIR%\features\org.talend.rcp.branding.%PRODUCTNAME%\%PRODUCTNAME%license.txt
//
// You should have received a copy of the agreement
// along with this program; if not, write to Talend SA
// 9 rue Pages 92150 Suresnes, France
//
// ============================================================================
package org.talend.dataquality.sampling.parallel;
import static org.junit.Assert.assertTrue;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.talend.dataquality.duplicating.AllDataqualitySamplingTests;
public class SparkSamplingUtilTest implements Serializable {
private static final int SAMPLE_SIZE = 10;
private static final int ORIGINAL_COUNT = 100;
private ReservoirSamplerWithBinaryHeap<TestRowStruct> sampler;
private TestRowStruct[] testers;
private static final Integer[] EXPECTED_SAMPLES_LIST = { 31, 79, 93, 32, 45, 90, 15, 59, 91, 89 };
private static JavaSparkContext sc;
@BeforeClass
public static void beforeClass() {
sc = new JavaSparkContext(new SparkConf().setAppName("Simple Application").setMaster("local[1]"));
}
@Before
public void init() {
testers = new TestRowStruct[ORIGINAL_COUNT];
for (int j = 0; j < ORIGINAL_COUNT; j++) {
TestRowStruct struct = new TestRowStruct();
struct.id = j + 1;
struct.city = "city" + (j + 1);
testers[j] = struct;
}
}
@Test
public void testSamplePairList() {
JavaRDD<TestRowStruct> rdd = sc.parallelize(Arrays.asList(testers));
SparkSamplingUtil<TestRowStruct> sampler = new SparkSamplingUtil<>(AllDataqualitySamplingTests.RANDOM_SEED);
List<ImmutablePair<Double, TestRowStruct>> topPairs = sampler.getSamplePairList(rdd, SAMPLE_SIZE);
for (int i = 0; i < topPairs.size(); i++) {
assertTrue("The ID " + topPairs.get(i).getRight().getId() + " is expected",
Arrays.asList(EXPECTED_SAMPLES_LIST).contains(topPairs.get(i).getRight().getId()));
}
}
@Test
public void getSamplePairListForDataFrame() {
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
JavaRDD<TestRowStruct> rdd = sc.parallelize(Arrays.asList(testers));
DataFrame df = sqlContext.createDataFrame(rdd, TestRowStruct.class);
SparkSamplingUtil<TestRowStruct> sampler = new SparkSamplingUtil<>(AllDataqualitySamplingTests.RANDOM_SEED);
List<ImmutablePair<Double, Row>> sampleList = sampler.getSamplePairList(df, 5);
for (int i = 0; i < sampleList.size(); i++) {
assertTrue("The ID " + sampleList.get(i).getRight().getInt(1) + " is expected",
Arrays.asList(EXPECTED_SAMPLES_LIST).contains(sampleList.get(i).getRight().getInt(1)));
}
}
@Test
public void testGetSampleList() {
JavaRDD<TestRowStruct> rdd = sc.parallelize(Arrays.asList(testers));
SparkSamplingUtil<TestRowStruct> sampler = new SparkSamplingUtil<>(AllDataqualitySamplingTests.RANDOM_SEED);
List<TestRowStruct> sampleList = sampler.getSampleList(rdd, SAMPLE_SIZE);
for (int i = 0; i < sampleList.size(); i++) {
assertTrue("The ID " + sampleList.get(i).getId() + " is expected",
Arrays.asList(EXPECTED_SAMPLES_LIST).contains(sampleList.get(i).getId()));
}
}
public class TestRowStruct implements Serializable {
private Integer id;
private String city;
public Integer getId() {
return this.id;
}
public void setId(Integer id) {
this.id = id;
}
public String getCity() {
return this.city;
}
public void setCity(String city) {
this.city = city;
}
@Override
public String toString() {
return id + " -> " + city; //$NON-NLS-1$
}
}
}