/* * Copyright © 2016 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.spark.app; import co.cask.cdap.api.TxRunnable; import co.cask.cdap.api.common.Bytes; import co.cask.cdap.api.data.DatasetContext; import co.cask.cdap.api.dataset.table.Get; import co.cask.cdap.api.dataset.table.Increment; import co.cask.cdap.api.dataset.table.Put; import co.cask.cdap.api.dataset.table.Table; import co.cask.cdap.api.spark.AbstractSpark; import co.cask.cdap.api.spark.JavaSparkExecutionContext; import co.cask.cdap.api.spark.JavaSparkMain; import co.cask.cdap.api.spark.SparkClientContext; import com.google.common.base.Preconditions; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; /** * */ public class CharCountProgram extends AbstractSpark implements JavaSparkMain { @Override protected void configure() { setMainClass(CharCountProgram.class); } @Override public void beforeSubmit(SparkClientContext context) throws Exception { context.setSparkConf(new SparkConf().set("spark.io.compression.codec", "org.apache.spark.io.LZFCompressionCodec")); Table totals = context.getDataset("totals"); totals.get(new Get("total").add("total")).getLong("total"); totals.put(new Put("total").add("total", 0L)); } @Override public void run(final JavaSparkExecutionContext sec) throws Exception { JavaSparkContext sc = new JavaSparkContext(); // Verify the codec is being set Preconditions.checkArgument( "org.apache.spark.io.LZFCompressionCodec".equals(sc.getConf().get("spark.io.compression.codec"))); // read the dataset JavaPairRDD<byte[], String> inputData = sec.fromDataset("keys"); // create a new RDD with the same key but with a new value which is the length of the string final JavaPairRDD<byte[], byte[]> stringLengths = inputData.mapToPair(new PairFunction<Tuple2<byte[], String>, byte[], byte[]>() { @Override public Tuple2<byte[], byte[]> call(Tuple2<byte[], String> stringTuple2) throws Exception { return new Tuple2<>(stringTuple2._1(), Bytes.toBytes(stringTuple2._2().length())); } }); // write a total count to a table (that emits a metric we can validate in the test case) sec.execute(new TxRunnable() { @Override public void run(DatasetContext context) throws Exception { long count = stringLengths.count(); Table totals = context.getDataset("totals"); totals.increment(new Increment("total").add("total", count)); // write the character count to dataset sec.saveAsDataset(stringLengths, "count"); } }); } }