/* * Copyright © 2016 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package org.example.plugin; import co.cask.cdap.api.data.format.StructuredRecord; import co.cask.cdap.api.data.schema.Schema; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFlatMapFunction; import scala.Tuple2; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.regex.Pattern; /** * Common class for word count spark plugins. */ public class WordCount implements Serializable { private static final Pattern WHITESPACE = Pattern.compile("\\s"); private final String field; public WordCount(String field) { this.field = field; } public void validateSchema(Schema inputSchema) { // a null input schema means its unknown until runtime, or its not constant if (inputSchema != null) { // if the input schema is constant and known at configure time, check that the input field exists and is a string. Schema.Field inputField = inputSchema.getField(field); if (inputField == null) { throw new IllegalArgumentException( String.format("Field '%s' does not exist in input schema %s.", field, inputSchema)); } Schema fieldSchema = inputField.getSchema(); Schema.Type fieldType = fieldSchema.isNullable() ? fieldSchema.getNonNullable().getType() : fieldSchema.getType(); if (fieldType != Schema.Type.STRING) { throw new IllegalArgumentException( String.format("Field '%s' is of illegal type %s. Must be of type %s.", field, fieldType, Schema.Type.STRING)); } } } public JavaPairRDD<String, Long> countWords(JavaRDD<StructuredRecord> input) { return input.flatMap(new SplitFunction(field)) .groupBy(new Identity<String>()) .flatMapToPair(new CountFunction()); } private static class Identity<T> implements Function<T, T> { @Override public T call(T t) throws Exception { return t; } } private static class SplitFunction implements FlatMapFunction<StructuredRecord, String> { private final String field; public SplitFunction(String field) { this.field = field; } @Override public Iterable<String> call(StructuredRecord record) throws Exception { String val = record.get(field); List<String> words = new ArrayList<>(); if (val != null) { Collections.addAll(words, WHITESPACE.split(val)); } return words; } } private static class CountFunction implements PairFlatMapFunction<Tuple2<String, Iterable<String>>, String, Long> { @Override public Iterable<Tuple2<String, Long>> call(Tuple2<String, Iterable<String>> tuples) throws Exception { String word = tuples._1(); Long count = 0L; for (String s : tuples._2()) { count++; } List<Tuple2<String, Long>> output = new ArrayList<>(); output.add(new Tuple2<>(word, count)); return output; } } }