/* * Copyright 2014 Cloudera, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.kitesdk.data.spark; import java.util.HashMap; import java.util.Map; import org.apache.avro.generic.GenericData.Record; import org.apache.hadoop.mapreduce.Job; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.kitesdk.data.Dataset; import org.kitesdk.data.DatasetDescriptor; import org.kitesdk.data.DatasetReader; import org.kitesdk.data.DatasetWriter; import org.kitesdk.data.Format; import org.kitesdk.data.mapreduce.DatasetKeyInputFormat; import org.kitesdk.data.mapreduce.DatasetKeyOutputFormat; import org.kitesdk.data.mapreduce.FileSystemTestBase; import org.kitesdk.data.mapreduce.TestMapReduce; import scala.Tuple2; @RunWith(Parameterized.class) public class TestSpark extends FileSystemTestBase { public TestSpark(Format format) { super(format); } // These are static inner classes becasuse Format does not implement Serializable public static class ToJava implements PairFunction<Tuple2<Record, Void>, String, Integer> { @Override public Tuple2<String, Integer> call(Tuple2<Record, Void> t) throws Exception { return new Tuple2<String, Integer>(t._1().get("text").toString(), 1); } } public static class Sum implements Function2<Integer, Integer, Integer> { @Override public Integer call(Integer t1, Integer t2) throws Exception { return t1 + t2; } } public static class ToAvro implements PairFunction<Tuple2<String, Integer>, Record, Void> { @Override public Tuple2<Record, Void> call(Tuple2<String, Integer> t) throws Exception { Record record = new Record(TestMapReduce.STATS_SCHEMA); record.put("name", t._1()); record.put("count", t._2()); return new Tuple2<Record, Void>(record, null); } } @Test @SuppressWarnings("deprecation") public void testSparkJob() throws Exception { Dataset<Record> inputDataset = repo.create("ns", "in", new DatasetDescriptor.Builder() .property("kite.allow.csv", "true") .schema(TestMapReduce.STRING_SCHEMA) .format(format) .build(), Record.class); DatasetWriter<Record> writer = inputDataset.newWriter(); writer.write(newStringRecord("apple")); writer.write(newStringRecord("banana")); writer.write(newStringRecord("banana")); writer.write(newStringRecord("carrot")); writer.write(newStringRecord("apple")); writer.write(newStringRecord("apple")); writer.close(); Dataset<Record> outputDataset = repo.create("ns", "out", new DatasetDescriptor.Builder() .property("kite.allow.csv", "true") .schema(TestMapReduce.STATS_SCHEMA) .format(format) .build(), Record.class); Job job = Job.getInstance(); DatasetKeyInputFormat.configure(job).readFrom(inputDataset); DatasetKeyOutputFormat.configure(job).writeTo(outputDataset); @SuppressWarnings("unchecked") JavaPairRDD<Record, Void> inputData = SparkTestHelper.getSparkContext() .newAPIHadoopRDD(job.getConfiguration(), DatasetKeyInputFormat.class, Record.class, Void.class); JavaPairRDD<String, Integer> mappedData = inputData.mapToPair(new ToJava()); JavaPairRDD<String, Integer> sums = mappedData.reduceByKey(new Sum()); JavaPairRDD<Record, Void> outputData = sums.mapToPair(new ToAvro()); outputData.saveAsNewAPIHadoopDataset(job.getConfiguration()); DatasetReader<Record> reader = outputDataset.newReader(); Map<String, Integer> counts = new HashMap<String, Integer>(); for (Record record : reader) { counts.put(record.get("name").toString(), (Integer) record.get("count")); } reader.close(); Assert.assertEquals(3, counts.get("apple").intValue()); Assert.assertEquals(2, counts.get("banana").intValue()); Assert.assertEquals(1, counts.get("carrot").intValue()); } }